248 lines
8.5 KiB
Rust
248 lines
8.5 KiB
Rust
|
|
mod common;
|
||
|
|
|
||
|
|
use common::*;
|
||
|
|
use rustproxy::RustProxy;
|
||
|
|
use rustproxy_config::RustProxyOptions;
|
||
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||
|
|
use tokio::net::TcpStream;
|
||
|
|
|
||
|
|
/// Build a minimal TLS ClientHello with the given SNI domain.
|
||
|
|
/// This is enough for the proxy's SNI parser to extract the domain.
|
||
|
|
fn build_client_hello(domain: &str) -> Vec<u8> {
|
||
|
|
let domain_bytes = domain.as_bytes();
|
||
|
|
let sni_length = domain_bytes.len() as u16;
|
||
|
|
|
||
|
|
// Server Name extension (type 0x0000)
|
||
|
|
let mut sni_ext = Vec::new();
|
||
|
|
sni_ext.extend_from_slice(&[0x00, 0x00]); // extension type: server_name
|
||
|
|
let sni_list_len = sni_length + 5; // 2 (list len) + 1 (type) + 2 (name len) + name
|
||
|
|
sni_ext.extend_from_slice(&(sni_list_len as u16).to_be_bytes()); // extension data length
|
||
|
|
sni_ext.extend_from_slice(&((sni_list_len - 2) as u16).to_be_bytes()); // server name list length
|
||
|
|
sni_ext.push(0x00); // host_name type
|
||
|
|
sni_ext.extend_from_slice(&sni_length.to_be_bytes());
|
||
|
|
sni_ext.extend_from_slice(domain_bytes);
|
||
|
|
|
||
|
|
let extensions_length = sni_ext.len() as u16;
|
||
|
|
|
||
|
|
// ClientHello message
|
||
|
|
let mut client_hello = Vec::new();
|
||
|
|
client_hello.extend_from_slice(&[0x03, 0x03]); // TLS 1.2 version
|
||
|
|
client_hello.extend_from_slice(&[0x00; 32]); // random
|
||
|
|
client_hello.push(0x00); // session_id length
|
||
|
|
client_hello.extend_from_slice(&[0x00, 0x02, 0x00, 0xff]); // cipher suites (1 suite)
|
||
|
|
client_hello.extend_from_slice(&[0x01, 0x00]); // compression methods (null)
|
||
|
|
client_hello.extend_from_slice(&extensions_length.to_be_bytes());
|
||
|
|
client_hello.extend_from_slice(&sni_ext);
|
||
|
|
|
||
|
|
let hello_len = client_hello.len() as u32;
|
||
|
|
|
||
|
|
// Handshake wrapper (type 1 = ClientHello)
|
||
|
|
let mut handshake = Vec::new();
|
||
|
|
handshake.push(0x01); // ClientHello
|
||
|
|
handshake.extend_from_slice(&hello_len.to_be_bytes()[1..4]); // 3-byte length
|
||
|
|
handshake.extend_from_slice(&client_hello);
|
||
|
|
|
||
|
|
let hs_len = handshake.len() as u16;
|
||
|
|
|
||
|
|
// TLS record
|
||
|
|
let mut record = Vec::new();
|
||
|
|
record.push(0x16); // ContentType: Handshake
|
||
|
|
record.extend_from_slice(&[0x03, 0x01]); // TLS 1.0 (record version)
|
||
|
|
record.extend_from_slice(&hs_len.to_be_bytes());
|
||
|
|
record.extend_from_slice(&handshake);
|
||
|
|
|
||
|
|
record
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_tls_passthrough_sni_routing() {
|
||
|
|
let backend1_port = next_port();
|
||
|
|
let backend2_port = next_port();
|
||
|
|
let proxy_port = next_port();
|
||
|
|
|
||
|
|
let _b1 = start_prefix_echo_server(backend1_port, "BACKEND1:").await;
|
||
|
|
let _b2 = start_prefix_echo_server(backend2_port, "BACKEND2:").await;
|
||
|
|
|
||
|
|
let options = RustProxyOptions {
|
||
|
|
routes: vec![
|
||
|
|
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
|
||
|
|
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
|
||
|
|
],
|
||
|
|
..Default::default()
|
||
|
|
};
|
||
|
|
|
||
|
|
let mut proxy = RustProxy::new(options).unwrap();
|
||
|
|
proxy.start().await.unwrap();
|
||
|
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||
|
|
|
||
|
|
// Send a fake ClientHello with SNI "one.example.com"
|
||
|
|
let result = with_timeout(async {
|
||
|
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
let hello = build_client_hello("one.example.com");
|
||
|
|
stream.write_all(&hello).await.unwrap();
|
||
|
|
|
||
|
|
let mut buf = vec![0u8; 4096];
|
||
|
|
let n = stream.read(&mut buf).await.unwrap();
|
||
|
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||
|
|
}, 5)
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
|
||
|
|
// Backend1 should have received the ClientHello and prefixed its response
|
||
|
|
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
|
||
|
|
|
||
|
|
// Now test routing to backend2
|
||
|
|
let result2 = with_timeout(async {
|
||
|
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
let hello = build_client_hello("two.example.com");
|
||
|
|
stream.write_all(&hello).await.unwrap();
|
||
|
|
|
||
|
|
let mut buf = vec![0u8; 4096];
|
||
|
|
let n = stream.read(&mut buf).await.unwrap();
|
||
|
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||
|
|
}, 5)
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
|
||
|
|
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
|
||
|
|
|
||
|
|
proxy.stop().await.unwrap();
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_tls_passthrough_unknown_sni() {
|
||
|
|
let backend_port = next_port();
|
||
|
|
let proxy_port = next_port();
|
||
|
|
|
||
|
|
let _backend = start_echo_server(backend_port).await;
|
||
|
|
|
||
|
|
let options = RustProxyOptions {
|
||
|
|
routes: vec![
|
||
|
|
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
|
||
|
|
],
|
||
|
|
..Default::default()
|
||
|
|
};
|
||
|
|
|
||
|
|
let mut proxy = RustProxy::new(options).unwrap();
|
||
|
|
proxy.start().await.unwrap();
|
||
|
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||
|
|
|
||
|
|
// Send ClientHello with unknown SNI - should get no response (connection dropped)
|
||
|
|
let result = with_timeout(async {
|
||
|
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
let hello = build_client_hello("unknown.example.com");
|
||
|
|
stream.write_all(&hello).await.unwrap();
|
||
|
|
|
||
|
|
let mut buf = vec![0u8; 4096];
|
||
|
|
// Should either get 0 bytes (closed) or an error
|
||
|
|
match stream.read(&mut buf).await {
|
||
|
|
Ok(0) => true, // Connection closed = no route matched
|
||
|
|
Ok(_) => false, // Got data = route shouldn't have matched
|
||
|
|
Err(_) => true, // Error = connection dropped
|
||
|
|
}
|
||
|
|
}, 5)
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
|
||
|
|
assert!(result, "Unknown SNI should result in dropped connection");
|
||
|
|
|
||
|
|
proxy.stop().await.unwrap();
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_tls_passthrough_wildcard_domain() {
|
||
|
|
let backend_port = next_port();
|
||
|
|
let proxy_port = next_port();
|
||
|
|
|
||
|
|
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
|
||
|
|
|
||
|
|
let options = RustProxyOptions {
|
||
|
|
routes: vec![
|
||
|
|
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
|
||
|
|
],
|
||
|
|
..Default::default()
|
||
|
|
};
|
||
|
|
|
||
|
|
let mut proxy = RustProxy::new(options).unwrap();
|
||
|
|
proxy.start().await.unwrap();
|
||
|
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||
|
|
|
||
|
|
// Should match any subdomain of example.com
|
||
|
|
let result = with_timeout(async {
|
||
|
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
let hello = build_client_hello("anything.example.com");
|
||
|
|
stream.write_all(&hello).await.unwrap();
|
||
|
|
|
||
|
|
let mut buf = vec![0u8; 4096];
|
||
|
|
let n = stream.read(&mut buf).await.unwrap();
|
||
|
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||
|
|
}, 5)
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
|
||
|
|
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
|
||
|
|
|
||
|
|
proxy.stop().await.unwrap();
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_tls_passthrough_multiple_domains() {
|
||
|
|
let b1_port = next_port();
|
||
|
|
let b2_port = next_port();
|
||
|
|
let b3_port = next_port();
|
||
|
|
let proxy_port = next_port();
|
||
|
|
|
||
|
|
let _b1 = start_prefix_echo_server(b1_port, "B1:").await;
|
||
|
|
let _b2 = start_prefix_echo_server(b2_port, "B2:").await;
|
||
|
|
let _b3 = start_prefix_echo_server(b3_port, "B3:").await;
|
||
|
|
|
||
|
|
let options = RustProxyOptions {
|
||
|
|
routes: vec![
|
||
|
|
make_tls_passthrough_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", b1_port),
|
||
|
|
make_tls_passthrough_route(proxy_port, Some("beta.example.com"), "127.0.0.1", b2_port),
|
||
|
|
make_tls_passthrough_route(proxy_port, Some("gamma.example.com"), "127.0.0.1", b3_port),
|
||
|
|
],
|
||
|
|
..Default::default()
|
||
|
|
};
|
||
|
|
|
||
|
|
let mut proxy = RustProxy::new(options).unwrap();
|
||
|
|
proxy.start().await.unwrap();
|
||
|
|
assert!(wait_for_port(proxy_port, 2000).await);
|
||
|
|
|
||
|
|
for (domain, expected_prefix) in [
|
||
|
|
("alpha.example.com", "B1:"),
|
||
|
|
("beta.example.com", "B2:"),
|
||
|
|
("gamma.example.com", "B3:"),
|
||
|
|
] {
|
||
|
|
let result = with_timeout(async {
|
||
|
|
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
let hello = build_client_hello(domain);
|
||
|
|
stream.write_all(&hello).await.unwrap();
|
||
|
|
|
||
|
|
let mut buf = vec![0u8; 4096];
|
||
|
|
let n = stream.read(&mut buf).await.unwrap();
|
||
|
|
String::from_utf8_lossy(&buf[..n]).to_string()
|
||
|
|
}, 5)
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
|
||
|
|
assert!(
|
||
|
|
result.starts_with(expected_prefix),
|
||
|
|
"Domain {} should route to {}, got: {}",
|
||
|
|
domain, expected_prefix, result
|
||
|
|
);
|
||
|
|
}
|
||
|
|
|
||
|
|
proxy.stop().await.unwrap();
|
||
|
|
}
|