Files
smartnetwork/rust/crates/rustnetwork/src/traceroute.rs

309 lines
9.8 KiB
Rust

use socket2::{Domain, Protocol, Socket, Type};
use std::io;
use std::mem::MaybeUninit;
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct TracerouteHop {
pub ttl: u8,
pub ip: Option<String>,
pub rtt: Option<f64>,
}
pub async fn traceroute(
host: &str,
max_hops: u8,
timeout_ms: u64,
) -> Result<Vec<TracerouteHop>, String> {
let dest: IpAddr = resolve_host(host).await?;
let timeout_dur = Duration::from_millis(timeout_ms);
// Run blocking raw-socket traceroute on the blocking thread pool
tokio::task::spawn_blocking(move || traceroute_blocking(dest, max_hops, timeout_dur))
.await
.map_err(|e| format!("Task join error: {e}"))?
}
fn traceroute_blocking(
dest: IpAddr,
max_hops: u8,
timeout: Duration,
) -> Result<Vec<TracerouteHop>, String> {
let mut hops = Vec::new();
for ttl in 1..=max_hops {
match send_probe(dest, ttl, timeout) {
Ok((ip, rtt)) => {
let reached = ip.as_ref().map(|a| a == &dest.to_string()).unwrap_or(false);
hops.push(TracerouteHop {
ttl,
ip,
rtt: Some(rtt),
});
if reached {
break;
}
}
Err(ProbeError::Timeout) => {
hops.push(TracerouteHop {
ttl,
ip: None,
rtt: None,
});
}
Err(ProbeError::Other(e)) => {
hops.push(TracerouteHop {
ttl,
ip: None,
rtt: None,
});
// Log but continue
eprintln!("Probe error at TTL {ttl}: {e}");
}
}
}
Ok(hops)
}
enum ProbeError {
Timeout,
Other(String),
}
fn send_probe(dest: IpAddr, ttl: u8, timeout: Duration) -> Result<(Option<String>, f64), ProbeError> {
let (domain, proto) = match dest {
IpAddr::V4(_) => (Domain::IPV4, Protocol::ICMPV4),
IpAddr::V6(_) => (Domain::IPV6, Protocol::ICMPV6),
};
let sock = Socket::new(domain, Type::RAW, Some(proto))
.map_err(|e| ProbeError::Other(format!("Socket creation failed: {e}")))?;
sock.set_ttl(ttl as u32)
.map_err(|e| ProbeError::Other(format!("Failed to set TTL: {e}")))?;
sock.set_read_timeout(Some(timeout))
.map_err(|e| ProbeError::Other(format!("Failed to set timeout: {e}")))?;
let dest_addr = match dest {
IpAddr::V4(v4) => SocketAddr::new(IpAddr::V4(v4), 0),
IpAddr::V6(v6) => SocketAddr::new(IpAddr::V6(v6), 0),
};
// Build ICMP Echo Request packet
let ident = (std::process::id() as u16) ^ (ttl as u16);
let seq = ttl as u16;
let packet = match dest {
IpAddr::V4(_) => build_icmpv4_echo_request(ident, seq),
IpAddr::V6(_) => build_icmpv6_echo_request(ident, seq),
};
let start = Instant::now();
sock.send_to(&packet, &dest_addr.into())
.map_err(|e| ProbeError::Other(format!("Send failed: {e}")))?;
// Wait for response using MaybeUninit buffer as required by socket2
let mut buf_uninit = [MaybeUninit::<u8>::uninit(); 512];
loop {
match sock.recv_from(&mut buf_uninit) {
Ok((n, from_addr)) => {
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
// Safety: recv_from initialized the first n bytes
let buf: &[u8] = unsafe {
std::slice::from_raw_parts(buf_uninit.as_ptr() as *const u8, n)
};
let from_ip = match from_addr.as_socket() {
Some(sa) => sa.ip().to_string(),
None => "unknown".to_string(),
};
// Check if this response is for our probe
match dest {
IpAddr::V4(_) => {
if is_relevant_icmpv4_response(buf, ident, seq) {
return Ok((Some(from_ip), elapsed));
}
}
IpAddr::V6(_) => {
if is_relevant_icmpv6_response(buf, ident, seq) {
return Ok((Some(from_ip), elapsed));
}
}
}
// Check if we've exceeded timeout
if start.elapsed() >= timeout {
return Err(ProbeError::Timeout);
}
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut => {
return Err(ProbeError::Timeout);
}
Err(e) => {
return Err(ProbeError::Other(format!("Recv error: {e}")));
}
}
}
}
/// Check if an ICMPv4 response is relevant to our probe.
/// It could be Echo Reply (type 0) or Time Exceeded (type 11).
fn is_relevant_icmpv4_response(buf: &[u8], ident: u16, seq: u16) -> bool {
// IPv4 header is at least 20 bytes, then ICMP follows
if buf.len() < 20 {
return false;
}
let ip_header_len = ((buf[0] & 0x0f) as usize) * 4;
if buf.len() < ip_header_len + 8 {
return false;
}
let icmp = &buf[ip_header_len..];
let icmp_type = icmp[0];
match icmp_type {
0 => {
// Echo Reply: check ident and seq
if icmp.len() < 8 {
return false;
}
let reply_ident = u16::from_be_bytes([icmp[4], icmp[5]]);
let reply_seq = u16::from_be_bytes([icmp[6], icmp[7]]);
reply_ident == ident && reply_seq == seq
}
11 => {
// Time Exceeded: the original IP packet + first 8 bytes of original ICMP are in payload
// icmp[0]=type, [1]=code, [2-3]=checksum, [4-7]=unused, [8+]=original IP header+8 bytes
if icmp.len() < 36 {
// 8 (outer ICMP header) + 20 (inner IP header) + 8 (inner ICMP header)
return false;
}
let inner_ip = &icmp[8..];
let inner_ip_header_len = ((inner_ip[0] & 0x0f) as usize) * 4;
if icmp.len() < 8 + inner_ip_header_len + 8 {
return false;
}
let inner_icmp = &inner_ip[inner_ip_header_len..];
// Check inner ICMP echo request ident and seq
if inner_icmp[0] != 8 {
// Not echo request
return false;
}
let inner_ident = u16::from_be_bytes([inner_icmp[4], inner_icmp[5]]);
let inner_seq = u16::from_be_bytes([inner_icmp[6], inner_icmp[7]]);
inner_ident == ident && inner_seq == seq
}
_ => false,
}
}
/// Check if an ICMPv6 response is relevant to our probe
fn is_relevant_icmpv6_response(buf: &[u8], ident: u16, seq: u16) -> bool {
// ICMPv6: no IP header in raw socket recv (kernel strips it)
if buf.len() < 8 {
return false;
}
let icmp_type = buf[0];
match icmp_type {
129 => {
// Echo Reply
let reply_ident = u16::from_be_bytes([buf[4], buf[5]]);
let reply_seq = u16::from_be_bytes([buf[6], buf[7]]);
reply_ident == ident && reply_seq == seq
}
3 => {
// Time Exceeded: payload contains original IPv6 header + first bytes of original ICMPv6
if buf.len() < 56 {
// 8 (outer ICMPv6) + 40 (inner IPv6 header) + 8 (inner ICMPv6)
return false;
}
let inner_icmp = &buf[48..]; // 8 + 40
if inner_icmp[0] != 128 {
// Not echo request
return false;
}
let inner_ident = u16::from_be_bytes([inner_icmp[4], inner_icmp[5]]);
let inner_seq = u16::from_be_bytes([inner_icmp[6], inner_icmp[7]]);
inner_ident == ident && inner_seq == seq
}
_ => false,
}
}
/// Build an ICMPv4 Echo Request packet
fn build_icmpv4_echo_request(ident: u16, seq: u16) -> Vec<u8> {
let mut pkt = vec![0u8; 64]; // 8 header + 56 payload
pkt[0] = 8; // Type: Echo Request
pkt[1] = 0; // Code
// Checksum placeholder [2,3]
pkt[4] = (ident >> 8) as u8;
pkt[5] = (ident & 0xff) as u8;
pkt[6] = (seq >> 8) as u8;
pkt[7] = (seq & 0xff) as u8;
// Fill payload with pattern
for i in 8..64 {
pkt[i] = (i as u8) & 0xff;
}
// Calculate checksum
let cksum = icmp_checksum(&pkt);
pkt[2] = (cksum >> 8) as u8;
pkt[3] = (cksum & 0xff) as u8;
pkt
}
/// Build an ICMPv6 Echo Request packet
fn build_icmpv6_echo_request(ident: u16, seq: u16) -> Vec<u8> {
let mut pkt = vec![0u8; 64];
pkt[0] = 128; // Type: Echo Request
pkt[1] = 0; // Code
// Checksum [2,3] - kernel calculates for ICMPv6
pkt[4] = (ident >> 8) as u8;
pkt[5] = (ident & 0xff) as u8;
pkt[6] = (seq >> 8) as u8;
pkt[7] = (seq & 0xff) as u8;
for i in 8..64 {
pkt[i] = (i as u8) & 0xff;
}
// Note: ICMPv6 checksum is computed by the kernel when using raw sockets on Linux
pkt
}
/// Calculate ICMP checksum
fn icmp_checksum(data: &[u8]) -> u16 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < data.len() {
sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
i += 2;
}
if i < data.len() {
sum += (data[i] as u32) << 8;
}
while sum >> 16 != 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
!sum as u16
}
async fn resolve_host(host: &str) -> Result<IpAddr, String> {
if let Ok(addr) = host.parse::<IpAddr>() {
return Ok(addr);
}
let addrs = tokio::net::lookup_host(format!("{host}:0"))
.await
.map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
for addr in addrs {
return Ok(addr.ip());
}
Err(format!("No addresses found for {host}"))
}