309 lines
9.8 KiB
Rust
309 lines
9.8 KiB
Rust
use socket2::{Domain, Protocol, Socket, Type};
|
|
use std::io;
|
|
use std::mem::MaybeUninit;
|
|
use std::net::{IpAddr, SocketAddr};
|
|
use std::time::{Duration, Instant};
|
|
|
|
#[derive(Debug)]
|
|
pub struct TracerouteHop {
|
|
pub ttl: u8,
|
|
pub ip: Option<String>,
|
|
pub rtt: Option<f64>,
|
|
}
|
|
|
|
pub async fn traceroute(
|
|
host: &str,
|
|
max_hops: u8,
|
|
timeout_ms: u64,
|
|
) -> Result<Vec<TracerouteHop>, String> {
|
|
let dest: IpAddr = resolve_host(host).await?;
|
|
let timeout_dur = Duration::from_millis(timeout_ms);
|
|
|
|
// Run blocking raw-socket traceroute on the blocking thread pool
|
|
tokio::task::spawn_blocking(move || traceroute_blocking(dest, max_hops, timeout_dur))
|
|
.await
|
|
.map_err(|e| format!("Task join error: {e}"))?
|
|
}
|
|
|
|
fn traceroute_blocking(
|
|
dest: IpAddr,
|
|
max_hops: u8,
|
|
timeout: Duration,
|
|
) -> Result<Vec<TracerouteHop>, String> {
|
|
let mut hops = Vec::new();
|
|
|
|
for ttl in 1..=max_hops {
|
|
match send_probe(dest, ttl, timeout) {
|
|
Ok((ip, rtt)) => {
|
|
let reached = ip.as_ref().map(|a| a == &dest.to_string()).unwrap_or(false);
|
|
hops.push(TracerouteHop {
|
|
ttl,
|
|
ip,
|
|
rtt: Some(rtt),
|
|
});
|
|
if reached {
|
|
break;
|
|
}
|
|
}
|
|
Err(ProbeError::Timeout) => {
|
|
hops.push(TracerouteHop {
|
|
ttl,
|
|
ip: None,
|
|
rtt: None,
|
|
});
|
|
}
|
|
Err(ProbeError::Other(e)) => {
|
|
hops.push(TracerouteHop {
|
|
ttl,
|
|
ip: None,
|
|
rtt: None,
|
|
});
|
|
// Log but continue
|
|
eprintln!("Probe error at TTL {ttl}: {e}");
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(hops)
|
|
}
|
|
|
|
enum ProbeError {
|
|
Timeout,
|
|
Other(String),
|
|
}
|
|
|
|
fn send_probe(dest: IpAddr, ttl: u8, timeout: Duration) -> Result<(Option<String>, f64), ProbeError> {
|
|
let (domain, proto) = match dest {
|
|
IpAddr::V4(_) => (Domain::IPV4, Protocol::ICMPV4),
|
|
IpAddr::V6(_) => (Domain::IPV6, Protocol::ICMPV6),
|
|
};
|
|
|
|
let sock = Socket::new(domain, Type::RAW, Some(proto))
|
|
.map_err(|e| ProbeError::Other(format!("Socket creation failed: {e}")))?;
|
|
|
|
sock.set_ttl(ttl as u32)
|
|
.map_err(|e| ProbeError::Other(format!("Failed to set TTL: {e}")))?;
|
|
sock.set_read_timeout(Some(timeout))
|
|
.map_err(|e| ProbeError::Other(format!("Failed to set timeout: {e}")))?;
|
|
|
|
let dest_addr = match dest {
|
|
IpAddr::V4(v4) => SocketAddr::new(IpAddr::V4(v4), 0),
|
|
IpAddr::V6(v6) => SocketAddr::new(IpAddr::V6(v6), 0),
|
|
};
|
|
|
|
// Build ICMP Echo Request packet
|
|
let ident = (std::process::id() as u16) ^ (ttl as u16);
|
|
let seq = ttl as u16;
|
|
let packet = match dest {
|
|
IpAddr::V4(_) => build_icmpv4_echo_request(ident, seq),
|
|
IpAddr::V6(_) => build_icmpv6_echo_request(ident, seq),
|
|
};
|
|
|
|
let start = Instant::now();
|
|
|
|
sock.send_to(&packet, &dest_addr.into())
|
|
.map_err(|e| ProbeError::Other(format!("Send failed: {e}")))?;
|
|
|
|
// Wait for response using MaybeUninit buffer as required by socket2
|
|
let mut buf_uninit = [MaybeUninit::<u8>::uninit(); 512];
|
|
loop {
|
|
match sock.recv_from(&mut buf_uninit) {
|
|
Ok((n, from_addr)) => {
|
|
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
|
|
// Safety: recv_from initialized the first n bytes
|
|
let buf: &[u8] = unsafe {
|
|
std::slice::from_raw_parts(buf_uninit.as_ptr() as *const u8, n)
|
|
};
|
|
let from_ip = match from_addr.as_socket() {
|
|
Some(sa) => sa.ip().to_string(),
|
|
None => "unknown".to_string(),
|
|
};
|
|
|
|
// Check if this response is for our probe
|
|
match dest {
|
|
IpAddr::V4(_) => {
|
|
if is_relevant_icmpv4_response(buf, ident, seq) {
|
|
return Ok((Some(from_ip), elapsed));
|
|
}
|
|
}
|
|
IpAddr::V6(_) => {
|
|
if is_relevant_icmpv6_response(buf, ident, seq) {
|
|
return Ok((Some(from_ip), elapsed));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check if we've exceeded timeout
|
|
if start.elapsed() >= timeout {
|
|
return Err(ProbeError::Timeout);
|
|
}
|
|
}
|
|
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut => {
|
|
return Err(ProbeError::Timeout);
|
|
}
|
|
Err(e) => {
|
|
return Err(ProbeError::Other(format!("Recv error: {e}")));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Check if an ICMPv4 response is relevant to our probe.
|
|
/// It could be Echo Reply (type 0) or Time Exceeded (type 11).
|
|
fn is_relevant_icmpv4_response(buf: &[u8], ident: u16, seq: u16) -> bool {
|
|
// IPv4 header is at least 20 bytes, then ICMP follows
|
|
if buf.len() < 20 {
|
|
return false;
|
|
}
|
|
let ip_header_len = ((buf[0] & 0x0f) as usize) * 4;
|
|
if buf.len() < ip_header_len + 8 {
|
|
return false;
|
|
}
|
|
|
|
let icmp = &buf[ip_header_len..];
|
|
let icmp_type = icmp[0];
|
|
|
|
match icmp_type {
|
|
0 => {
|
|
// Echo Reply: check ident and seq
|
|
if icmp.len() < 8 {
|
|
return false;
|
|
}
|
|
let reply_ident = u16::from_be_bytes([icmp[4], icmp[5]]);
|
|
let reply_seq = u16::from_be_bytes([icmp[6], icmp[7]]);
|
|
reply_ident == ident && reply_seq == seq
|
|
}
|
|
11 => {
|
|
// Time Exceeded: the original IP packet + first 8 bytes of original ICMP are in payload
|
|
// icmp[0]=type, [1]=code, [2-3]=checksum, [4-7]=unused, [8+]=original IP header+8 bytes
|
|
if icmp.len() < 36 {
|
|
// 8 (outer ICMP header) + 20 (inner IP header) + 8 (inner ICMP header)
|
|
return false;
|
|
}
|
|
let inner_ip = &icmp[8..];
|
|
let inner_ip_header_len = ((inner_ip[0] & 0x0f) as usize) * 4;
|
|
if icmp.len() < 8 + inner_ip_header_len + 8 {
|
|
return false;
|
|
}
|
|
let inner_icmp = &inner_ip[inner_ip_header_len..];
|
|
// Check inner ICMP echo request ident and seq
|
|
if inner_icmp[0] != 8 {
|
|
// Not echo request
|
|
return false;
|
|
}
|
|
let inner_ident = u16::from_be_bytes([inner_icmp[4], inner_icmp[5]]);
|
|
let inner_seq = u16::from_be_bytes([inner_icmp[6], inner_icmp[7]]);
|
|
inner_ident == ident && inner_seq == seq
|
|
}
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
/// Check if an ICMPv6 response is relevant to our probe
|
|
fn is_relevant_icmpv6_response(buf: &[u8], ident: u16, seq: u16) -> bool {
|
|
// ICMPv6: no IP header in raw socket recv (kernel strips it)
|
|
if buf.len() < 8 {
|
|
return false;
|
|
}
|
|
let icmp_type = buf[0];
|
|
|
|
match icmp_type {
|
|
129 => {
|
|
// Echo Reply
|
|
let reply_ident = u16::from_be_bytes([buf[4], buf[5]]);
|
|
let reply_seq = u16::from_be_bytes([buf[6], buf[7]]);
|
|
reply_ident == ident && reply_seq == seq
|
|
}
|
|
3 => {
|
|
// Time Exceeded: payload contains original IPv6 header + first bytes of original ICMPv6
|
|
if buf.len() < 56 {
|
|
// 8 (outer ICMPv6) + 40 (inner IPv6 header) + 8 (inner ICMPv6)
|
|
return false;
|
|
}
|
|
let inner_icmp = &buf[48..]; // 8 + 40
|
|
if inner_icmp[0] != 128 {
|
|
// Not echo request
|
|
return false;
|
|
}
|
|
let inner_ident = u16::from_be_bytes([inner_icmp[4], inner_icmp[5]]);
|
|
let inner_seq = u16::from_be_bytes([inner_icmp[6], inner_icmp[7]]);
|
|
inner_ident == ident && inner_seq == seq
|
|
}
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
/// Build an ICMPv4 Echo Request packet
|
|
fn build_icmpv4_echo_request(ident: u16, seq: u16) -> Vec<u8> {
|
|
let mut pkt = vec![0u8; 64]; // 8 header + 56 payload
|
|
pkt[0] = 8; // Type: Echo Request
|
|
pkt[1] = 0; // Code
|
|
// Checksum placeholder [2,3]
|
|
pkt[4] = (ident >> 8) as u8;
|
|
pkt[5] = (ident & 0xff) as u8;
|
|
pkt[6] = (seq >> 8) as u8;
|
|
pkt[7] = (seq & 0xff) as u8;
|
|
|
|
// Fill payload with pattern
|
|
for i in 8..64 {
|
|
pkt[i] = (i as u8) & 0xff;
|
|
}
|
|
|
|
// Calculate checksum
|
|
let cksum = icmp_checksum(&pkt);
|
|
pkt[2] = (cksum >> 8) as u8;
|
|
pkt[3] = (cksum & 0xff) as u8;
|
|
|
|
pkt
|
|
}
|
|
|
|
/// Build an ICMPv6 Echo Request packet
|
|
fn build_icmpv6_echo_request(ident: u16, seq: u16) -> Vec<u8> {
|
|
let mut pkt = vec![0u8; 64];
|
|
pkt[0] = 128; // Type: Echo Request
|
|
pkt[1] = 0; // Code
|
|
// Checksum [2,3] - kernel calculates for ICMPv6
|
|
pkt[4] = (ident >> 8) as u8;
|
|
pkt[5] = (ident & 0xff) as u8;
|
|
pkt[6] = (seq >> 8) as u8;
|
|
pkt[7] = (seq & 0xff) as u8;
|
|
|
|
for i in 8..64 {
|
|
pkt[i] = (i as u8) & 0xff;
|
|
}
|
|
|
|
// Note: ICMPv6 checksum is computed by the kernel when using raw sockets on Linux
|
|
pkt
|
|
}
|
|
|
|
/// Calculate ICMP checksum
|
|
fn icmp_checksum(data: &[u8]) -> u16 {
|
|
let mut sum: u32 = 0;
|
|
let mut i = 0;
|
|
while i + 1 < data.len() {
|
|
sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
|
|
i += 2;
|
|
}
|
|
if i < data.len() {
|
|
sum += (data[i] as u32) << 8;
|
|
}
|
|
while sum >> 16 != 0 {
|
|
sum = (sum & 0xffff) + (sum >> 16);
|
|
}
|
|
!sum as u16
|
|
}
|
|
|
|
async fn resolve_host(host: &str) -> Result<IpAddr, String> {
|
|
if let Ok(addr) = host.parse::<IpAddr>() {
|
|
return Ok(addr);
|
|
}
|
|
let addrs = tokio::net::lookup_host(format!("{host}:0"))
|
|
.await
|
|
.map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
|
|
|
|
for addr in addrs {
|
|
return Ok(addr.ip());
|
|
}
|
|
Err(format!("No addresses found for {host}"))
|
|
}
|