use anyhow::Result; use dashmap::DashMap; use quinn::{ClientConfig, Endpoint, ServerConfig as QuinnServerConfig}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use std::net::SocketAddr; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::protocol::{ self, ClusterRequest, ClusterResponse, ShardReadResponse, ShardWriteAck, ShardWriteRequest, }; use super::shard_store::{ShardId, ShardStore}; /// QUIC transport layer for inter-node communication. /// /// Manages a QUIC endpoint for both sending and receiving cluster messages. /// Uses self-signed TLS certificates generated at init time. /// Maintains a connection pool to peer nodes. pub struct QuicTransport { endpoint: Endpoint, /// Cached connections to peer nodes: node_id -> Connection connections: Arc>, local_node_id: String, } impl QuicTransport { /// Create a new QUIC transport, binding to the specified address. pub async fn new(bind_addr: SocketAddr, local_node_id: String) -> Result { let (server_config, client_config) = Self::generate_tls_configs()?; let endpoint = Endpoint::server(server_config, bind_addr)?; // Also configure the endpoint for client connections let mut endpoint_client = endpoint.clone(); endpoint_client.set_default_client_config(client_config); Ok(Self { endpoint, connections: Arc::new(DashMap::new()), local_node_id, }) } /// Get or establish a connection to a peer node. pub async fn get_connection( &self, node_id: &str, addr: SocketAddr, ) -> Result { // Check cache first if let Some(conn) = self.connections.get(node_id) { if conn.close_reason().is_none() { return Ok(conn.clone()); } // Connection is closed, remove from cache drop(conn); self.connections.remove(node_id); } // Establish new connection let conn = self .endpoint .connect(addr, "smartstorage")? .await?; self.connections .insert(node_id.to_string(), conn.clone()); Ok(conn) } /// Send a cluster request and receive the response. pub async fn send_request( &self, conn: &quinn::Connection, request: &ClusterRequest, ) -> Result { let (mut send, mut recv) = conn.open_bi().await?; // Encode and send request let encoded = protocol::encode_request(request)?; send.write_all(&encoded).await?; send.finish()?; // Read response let response_data = recv.read_to_end(64 * 1024 * 1024).await?; // 64MB max let (response, _) = protocol::decode_response(&response_data)?; Ok(response) } /// Send a shard write request with streaming data. /// /// Sends the request header first, then streams the shard data bytes. pub async fn send_shard_write( &self, conn: &quinn::Connection, request: ShardWriteRequest, shard_data: &[u8], ) -> Result { let (mut send, mut recv) = conn.open_bi().await?; // Send request header let encoded = protocol::encode_request(&ClusterRequest::ShardWrite(request))?; send.write_all(&encoded).await?; // Stream shard data send.write_all(shard_data).await?; send.finish()?; // Read ack let response_data = recv.read_to_end(1024).await?; let (response, _) = protocol::decode_response(&response_data)?; match response { ClusterResponse::ShardWriteAck(ack) => Ok(ack), ClusterResponse::Error(e) => { anyhow::bail!("Shard write error: {} - {}", e.code, e.message) } other => anyhow::bail!("Unexpected response to shard write: {:?}", other), } } /// Send a shard read request and receive the shard data. /// /// Returns (shard_data, checksum). pub async fn send_shard_read( &self, conn: &quinn::Connection, request: &ClusterRequest, ) -> Result, u32)>> { let (mut send, mut recv) = conn.open_bi().await?; // Send request let encoded = protocol::encode_request(request)?; send.write_all(&encoded).await?; send.finish()?; // Read response header let mut header_len_buf = [0u8; 4]; recv.read_exact(&mut header_len_buf).await?; let header_len = u32::from_le_bytes(header_len_buf) as usize; let mut header_buf = vec![0u8; header_len]; recv.read_exact(&mut header_buf).await?; let response: ClusterResponse = bincode::deserialize(&header_buf)?; match response { ClusterResponse::ShardReadResponse(read_resp) => { if !read_resp.found { return Ok(None); } // Read shard data that follows let mut shard_data = vec![0u8; read_resp.shard_data_length as usize]; recv.read_exact(&mut shard_data).await?; Ok(Some((shard_data, read_resp.checksum))) } ClusterResponse::Error(e) => { anyhow::bail!("Shard read error: {} - {}", e.code, e.message) } other => anyhow::bail!("Unexpected response to shard read: {:?}", other), } } /// Accept incoming connections and dispatch to the handler. pub async fn accept_loop( self: Arc, shard_store: Arc, mut shutdown: tokio::sync::watch::Receiver, ) { loop { tokio::select! { incoming = self.endpoint.accept() => { match incoming { Some(incoming_conn) => { let transport = self.clone(); let store = shard_store.clone(); tokio::spawn(async move { match incoming_conn.await { Ok(conn) => { transport.handle_connection(conn, store).await; } Err(e) => { tracing::error!("Failed to accept QUIC connection: {}", e); } } }); } None => break, } } _ = shutdown.changed() => break, } } } /// Handle a single QUIC connection (may have multiple streams). async fn handle_connection( &self, conn: quinn::Connection, shard_store: Arc, ) { loop { match conn.accept_bi().await { Ok((send, recv)) => { let store = shard_store.clone(); tokio::spawn(async move { if let Err(e) = Self::handle_stream(send, recv, store).await { tracing::error!("Stream handler error: {}", e); } }); } Err(quinn::ConnectionError::ApplicationClosed(_)) => break, Err(e) => { tracing::error!("Connection error: {}", e); break; } } } } /// Handle a single bidirectional stream (one request-response exchange). async fn handle_stream( mut send: quinn::SendStream, mut recv: quinn::RecvStream, shard_store: Arc, ) -> Result<()> { // Read the length-prefixed request header let mut len_buf = [0u8; 4]; recv.read_exact(&mut len_buf).await?; let msg_len = u32::from_le_bytes(len_buf) as usize; let mut msg_buf = vec![0u8; msg_len]; recv.read_exact(&mut msg_buf).await?; let request: ClusterRequest = bincode::deserialize(&msg_buf)?; match request { ClusterRequest::ShardWrite(write_req) => { // Read shard data from the stream let mut shard_data = vec![0u8; write_req.shard_data_length as usize]; recv.read_exact(&mut shard_data).await?; let shard_id = ShardId { bucket: write_req.bucket, key: write_req.key, chunk_index: write_req.chunk_index, shard_index: write_req.shard_index, }; let result = shard_store .write_shard(&shard_id, &shard_data, write_req.checksum) .await; let ack = ShardWriteAck { request_id: write_req.request_id, success: result.is_ok(), error: result.err().map(|e| e.to_string()), }; let response = protocol::encode_response(&ClusterResponse::ShardWriteAck(ack))?; send.write_all(&response).await?; send.finish()?; } ClusterRequest::ShardRead(read_req) => { let shard_id = ShardId { bucket: read_req.bucket, key: read_req.key, chunk_index: read_req.chunk_index, shard_index: read_req.shard_index, }; match shard_store.read_shard(&shard_id).await { Ok((data, checksum)) => { let header = ShardReadResponse { request_id: read_req.request_id, found: true, shard_data_length: data.len() as u64, checksum, }; // Send header let header_bytes = bincode::serialize(&ClusterResponse::ShardReadResponse(header))?; send.write_all(&(header_bytes.len() as u32).to_le_bytes()).await?; send.write_all(&header_bytes).await?; // Send shard data send.write_all(&data).await?; send.finish()?; } Err(_) => { let header = ShardReadResponse { request_id: read_req.request_id, found: false, shard_data_length: 0, checksum: 0, }; let header_bytes = bincode::serialize(&ClusterResponse::ShardReadResponse(header))?; send.write_all(&(header_bytes.len() as u32).to_le_bytes()).await?; send.write_all(&header_bytes).await?; send.finish()?; } } } ClusterRequest::ShardDelete(del_req) => { let shard_id = ShardId { bucket: del_req.bucket, key: del_req.key, chunk_index: del_req.chunk_index, shard_index: del_req.shard_index, }; let result = shard_store.delete_shard(&shard_id).await; let ack = protocol::ClusterResponse::ShardDeleteAck(protocol::ShardDeleteAck { request_id: del_req.request_id, success: result.is_ok(), }); let response = protocol::encode_response(&ack)?; send.write_all(&response).await?; send.finish()?; } ClusterRequest::ShardHead(head_req) => { let shard_id = ShardId { bucket: head_req.bucket, key: head_req.key, chunk_index: head_req.chunk_index, shard_index: head_req.shard_index, }; let resp = match shard_store.head_shard(&shard_id).await { Ok(Some(meta)) => protocol::ShardHeadResponse { request_id: head_req.request_id, found: true, data_size: meta.data_size, checksum: meta.checksum, }, _ => protocol::ShardHeadResponse { request_id: head_req.request_id, found: false, data_size: 0, checksum: 0, }, }; let response = protocol::encode_response(&ClusterResponse::ShardHeadResponse(resp))?; send.write_all(&response).await?; send.finish()?; } // Heartbeat, Join, TopologySync, Heal, and Manifest operations // will be handled by the membership and coordinator modules. // For now, send a generic ack. _ => { let response_data = recv.read_to_end(0).await.unwrap_or_default(); drop(response_data); let err = protocol::ErrorResponse { request_id: String::new(), code: "NotImplemented".to_string(), message: "This cluster operation is not yet implemented".to_string(), }; let response = protocol::encode_response(&ClusterResponse::Error(err))?; send.write_all(&response).await?; send.finish()?; } } Ok(()) } /// Generate self-signed TLS certificates for cluster-internal communication. fn generate_tls_configs() -> Result<(QuinnServerConfig, ClientConfig)> { // Generate self-signed certificate let cert = rcgen::generate_simple_self_signed(vec!["smartstorage".to_string()])?; let cert_der = CertificateDer::from(cert.cert); let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der())); // Server config let mut server_crypto = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(vec![cert_der.clone()], key_der.clone_key())?; server_crypto.alpn_protocols = vec![b"smartstorage".to_vec()]; let server_config = QuinnServerConfig::with_crypto(Arc::new( quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto)?, )); // Client config: skip server certificate verification (cluster-internal) let mut client_crypto = rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(SkipServerVerification)) .with_no_client_auth(); client_crypto.alpn_protocols = vec![b"smartstorage".to_vec()]; let client_config = ClientConfig::new(Arc::new( quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto)?, )); Ok((server_config, client_config)) } /// Close the QUIC endpoint gracefully. pub fn close(&self) { self.endpoint .close(quinn::VarInt::from_u32(0), b"shutdown"); } /// Get the local node ID. pub fn local_node_id(&self) -> &str { &self.local_node_id } } /// Certificate verifier that skips verification (for cluster-internal self-signed certs). #[derive(Debug)] struct SkipServerVerification; impl rustls::client::danger::ServerCertVerifier for SkipServerVerification { fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, _intermediates: &[CertificateDer<'_>], _server_name: &rustls::pki_types::ServerName<'_>, _ocsp_response: &[u8], _now: rustls::pki_types::UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ rustls::SignatureScheme::RSA_PKCS1_SHA256, rustls::SignatureScheme::RSA_PKCS1_SHA384, rustls::SignatureScheme::RSA_PKCS1_SHA512, rustls::SignatureScheme::ECDSA_NISTP256_SHA256, rustls::SignatureScheme::ECDSA_NISTP384_SHA384, rustls::SignatureScheme::ED25519, rustls::SignatureScheme::RSA_PSS_SHA256, rustls::SignatureScheme::RSA_PSS_SHA384, rustls::SignatureScheme::RSA_PSS_SHA512, ] } }