Files
smartstorage/rust/src/cluster/membership.rs

185 lines
6.2 KiB
Rust

use anyhow::Result;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use super::protocol::{
ClusterRequest, ClusterResponse, HeartbeatMessage, JoinRequestMessage, NodeInfo,
};
use super::quic_transport::QuicTransport;
use super::state::ClusterState;
/// Manages cluster membership: heartbeating, joining, failure detection.
pub struct MembershipManager {
state: Arc<ClusterState>,
transport: Arc<QuicTransport>,
heartbeat_interval: Duration,
local_node_info: NodeInfo,
}
impl MembershipManager {
pub fn new(
state: Arc<ClusterState>,
transport: Arc<QuicTransport>,
heartbeat_interval_ms: u64,
local_node_info: NodeInfo,
) -> Self {
Self {
state,
transport,
heartbeat_interval: Duration::from_millis(heartbeat_interval_ms),
local_node_info,
}
}
/// Join the cluster by contacting seed nodes.
/// Sends a JoinRequest to each seed node until one accepts.
pub async fn join_cluster(&self, seed_nodes: &[String]) -> Result<()> {
if seed_nodes.is_empty() {
tracing::info!("No seed nodes configured, starting as initial cluster node");
self.state.add_node(self.local_node_info.clone()).await;
return Ok(());
}
for seed in seed_nodes {
let addr: SocketAddr = match seed.parse() {
Ok(a) => a,
Err(e) => {
tracing::warn!("Invalid seed node address '{}': {}", seed, e);
continue;
}
};
tracing::info!("Attempting to join cluster via seed node {}", seed);
match self.try_join(addr).await {
Ok(()) => {
tracing::info!("Successfully joined cluster via {}", seed);
return Ok(());
}
Err(e) => {
tracing::warn!("Failed to join via {}: {}", seed, e);
}
}
}
// If no seed responded, start as a new cluster
tracing::info!("Could not reach any seed nodes, starting as initial cluster node");
self.state.add_node(self.local_node_info.clone()).await;
Ok(())
}
async fn try_join(&self, addr: SocketAddr) -> Result<()> {
let conn = self
.transport
.get_connection("seed", addr)
.await?;
let request = ClusterRequest::JoinRequest(JoinRequestMessage {
node_info: self.local_node_info.clone(),
});
let response = self.transport.send_request(&conn, &request).await?;
match response {
ClusterResponse::JoinResponse(join_resp) => {
if join_resp.accepted {
if let Some(topology) = &join_resp.topology {
self.state.apply_topology(topology).await;
// Also register self
self.state.add_node(self.local_node_info.clone()).await;
tracing::info!(
"Applied cluster topology (version {}, {} nodes, {} erasure sets)",
topology.version,
topology.nodes.len(),
topology.erasure_sets.len(),
);
}
Ok(())
} else {
anyhow::bail!(
"Join rejected: {}",
join_resp.error.unwrap_or_default()
)
}
}
ClusterResponse::Error(e) => {
anyhow::bail!("Join error: {} - {}", e.code, e.message)
}
_ => anyhow::bail!("Unexpected response to join request"),
}
}
/// Run the heartbeat loop. Sends heartbeats to all peers periodically.
pub async fn heartbeat_loop(self: Arc<Self>, mut shutdown: tokio::sync::watch::Receiver<bool>) {
let mut interval = tokio::time::interval(self.heartbeat_interval);
loop {
tokio::select! {
_ = interval.tick() => {
self.send_heartbeats().await;
}
_ = shutdown.changed() => break,
}
}
}
async fn send_heartbeats(&self) {
let peers = self.state.online_peers().await;
let topology_version = self.state.version().await;
let mut responded = Vec::new();
for peer in &peers {
let addr: SocketAddr = match peer.quic_addr.parse() {
Ok(a) => a,
Err(_) => continue,
};
let heartbeat = ClusterRequest::Heartbeat(HeartbeatMessage {
node_id: self.local_node_info.node_id.clone(),
timestamp: chrono::Utc::now().to_rfc3339(),
drive_states: Vec::new(), // TODO: populate from DriveManager
topology_version,
});
match tokio::time::timeout(
Duration::from_secs(5),
self.send_heartbeat_to_peer(&peer.node_id, addr, &heartbeat),
)
.await
{
Ok(Ok(())) => {
responded.push(peer.node_id.clone());
}
Ok(Err(e)) => {
tracing::debug!(
peer = %peer.node_id,
error = %e,
"Heartbeat failed"
);
}
Err(_) => {
tracing::debug!(peer = %peer.node_id, "Heartbeat timed out");
}
}
}
// Update state based on responses
let status_changes = self.state.tick_heartbeats(&responded).await;
for (node_id, status) in &status_changes {
tracing::info!(node = %node_id, status = ?status, "Node status changed");
}
}
async fn send_heartbeat_to_peer(
&self,
node_id: &str,
addr: SocketAddr,
heartbeat: &ClusterRequest,
) -> Result<()> {
let conn = self.transport.get_connection(node_id, addr).await?;
let _response = self.transport.send_request(&conn, heartbeat).await?;
Ok(())
}
}