456 lines
17 KiB
Rust
456 lines
17 KiB
Rust
use anyhow::Result;
|
|
use dashmap::DashMap;
|
|
use quinn::{ClientConfig, Endpoint, ServerConfig as QuinnServerConfig};
|
|
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
|
|
use super::protocol::{
|
|
self, ClusterRequest, ClusterResponse, ShardReadResponse, ShardWriteAck, ShardWriteRequest,
|
|
};
|
|
use super::shard_store::{ShardId, ShardStore};
|
|
|
|
/// QUIC transport layer for inter-node communication.
|
|
///
|
|
/// Manages a QUIC endpoint for both sending and receiving cluster messages.
|
|
/// Uses self-signed TLS certificates generated at init time.
|
|
/// Maintains a connection pool to peer nodes.
|
|
pub struct QuicTransport {
|
|
endpoint: Endpoint,
|
|
/// Cached connections to peer nodes: node_id -> Connection
|
|
connections: Arc<DashMap<String, quinn::Connection>>,
|
|
local_node_id: String,
|
|
}
|
|
|
|
impl QuicTransport {
|
|
/// Create a new QUIC transport, binding to the specified address.
|
|
pub async fn new(bind_addr: SocketAddr, local_node_id: String) -> Result<Self> {
|
|
let (server_config, client_config) = Self::generate_tls_configs()?;
|
|
|
|
let endpoint = Endpoint::server(server_config, bind_addr)?;
|
|
|
|
// Also configure the endpoint for client connections
|
|
let mut endpoint_client = endpoint.clone();
|
|
endpoint_client.set_default_client_config(client_config);
|
|
|
|
Ok(Self {
|
|
endpoint,
|
|
connections: Arc::new(DashMap::new()),
|
|
local_node_id,
|
|
})
|
|
}
|
|
|
|
/// Get or establish a connection to a peer node.
|
|
pub async fn get_connection(
|
|
&self,
|
|
node_id: &str,
|
|
addr: SocketAddr,
|
|
) -> Result<quinn::Connection> {
|
|
// Check cache first
|
|
if let Some(conn) = self.connections.get(node_id) {
|
|
if conn.close_reason().is_none() {
|
|
return Ok(conn.clone());
|
|
}
|
|
// Connection is closed, remove from cache
|
|
drop(conn);
|
|
self.connections.remove(node_id);
|
|
}
|
|
|
|
// Establish new connection
|
|
let conn = self
|
|
.endpoint
|
|
.connect(addr, "smartstorage")?
|
|
.await?;
|
|
|
|
self.connections
|
|
.insert(node_id.to_string(), conn.clone());
|
|
|
|
Ok(conn)
|
|
}
|
|
|
|
/// Send a cluster request and receive the response.
|
|
pub async fn send_request(
|
|
&self,
|
|
conn: &quinn::Connection,
|
|
request: &ClusterRequest,
|
|
) -> Result<ClusterResponse> {
|
|
let (mut send, mut recv) = conn.open_bi().await?;
|
|
|
|
// Encode and send request
|
|
let encoded = protocol::encode_request(request)?;
|
|
send.write_all(&encoded).await?;
|
|
send.finish()?;
|
|
|
|
// Read response
|
|
let response_data = recv.read_to_end(64 * 1024 * 1024).await?; // 64MB max
|
|
let (response, _) = protocol::decode_response(&response_data)?;
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
/// Send a shard write request with streaming data.
|
|
///
|
|
/// Sends the request header first, then streams the shard data bytes.
|
|
pub async fn send_shard_write(
|
|
&self,
|
|
conn: &quinn::Connection,
|
|
request: ShardWriteRequest,
|
|
shard_data: &[u8],
|
|
) -> Result<ShardWriteAck> {
|
|
let (mut send, mut recv) = conn.open_bi().await?;
|
|
|
|
// Send request header
|
|
let encoded = protocol::encode_request(&ClusterRequest::ShardWrite(request))?;
|
|
send.write_all(&encoded).await?;
|
|
|
|
// Stream shard data
|
|
send.write_all(shard_data).await?;
|
|
send.finish()?;
|
|
|
|
// Read ack
|
|
let response_data = recv.read_to_end(1024).await?;
|
|
let (response, _) = protocol::decode_response(&response_data)?;
|
|
|
|
match response {
|
|
ClusterResponse::ShardWriteAck(ack) => Ok(ack),
|
|
ClusterResponse::Error(e) => {
|
|
anyhow::bail!("Shard write error: {} - {}", e.code, e.message)
|
|
}
|
|
other => anyhow::bail!("Unexpected response to shard write: {:?}", other),
|
|
}
|
|
}
|
|
|
|
/// Send a shard read request and receive the shard data.
|
|
///
|
|
/// Returns (shard_data, checksum).
|
|
pub async fn send_shard_read(
|
|
&self,
|
|
conn: &quinn::Connection,
|
|
request: &ClusterRequest,
|
|
) -> Result<Option<(Vec<u8>, u32)>> {
|
|
let (mut send, mut recv) = conn.open_bi().await?;
|
|
|
|
// Send request
|
|
let encoded = protocol::encode_request(request)?;
|
|
send.write_all(&encoded).await?;
|
|
send.finish()?;
|
|
|
|
// Read response header
|
|
let mut header_len_buf = [0u8; 4];
|
|
recv.read_exact(&mut header_len_buf).await?;
|
|
let header_len = u32::from_le_bytes(header_len_buf) as usize;
|
|
|
|
let mut header_buf = vec![0u8; header_len];
|
|
recv.read_exact(&mut header_buf).await?;
|
|
let response: ClusterResponse = bincode::deserialize(&header_buf)?;
|
|
|
|
match response {
|
|
ClusterResponse::ShardReadResponse(read_resp) => {
|
|
if !read_resp.found {
|
|
return Ok(None);
|
|
}
|
|
// Read shard data that follows
|
|
let mut shard_data = vec![0u8; read_resp.shard_data_length as usize];
|
|
recv.read_exact(&mut shard_data).await?;
|
|
Ok(Some((shard_data, read_resp.checksum)))
|
|
}
|
|
ClusterResponse::Error(e) => {
|
|
anyhow::bail!("Shard read error: {} - {}", e.code, e.message)
|
|
}
|
|
other => anyhow::bail!("Unexpected response to shard read: {:?}", other),
|
|
}
|
|
}
|
|
|
|
/// Accept incoming connections and dispatch to the handler.
|
|
pub async fn accept_loop(
|
|
self: Arc<Self>,
|
|
shard_store: Arc<ShardStore>,
|
|
mut shutdown: tokio::sync::watch::Receiver<bool>,
|
|
) {
|
|
loop {
|
|
tokio::select! {
|
|
incoming = self.endpoint.accept() => {
|
|
match incoming {
|
|
Some(incoming_conn) => {
|
|
let transport = self.clone();
|
|
let store = shard_store.clone();
|
|
tokio::spawn(async move {
|
|
match incoming_conn.await {
|
|
Ok(conn) => {
|
|
transport.handle_connection(conn, store).await;
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("Failed to accept QUIC connection: {}", e);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
None => break,
|
|
}
|
|
}
|
|
_ = shutdown.changed() => break,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle a single QUIC connection (may have multiple streams).
|
|
async fn handle_connection(
|
|
&self,
|
|
conn: quinn::Connection,
|
|
shard_store: Arc<ShardStore>,
|
|
) {
|
|
loop {
|
|
match conn.accept_bi().await {
|
|
Ok((send, recv)) => {
|
|
let store = shard_store.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = Self::handle_stream(send, recv, store).await {
|
|
tracing::error!("Stream handler error: {}", e);
|
|
}
|
|
});
|
|
}
|
|
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
|
|
Err(e) => {
|
|
tracing::error!("Connection error: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle a single bidirectional stream (one request-response exchange).
|
|
async fn handle_stream(
|
|
mut send: quinn::SendStream,
|
|
mut recv: quinn::RecvStream,
|
|
shard_store: Arc<ShardStore>,
|
|
) -> Result<()> {
|
|
// Read the length-prefixed request header
|
|
let mut len_buf = [0u8; 4];
|
|
recv.read_exact(&mut len_buf).await?;
|
|
let msg_len = u32::from_le_bytes(len_buf) as usize;
|
|
|
|
let mut msg_buf = vec![0u8; msg_len];
|
|
recv.read_exact(&mut msg_buf).await?;
|
|
let request: ClusterRequest = bincode::deserialize(&msg_buf)?;
|
|
|
|
match request {
|
|
ClusterRequest::ShardWrite(write_req) => {
|
|
// Read shard data from the stream
|
|
let mut shard_data = vec![0u8; write_req.shard_data_length as usize];
|
|
recv.read_exact(&mut shard_data).await?;
|
|
|
|
let shard_id = ShardId {
|
|
bucket: write_req.bucket,
|
|
key: write_req.key,
|
|
chunk_index: write_req.chunk_index,
|
|
shard_index: write_req.shard_index,
|
|
};
|
|
|
|
let result = shard_store
|
|
.write_shard(&shard_id, &shard_data, write_req.checksum)
|
|
.await;
|
|
|
|
let ack = ShardWriteAck {
|
|
request_id: write_req.request_id,
|
|
success: result.is_ok(),
|
|
error: result.err().map(|e| e.to_string()),
|
|
};
|
|
let response = protocol::encode_response(&ClusterResponse::ShardWriteAck(ack))?;
|
|
send.write_all(&response).await?;
|
|
send.finish()?;
|
|
}
|
|
|
|
ClusterRequest::ShardRead(read_req) => {
|
|
let shard_id = ShardId {
|
|
bucket: read_req.bucket,
|
|
key: read_req.key,
|
|
chunk_index: read_req.chunk_index,
|
|
shard_index: read_req.shard_index,
|
|
};
|
|
|
|
match shard_store.read_shard(&shard_id).await {
|
|
Ok((data, checksum)) => {
|
|
let header = ShardReadResponse {
|
|
request_id: read_req.request_id,
|
|
found: true,
|
|
shard_data_length: data.len() as u64,
|
|
checksum,
|
|
};
|
|
// Send header
|
|
let header_bytes = bincode::serialize(&ClusterResponse::ShardReadResponse(header))?;
|
|
send.write_all(&(header_bytes.len() as u32).to_le_bytes()).await?;
|
|
send.write_all(&header_bytes).await?;
|
|
// Send shard data
|
|
send.write_all(&data).await?;
|
|
send.finish()?;
|
|
}
|
|
Err(_) => {
|
|
let header = ShardReadResponse {
|
|
request_id: read_req.request_id,
|
|
found: false,
|
|
shard_data_length: 0,
|
|
checksum: 0,
|
|
};
|
|
let header_bytes = bincode::serialize(&ClusterResponse::ShardReadResponse(header))?;
|
|
send.write_all(&(header_bytes.len() as u32).to_le_bytes()).await?;
|
|
send.write_all(&header_bytes).await?;
|
|
send.finish()?;
|
|
}
|
|
}
|
|
}
|
|
|
|
ClusterRequest::ShardDelete(del_req) => {
|
|
let shard_id = ShardId {
|
|
bucket: del_req.bucket,
|
|
key: del_req.key,
|
|
chunk_index: del_req.chunk_index,
|
|
shard_index: del_req.shard_index,
|
|
};
|
|
let result = shard_store.delete_shard(&shard_id).await;
|
|
let ack = protocol::ClusterResponse::ShardDeleteAck(protocol::ShardDeleteAck {
|
|
request_id: del_req.request_id,
|
|
success: result.is_ok(),
|
|
});
|
|
let response = protocol::encode_response(&ack)?;
|
|
send.write_all(&response).await?;
|
|
send.finish()?;
|
|
}
|
|
|
|
ClusterRequest::ShardHead(head_req) => {
|
|
let shard_id = ShardId {
|
|
bucket: head_req.bucket,
|
|
key: head_req.key,
|
|
chunk_index: head_req.chunk_index,
|
|
shard_index: head_req.shard_index,
|
|
};
|
|
let resp = match shard_store.head_shard(&shard_id).await {
|
|
Ok(Some(meta)) => protocol::ShardHeadResponse {
|
|
request_id: head_req.request_id,
|
|
found: true,
|
|
data_size: meta.data_size,
|
|
checksum: meta.checksum,
|
|
},
|
|
_ => protocol::ShardHeadResponse {
|
|
request_id: head_req.request_id,
|
|
found: false,
|
|
data_size: 0,
|
|
checksum: 0,
|
|
},
|
|
};
|
|
let response =
|
|
protocol::encode_response(&ClusterResponse::ShardHeadResponse(resp))?;
|
|
send.write_all(&response).await?;
|
|
send.finish()?;
|
|
}
|
|
|
|
// Heartbeat, Join, TopologySync, Heal, and Manifest operations
|
|
// will be handled by the membership and coordinator modules.
|
|
// For now, send a generic ack.
|
|
_ => {
|
|
let response_data = recv.read_to_end(0).await.unwrap_or_default();
|
|
drop(response_data);
|
|
let err = protocol::ErrorResponse {
|
|
request_id: String::new(),
|
|
code: "NotImplemented".to_string(),
|
|
message: "This cluster operation is not yet implemented".to_string(),
|
|
};
|
|
let response = protocol::encode_response(&ClusterResponse::Error(err))?;
|
|
send.write_all(&response).await?;
|
|
send.finish()?;
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Generate self-signed TLS certificates for cluster-internal communication.
|
|
fn generate_tls_configs() -> Result<(QuinnServerConfig, ClientConfig)> {
|
|
// Generate self-signed certificate
|
|
let cert = rcgen::generate_simple_self_signed(vec!["smartstorage".to_string()])?;
|
|
let cert_der = CertificateDer::from(cert.cert);
|
|
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()));
|
|
|
|
// Server config
|
|
let mut server_crypto = rustls::ServerConfig::builder()
|
|
.with_no_client_auth()
|
|
.with_single_cert(vec![cert_der.clone()], key_der.clone_key())?;
|
|
server_crypto.alpn_protocols = vec![b"smartstorage".to_vec()];
|
|
let server_config = QuinnServerConfig::with_crypto(Arc::new(
|
|
quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto)?,
|
|
));
|
|
|
|
// Client config: skip server certificate verification (cluster-internal)
|
|
let mut client_crypto = rustls::ClientConfig::builder()
|
|
.dangerous()
|
|
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
|
|
.with_no_client_auth();
|
|
client_crypto.alpn_protocols = vec![b"smartstorage".to_vec()];
|
|
let client_config = ClientConfig::new(Arc::new(
|
|
quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto)?,
|
|
));
|
|
|
|
Ok((server_config, client_config))
|
|
}
|
|
|
|
/// Close the QUIC endpoint gracefully.
|
|
pub fn close(&self) {
|
|
self.endpoint
|
|
.close(quinn::VarInt::from_u32(0), b"shutdown");
|
|
}
|
|
|
|
/// Get the local node ID.
|
|
pub fn local_node_id(&self) -> &str {
|
|
&self.local_node_id
|
|
}
|
|
}
|
|
|
|
/// Certificate verifier that skips verification (for cluster-internal self-signed certs).
|
|
#[derive(Debug)]
|
|
struct SkipServerVerification;
|
|
|
|
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
|
|
fn verify_server_cert(
|
|
&self,
|
|
_end_entity: &CertificateDer<'_>,
|
|
_intermediates: &[CertificateDer<'_>],
|
|
_server_name: &rustls::pki_types::ServerName<'_>,
|
|
_ocsp_response: &[u8],
|
|
_now: rustls::pki_types::UnixTime,
|
|
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
|
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
|
}
|
|
|
|
fn verify_tls12_signature(
|
|
&self,
|
|
_message: &[u8],
|
|
_cert: &CertificateDer<'_>,
|
|
_dss: &rustls::DigitallySignedStruct,
|
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
|
}
|
|
|
|
fn verify_tls13_signature(
|
|
&self,
|
|
_message: &[u8],
|
|
_cert: &CertificateDer<'_>,
|
|
_dss: &rustls::DigitallySignedStruct,
|
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
|
}
|
|
|
|
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
|
vec![
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
|
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
|
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
|
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
|
rustls::SignatureScheme::ED25519,
|
|
rustls::SignatureScheme::RSA_PSS_SHA256,
|
|
rustls::SignatureScheme::RSA_PSS_SHA384,
|
|
rustls::SignatureScheme::RSA_PSS_SHA512,
|
|
]
|
|
}
|
|
}
|