From 0eee869a2ba6732b3775f195b433b4f9641692be Mon Sep 17 00:00:00 2001 From: Marco Cadetg Date: Mon, 30 Jun 2025 14:15:40 +0200 Subject: [PATCH] working code ;) --- Cargo.lock | 97 +- Cargo.toml | 3 + README.md | 27 + src/app.rs | 1072 +++++++++---------- src/main.rs | 196 ++-- src/network/capture.rs | 380 +++++++ src/network/dpi/dns.rs | 75 ++ src/network/dpi/http.rs | 130 +++ src/network/dpi/mod.rs | 100 ++ src/network/dpi/quic.rs | 22 + src/network/dpi/tls.rs | 201 ++++ src/network/linux.rs | 183 ---- src/network/macos.rs | 391 ------- src/network/merge.rs | 303 ++++++ src/network/mod.rs | 1730 +------------------------------ src/network/mod.rs.old | 1713 ++++++++++++++++++++++++++++++ src/network/parser.rs | 479 +++++++++ src/network/platform/linux.rs | 228 ++++ src/network/platform/macos.rs | 81 ++ src/network/platform/mod.rs | 73 ++ src/network/platform/windows.rs | 59 ++ src/network/services.rs | 272 +++++ src/network/types.rs | 273 +++++ src/network/windows.rs | 212 ---- src/ui.rs | 555 ++++------ 25 files changed, 5360 insertions(+), 3495 deletions(-) create mode 100644 src/network/capture.rs create mode 100644 src/network/dpi/dns.rs create mode 100644 src/network/dpi/http.rs create mode 100644 src/network/dpi/mod.rs create mode 100644 src/network/dpi/quic.rs create mode 100644 src/network/dpi/tls.rs delete mode 100644 src/network/linux.rs delete mode 100644 src/network/macos.rs create mode 100644 src/network/merge.rs create mode 100644 src/network/mod.rs.old create mode 100644 src/network/parser.rs create mode 100644 src/network/platform/linux.rs create mode 100644 src/network/platform/macos.rs create mode 100644 src/network/platform/mod.rs create mode 100644 src/network/platform/windows.rs create mode 100644 src/network/services.rs create mode 100644 src/network/types.rs delete mode 100644 src/network/windows.rs diff --git a/Cargo.lock b/Cargo.lock index 5ef4483..ac97a33 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -287,6 +287,62 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crossterm" version = "0.28.1" @@ -365,6 +421,20 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "deranged" version = "0.4.0" @@ -516,6 +586,12 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.4" @@ -533,6 +609,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "hex" version = "0.4.3" @@ -709,7 +791,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown", + "hashbrown 0.15.4", ] [[package]] @@ -761,6 +843,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "num_threads" version = "0.1.7" @@ -1100,9 +1192,12 @@ dependencies = [ "arboard", "chrono", "clap", + "crossbeam", "crossterm 0.29.0", + "dashmap", "dns-lookup", "log", + "num_cpus", "pcap", "pnet_datalink", "procfs", diff --git a/Cargo.toml b/Cargo.toml index 1a6cea1..cdbfc5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,11 @@ edition = "2024" anyhow = "1.0" arboard = "3.5" crossterm = "0.29" +crossbeam = "0.8" +dashmap = "6.1" dns-lookup = "2.0" log = "0.4" +num_cpus = "1.17" pcap = "2.2" pnet_datalink = "0.35" clap = { version = "4.5", features = ["derive"] } diff --git a/README.md b/README.md index 743f842..dd78596 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,31 @@ refresh_interval: 1000 show_locations: true ``` +## Architecture + +┌─────────────────┐ +│ Packet Capture │ ──packets──> channel +└─────────────────┘ │ + ├──> ┌──────────────────┐ + ├──> │ Packet Processor │ ──> DashMap + ├──> │ (Thread 0) │ │ + └──> │ (Thread N) │ │ + └──────────────────┘ │ + │ +┌─────────────────-┐ │ +│Process Enrichment│ ──────────────────────────────────────────> DashMap +└─────────────────-┘ │ + │ +┌─────────────────┐ │ +│Snapshot Provider│ <────────────────────────────────────────── DashMap +└─────────────────┘ │ + │ │ + └──> RwLock> (for UI) │ + │ +┌─────────────────┐ │ +│ Cleanup Thread │ <────────────────────────────────────────── DashMap +└─────────────────┘ + ## Internationalization RustNet supports multiple languages. The application looks for language files in the following locations: @@ -96,6 +121,7 @@ RustNet supports multiple languages. The application looks for language files in 4. `/usr/share/rustnet/i18n/[language].yml` Currently supported languages: + - English (en) - French (fr) @@ -114,6 +140,7 @@ RustNet attempts to identify the process associated with each network connection ## TODOs ### GeoIP Lookup + For GeoIP lookup: MaxMind GeoLite2 City database (place `GeoLite2-City.mmdb` in the application directory) When a MaxMind GeoLite2 City database is available, RustNet can display geographical information about remote IP addresses. To use this feature: diff --git a/src/app.rs b/src/app.rs index ee47fca..baeec24 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,651 +1,561 @@ +// app.rs - Main application orchestration (with debug logging) use anyhow::Result; -use arboard::Clipboard; -use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; -use log::error; +use crossbeam::channel::{self, Receiver, Sender}; +use dashmap::DashMap; +use log::{debug, error, info, warn}; use std::collections::HashMap; -use std::net::IpAddr; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::mpsc; -use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, RwLock}; use std::thread; -use std::time::Duration; +use std::time::{Duration, Instant, SystemTime}; -use crate::config::Config; -use crate::i18n::I18n; -use crate::network::{self, Connection, NetworkMonitor, Process}; +use crate::network::{ + capture::{CaptureConfig, PacketReader, setup_packet_capture}, + merge::{ + create_connection_from_packet, enrich_with_process_info, enrich_with_service_name, + merge_packet_into_connection, + }, + parser::{PacketParser, ParsedPacket, ParserConfig}, + platform::{ProcessLookup, create_process_lookup}, + services::ServiceLookup, + types::Connection, +}; -/// Application actions -pub enum Action { - Quit, - Refresh, +/// Application configuration +#[derive(Debug, Clone)] +pub struct Config { + /// Network interface to capture from (None for default) + pub interface: Option, + /// Filter localhost connections + pub filter_localhost: bool, + /// UI refresh interval in milliseconds + pub refresh_interval: u64, + /// Enable deep packet inspection + pub enable_dpi: bool, + /// Process lookup interval in seconds + pub process_lookup_interval: u64, + /// Connection timeout in seconds (remove inactive connections) + pub connection_timeout: u64, + /// BPF filter for packet capture + pub bpf_filter: Option, } -/// Application view modes -pub enum ViewMode { - Overview, - ConnectionDetails, - Help, +impl Default for Config { + fn default() -> Self { + Self { + interface: None, + filter_localhost: true, + refresh_interval: 1000, + enable_dpi: true, + process_lookup_interval: 2, + connection_timeout: 60, + bpf_filter: None, // No filter by default to see all packets + } + } } -/// Fields that can be focused for copying in the Connection Details view -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum DetailFocusField { - LocalIp, - RemoteIp, +/// Application statistics +#[derive(Debug)] +pub struct AppStats { + pub packets_processed: AtomicU64, + pub packets_dropped: AtomicU64, + pub connections_tracked: AtomicU64, + pub last_update: RwLock, } -/// Application state +impl Default for AppStats { + fn default() -> Self { + Self { + packets_processed: AtomicU64::new(0), + packets_dropped: AtomicU64::new(0), + connections_tracked: AtomicU64::new(0), + last_update: RwLock::new(Instant::now()), + } + } +} + +/// Main application state pub struct App { - pub config: Config, - pub i18n: I18n, - pub mode: ViewMode, - network_monitor: Arc>, - pub connections: Vec, - pub processes: HashMap, - pub selected_connection: Option, - pub selected_connection_idx: usize, - pub show_locations: bool, - pub show_hostnames: bool, - connection_order: HashMap, - next_order_index: usize, - dns_cache: HashMap, - connections_data_shared: Arc>>, - pub detail_focus: DetailFocusField, - processes_data_shared: Arc>>, - pub is_loading: bool, - pub loading_message: String, - loading_spinner_index: usize, - should_stop: Arc, -} + /// Configuration + config: Config, -const PROCESS_INFO_UPDATE_INTERVAL: Duration = Duration::from_secs(5); + /// Control flag for graceful shutdown + should_stop: Arc, + + /// Current connections snapshot for UI + connections_snapshot: Arc>>, + + /// Service name lookup + service_lookup: Arc, + + /// Application statistics + stats: Arc, + + /// Loading state + is_loading: Arc, +} impl App { - pub fn new(config: Config, i18n: I18n) -> Result { - log::info!("App::new - Starting application initialization"); - let monitor_start = std::time::Instant::now(); - let monitor = NetworkMonitor::new(config.interface.clone(), config.filter_localhost)?; - log::info!( - "App::new - NetworkMonitor created in {:?}", - monitor_start.elapsed() - ); - let app = Self { + /// Create a new application instance + pub fn new(config: Config) -> Result { + // Load service definitions + let service_lookup = ServiceLookup::from_file("/etc/services").unwrap_or_else(|e| { + warn!("Failed to load /etc/services: {}, using defaults", e); + ServiceLookup::with_defaults() + }); + + Ok(Self { config, - i18n, - mode: ViewMode::Overview, - network_monitor: Arc::new(Mutex::new(monitor)), - connections: Vec::new(), - processes: HashMap::new(), - selected_connection: None, - selected_connection_idx: 0, - show_locations: true, - show_hostnames: false, - connection_order: HashMap::new(), - next_order_index: 0, - dns_cache: HashMap::new(), - connections_data_shared: Arc::new(Mutex::new(Vec::new())), - processes_data_shared: Arc::new(Mutex::new(HashMap::new())), - detail_focus: DetailFocusField::LocalIp, - is_loading: true, - loading_message: "Initializing network monitor...".to_string(), - loading_spinner_index: 0, should_stop: Arc::new(AtomicBool::new(false)), - }; - log::info!("App::new - Application initialized successfully"); - Ok(app) + connections_snapshot: Arc::new(RwLock::new(Vec::new())), + service_lookup: Arc::new(service_lookup), + stats: Arc::new(AppStats::default()), + is_loading: Arc::new(AtomicBool::new(true)), + }) } - pub fn start_capture(&mut self) -> Result<()> { - log::info!("App::start_capture - Starting network capture setup"); - let start_time = std::time::Instant::now(); + /// Start all background threads + pub fn start(&mut self) -> Result<()> { + info!("Starting network monitor application"); - // Update loading message - self.loading_message = "Starting background threads...".to_string(); + // Create shared connection map + let connections: Arc> = Arc::new(DashMap::new()); - // --- Packet Capture Thread --- - let (packet_tx, packet_rx) = mpsc::channel::>(); - let interface_name = self.config.interface.clone(); - let should_stop_capture = Arc::clone(&self.should_stop); + // Start packet capture pipeline + self.start_packet_capture_pipeline(connections.clone())?; + + // Start process enrichment thread + self.start_process_enrichment(connections.clone())?; + + // Start snapshot provider for UI + self.start_snapshot_provider(connections.clone())?; + + // Start cleanup thread + self.start_cleanup_thread(connections)?; + + // Mark loading as complete after a short delay + let is_loading = Arc::clone(&self.is_loading); thread::spawn(move || { - log::info!("Starting packet capture thread"); - if let Err(e) = - network::packet_capture_thread(interface_name, packet_tx, should_stop_capture) - { - error!( - "Packet capture thread failed (this is normal if not running as root): {}", - e - ); - log::info!("Packet capture disabled, will rely on platform connections only"); - } + thread::sleep(Duration::from_millis(500)); + is_loading.store(false, Ordering::Relaxed); }); - // --- Connection Management Thread --- - let monitor_clone: Arc> = Arc::clone(&self.network_monitor); - let connections_shared_clone: Arc>> = - Arc::clone(&self.connections_data_shared); - let tick_rate = self.config.refresh_interval; - let should_stop_mgmt = Arc::clone(&self.should_stop); + Ok(()) + } + + /// Start packet capture and processing pipeline + fn start_packet_capture_pipeline( + &self, + connections: Arc>, + ) -> Result<()> { + // Create packet channel + let (packet_tx, packet_rx) = channel::unbounded(); + + // Start capture thread + self.start_capture_thread(packet_tx)?; + + // Start multiple packet processing threads + let num_processors = thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(4) + .min(4); + + for i in 0..num_processors { + self.start_packet_processor(i, packet_rx.clone(), connections.clone()); + } + + Ok(()) + } + + /// Start packet capture thread + fn start_capture_thread(&self, packet_tx: Sender>) -> Result<()> { + let capture_config = CaptureConfig { + interface: self.config.interface.clone(), + filter: self.config.bpf_filter.clone(), + ..Default::default() + }; + + let should_stop = Arc::clone(&self.should_stop); + let stats = Arc::clone(&self.stats); + thread::spawn(move || { - log::info!("Starting connection management thread"); + match setup_packet_capture(capture_config) { + Ok(capture) => { + info!("Packet capture started successfully"); + let mut reader = PacketReader::new(capture); + let mut packets_read = 0u64; + let mut last_log = Instant::now(); + let mut last_stats_check = Instant::now(); - // Add a small delay to let the UI render first - thread::sleep(Duration::from_millis(100)); + loop { + if should_stop.load(Ordering::Relaxed) { + info!("Capture thread stopping"); + break; + } - // Do initial connection discovery - log::info!("Performing initial connection discovery..."); - match monitor_clone.lock().unwrap().get_connections() { - Ok(initial_conns) => { - log::info!( - "Initial discovery found {} connections", - initial_conns.len() - ); - *connections_shared_clone.lock().unwrap() = initial_conns; - } - Err(e) => { - error!("Error in initial connection discovery: {}", e); - } - } + match reader.next_packet() { + Ok(Some(packet)) => { + packets_read += 1; - loop { - // Process all pending packets from the queue (may be empty if capture failed) - let packets: Vec<_> = packet_rx.try_iter().collect(); - if !packets.is_empty() { - log::debug!("Processing {} packets", packets.len()); - for packet_data in packets { - monitor_clone.lock().unwrap().process_packet(&packet_data); - } - } + // Log first packet immediately + if packets_read == 1 { + info!("First packet captured! Size: {} bytes", packet.len()); + } - // Update shared connections periodically - match monitor_clone.lock().unwrap().get_connections() { - Ok(conns) => { - log::debug!( - "Connection management thread: Found {} connections", - conns.len() - ); - *connections_shared_clone.lock().unwrap() = conns; - } - Err(e) => { - error!("Error getting connections in management thread: {}", e); - } - } + // Log every 100 packets or every 5 seconds + if packets_read % 100 == 0 + || last_log.elapsed() > Duration::from_secs(5) + { + info!("Read {} packets so far", packets_read); + last_log = Instant::now(); + } - if should_stop_mgmt.load(Ordering::Relaxed) { - log::info!("Connection management thread stopping"); - break; - } - - thread::sleep(Duration::from_millis(tick_rate / 10)); // Sleep for 1/10th of the tick rate or in other words we update connections 10 times per tick - } - }); - - // --- Process Information Fetching Thread --- - let monitor_clone_procs: Arc> = Arc::clone(&self.network_monitor); - let connections_shared_procs: Arc>> = - Arc::clone(&self.connections_data_shared); - let processes_shared_clone: Arc>> = - Arc::clone(&self.processes_data_shared); - let should_stop_platform_info = Arc::clone(&self.should_stop); - thread::spawn(move || { - loop { - thread::sleep(PROCESS_INFO_UPDATE_INTERVAL); - - let connections_to_check = connections_shared_procs.lock().unwrap().clone(); - let mut collected_processes: HashMap = HashMap::new(); - - for conn in connections_to_check { - if conn.pid.is_none() { - if let Some(process) = monitor_clone_procs - .lock() - .unwrap() - .get_platform_process_for_connection(&conn) - { - if !process.name.is_empty() { - collected_processes.insert(process.pid, process); + if packet_tx.send(packet).is_err() { + warn!("Packet channel closed"); + break; + } + } + Ok(None) => { + // Timeout - check stats every second + if last_stats_check.elapsed() > Duration::from_secs(1) { + if let Ok(capture_stats) = reader.stats() { + if capture_stats.received > 0 { + debug!( + "Capture stats - Received: {}, Dropped: {}", + capture_stats.received, capture_stats.dropped + ); + } + stats + .packets_dropped + .store(capture_stats.dropped as u64, Ordering::Relaxed); + } + last_stats_check = Instant::now(); + } + } + Err(e) => { + error!("Capture error: {}", e); + break; } } } - } - if !collected_processes.is_empty() { - let mut processes_guard = processes_shared_clone.lock().unwrap(); - for (pid, process) in collected_processes { - processes_guard.insert(pid, process); - } + info!( + "Capture thread exiting, total packets read: {}", + packets_read + ); } - if should_stop_platform_info.load(Ordering::Relaxed) { - log::info!("Process information thread stopping"); - break; + Err(e) => { + error!("Failed to start packet capture: {}", e); + error!( + "Make sure you have permission to capture packets (try running with sudo)" + ); + warn!("Application will run in process-only mode"); } } }); - log::info!( - "App::start_capture - All threads started in {:?}", - start_time.elapsed() - ); - - // Don't mark loading as complete here - let the background thread discovery do that - self.loading_message = "Threads started, discovering connections...".to_string(); - Ok(()) } - pub fn handle_key(&mut self, key: KeyEvent) -> Option { - match self.mode { - ViewMode::Overview => self.handle_overview_keys(key), - ViewMode::ConnectionDetails => self.handle_details_keys(key), - ViewMode::Help => self.handle_help_keys(key), - } + /// Start a packet processor thread + fn start_packet_processor( + &self, + id: usize, + packet_rx: Receiver>, + connections: Arc>, + ) { + let should_stop = Arc::clone(&self.should_stop); + let stats = Arc::clone(&self.stats); + let parser_config = ParserConfig { + enable_dpi: self.config.enable_dpi, + ..Default::default() + }; + + thread::spawn(move || { + info!("Packet processor {} started", id); + let parser = PacketParser::with_config(parser_config); + let mut batch = Vec::new(); + let mut total_processed = 0u64; + let mut last_log = Instant::now(); + + loop { + if should_stop.load(Ordering::Relaxed) { + info!("Packet processor {} stopping", id); + break; + } + + // Collect packets in batches + batch.clear(); + let deadline = Instant::now() + Duration::from_millis(10); + + while batch.len() < 100 && Instant::now() < deadline { + match packet_rx.recv_timeout(Duration::from_millis(1)) { + Ok(packet) => batch.push(packet), + Err(_) => break, + } + } + + // Process batch + let mut parsed_count = 0; + for packet_data in &batch { + if let Some(parsed) = parser.parse_packet(packet_data) { + update_connection(&connections, parsed, &stats); + parsed_count += 1; + } + } + + if !batch.is_empty() { + total_processed += batch.len() as u64; + stats + .packets_processed + .fetch_add(batch.len() as u64, Ordering::Relaxed); + + // Log progress + if total_processed % 100 == 0 || last_log.elapsed() > Duration::from_secs(5) { + debug!( + "Processor {}: {} packets processed ({} parsed)", + id, total_processed, parsed_count + ); + last_log = Instant::now(); + } + } + } + + info!( + "Packet processor {} exiting, total processed: {}", + id, total_processed + ); + }); } - pub fn shutdown(&mut self) { - log::info!("App shutting down, signaling threads to stop"); - self.should_stop.store(true, Ordering::Relaxed); - } + /// Start process enrichment thread + fn start_process_enrichment( + &self, + connections: Arc>, + ) -> Result<()> { + let process_lookup = create_process_lookup()?; + let should_stop = Arc::clone(&self.should_stop); + let interval = Duration::from_secs(self.config.process_lookup_interval); - fn handle_overview_keys(&mut self, key: KeyEvent) -> Option { - match key.code { - KeyCode::Char('q') | KeyCode::Char('c') - if key.modifiers.contains(KeyModifiers::CONTROL) => - { - Some(Action::Quit) - } - KeyCode::Char('r') => Some(Action::Refresh), - KeyCode::Down => { - if !self.connections.is_empty() { - self.selected_connection_idx = - (self.selected_connection_idx + 1) % self.connections.len(); - self.selected_connection = - Some(self.connections[self.selected_connection_idx].clone()); - } - None - } - KeyCode::Up => { - if !self.connections.is_empty() { - self.selected_connection_idx = self - .selected_connection_idx - .checked_sub(1) - .unwrap_or(self.connections.len() - 1); - self.selected_connection = - Some(self.connections[self.selected_connection_idx].clone()); - } - None - } - KeyCode::Enter => { - if !self.connections.is_empty() { - self.mode = ViewMode::ConnectionDetails; - } - None - } - KeyCode::Char('h') => { - self.mode = ViewMode::Help; - None - } - KeyCode::Char('l') => { - self.show_locations = !self.show_locations; - None - } - KeyCode::Char('d') => { - self.show_hostnames = !self.show_hostnames; - if !self.show_hostnames { - self.dns_cache.clear(); - } - None - } - _ => None, - } - } + thread::spawn(move || { + info!("Process enrichment thread started"); + let mut last_refresh = Instant::now(); - fn handle_details_keys(&mut self, key: KeyEvent) -> Option { - match key.code { - KeyCode::Esc | KeyCode::Char('q') => { - self.mode = ViewMode::Overview; - None - } - KeyCode::Up | KeyCode::Down => { - self.detail_focus = match self.detail_focus { - DetailFocusField::LocalIp => DetailFocusField::RemoteIp, - DetailFocusField::RemoteIp => DetailFocusField::LocalIp, - }; - None - } - KeyCode::Char('c') => { - if let Some(conn) = self.connections.get(self.selected_connection_idx) { - let ip_to_copy = match self.detail_focus { - DetailFocusField::LocalIp => conn.local_addr.ip().to_string(), - DetailFocusField::RemoteIp => conn.remote_addr.ip().to_string(), - }; - if let Ok(mut clipboard) = Clipboard::new() { - if let Err(e) = clipboard.set_text(ip_to_copy.clone()) { - error!("Failed to copy IP to clipboard: {}", e); + loop { + if should_stop.load(Ordering::Relaxed) { + info!("Process enrichment thread stopping"); + break; + } + + // Refresh process lookup periodically + if last_refresh.elapsed() > Duration::from_secs(5) { + if let Err(e) = process_lookup.refresh() { + debug!("Process lookup refresh failed: {}", e); + } + last_refresh = Instant::now(); + } + + // Enrich connections without process info + let mut enriched = 0; + for mut entry in connections.iter_mut() { + if entry.process_name.is_none() { + if let Some((pid, name)) = process_lookup.get_process_for_connection(&entry) + { + entry.pid = Some(pid); + entry.process_name = Some(name); + enriched += 1; } } } - None - } - _ => None, - } - } - fn handle_help_keys(&mut self, key: KeyEvent) -> Option { - match key.code { - KeyCode::Esc | KeyCode::Char('q') | KeyCode::Char('h') => { - self.mode = ViewMode::Overview; - None - } - _ => None, - } - } - - fn get_connection_key(&self, conn: &Connection) -> String { - format!( - "{:?}-{}-{}-{:?}", - conn.protocol, - conn.local_addr, - conn.remote_addr, - conn.state() - ) - } - - fn find_connection_index_by_key(&self, target_key: &str) -> Option { - self.connections - .iter() - .position(|conn| self.get_connection_key(conn) == target_key) - } - - pub fn on_tick(&mut self) -> Result<()> { - let selected_conn_key = self - .selected_connection - .as_ref() - .map(|sc| self.get_connection_key(sc)); - - let mut new_connections_list = self.connections_data_shared.lock().unwrap().clone(); - log::debug!( - "on_tick: Processing {} connections from shared data", - new_connections_list.len() - ); - - // Update loading status based on connections availability - if self.is_loading { - if !new_connections_list.is_empty() { - self.is_loading = false; - self.loading_message.clear(); - } else { - // Update spinner animation and vary the loading message - self.loading_spinner_index = (self.loading_spinner_index + 1) % 4; - - // Vary the loading message to show progress - let elapsed = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs() - % 10; - - self.loading_message = match elapsed { - 0..=2 => "Scanning network interfaces...".to_string(), - 3..=5 => "Discovering active connections...".to_string(), - 6..=8 => "Gathering process information...".to_string(), - _ => "Please wait, this may take 10-30 seconds...".to_string(), - }; - } - } - - let mut keys_to_process = Vec::new(); - for conn in &new_connections_list { - keys_to_process.push(self.get_connection_key(conn)); - } - - for key in keys_to_process { - self.connection_order.entry(key).or_insert_with(|| { - let index = self.next_order_index; - self.next_order_index += 1; - index - }); - } - - new_connections_list.sort_by(|a, b| { - let is_a_loopback = a.local_addr.ip().is_loopback() || a.remote_addr.ip().is_loopback(); - let is_b_loopback = b.local_addr.ip().is_loopback() || b.remote_addr.ip().is_loopback(); - is_a_loopback.cmp(&is_b_loopback).then_with(|| { - let key_a = self.get_connection_key(a); - let key_b = self.get_connection_key(b); - let order_a = self.connection_order.get(&key_a).unwrap_or(&usize::MAX); - let order_b = self.connection_order.get(&key_b).unwrap_or(&usize::MAX); - order_a.cmp(order_b) - }) - }); - - self.connections = new_connections_list; - - if let Some(key) = selected_conn_key { - if let Some(idx) = self.find_connection_index_by_key(&key) { - self.selected_connection_idx = idx; - self.selected_connection = Some(self.connections[idx].clone()); - } else if !self.connections.is_empty() { - self.selected_connection_idx = 0; - self.selected_connection = Some(self.connections[0].clone()); - } else { - self.selected_connection_idx = 0; - self.selected_connection = None; - } - } else if !self.connections.is_empty() && self.selected_connection.is_none() { - self.selected_connection_idx = 0; - self.selected_connection = Some(self.connections[0].clone()); - } - - if let Ok(shared_procs_guard) = self.processes_data_shared.lock() { - self.processes = shared_procs_guard.clone(); - } - - for conn in &mut self.connections { - if let Some(pid) = conn.pid { - if let Some(cached_process_info) = self.processes.get(&pid) { - if !cached_process_info.name.is_empty() { - conn.process_name = Some(cached_process_info.name.clone()); - } + if enriched > 0 { + debug!("Enriched {} connections with process info", enriched); } + + thread::sleep(interval); } - } + }); Ok(()) } - /// Format a socket address for display - pub fn format_socket_addr(&mut self, addr: std::net::SocketAddr) -> String { - if self.show_hostnames { - // Try to resolve hostname - if let Some(hostname) = self.dns_cache.get(&addr.ip()) { - format!("{}:{}", hostname, addr.port()) - } else { - // Attempt to resolve hostname if not in cache - if let Ok(hostname) = dns_lookup::lookup_addr(&addr.ip()) { - if hostname != addr.ip().to_string() { - // Cache the result - self.dns_cache.insert(addr.ip(), hostname.clone()); - return format!("{}:{}", hostname, addr.port()); - } + /// Start snapshot provider thread for UI updates + fn start_snapshot_provider(&self, connections: Arc>) -> Result<()> { + let snapshot = Arc::clone(&self.connections_snapshot); + let should_stop = Arc::clone(&self.should_stop); + let stats = Arc::clone(&self.stats); + let service_lookup = Arc::clone(&self.service_lookup); + let filter_localhost = self.config.filter_localhost; + let refresh_interval = Duration::from_millis(self.config.refresh_interval); + + thread::spawn(move || { + info!("Snapshot provider thread started"); + + loop { + if should_stop.load(Ordering::Relaxed) { + info!("Snapshot provider thread stopping"); + break; } - // Cache the IP as fallback to avoid repeated lookups - self.dns_cache.insert(addr.ip(), addr.ip().to_string()); - addr.to_string() + + // Create snapshot + let start = Instant::now(); + let total_connections = connections.len(); + + let mut snapshot_data: Vec = connections + .iter() + .map(|entry| { + let mut conn = entry.value().clone(); + + // Enrich with service name + if conn.service_name.is_none() { + if let Some(service) = + service_lookup.lookup(conn.local_addr.port(), conn.protocol) + { + conn.service_name = Some(service.to_string()); + } else if let Some(service) = + service_lookup.lookup(conn.remote_addr.port(), conn.protocol) + { + conn.service_name = Some(service.to_string()); + } + } + + conn + }) + .filter(|conn| { + // Apply filters + if filter_localhost { + !(conn.local_addr.ip().is_loopback() + && conn.remote_addr.ip().is_loopback()) + } else { + true + } + }) + .filter(|conn| conn.is_active()) + .collect(); + + // Sort by last activity + snapshot_data.sort_by(|a, b| b.last_activity.cmp(&a.last_activity)); + + let filtered_count = snapshot_data.len(); + + // Update snapshot + *snapshot.write().unwrap() = snapshot_data; + + // Update stats + stats + .connections_tracked + .store(total_connections as u64, Ordering::Relaxed); + *stats.last_update.write().unwrap() = Instant::now(); + + debug!( + "Snapshot updated in {:?} - Total: {}, Filtered: {}", + start.elapsed(), + total_connections, + filtered_count + ); + + thread::sleep(refresh_interval); } - } else { - addr.to_string() + }); + + Ok(()) + } + + /// Start cleanup thread to remove old connections + fn start_cleanup_thread(&self, connections: Arc>) -> Result<()> { + let should_stop = Arc::clone(&self.should_stop); + let timeout = Duration::from_secs(self.config.connection_timeout); + + thread::spawn(move || { + info!("Cleanup thread started"); + + loop { + if should_stop.load(Ordering::Relaxed) { + info!("Cleanup thread stopping"); + break; + } + + // Remove inactive connections + let now = SystemTime::now(); + let mut removed = 0; + + connections.retain(|_, conn| { + let should_keep = now + .duration_since(conn.last_activity) + .unwrap_or(Duration::from_secs(0)) + < timeout; + + if !should_keep { + removed += 1; + } + + should_keep + }); + + if removed > 0 { + debug!("Removed {} inactive connections", removed); + } + + thread::sleep(Duration::from_secs(10)); + } + }); + + Ok(()) + } + + /// Get current connections for UI display + pub fn get_connections(&self) -> Vec { + self.connections_snapshot.read().unwrap().clone() + } + + /// Get application statistics + pub fn get_stats(&self) -> AppStats { + AppStats { + packets_processed: AtomicU64::new(self.stats.packets_processed.load(Ordering::Relaxed)), + packets_dropped: AtomicU64::new(self.stats.packets_dropped.load(Ordering::Relaxed)), + connections_tracked: AtomicU64::new( + self.stats.connections_tracked.load(Ordering::Relaxed), + ), + last_update: RwLock::new(*self.stats.last_update.read().unwrap()), } } - /// Refresh the application state - pub fn refresh(&mut self) -> Result<()> { - // Trigger a fresh connection update - self.on_tick() + /// Check if application is still loading + pub fn is_loading(&self) -> bool { + self.is_loading.load(Ordering::Relaxed) } - /// Get the current spinner character for loading animation - pub fn get_spinner_char(&self) -> &str { - const SPINNER_CHARS: &[&str] = &["⠋", "⠙", "⠹", "⠸"]; - SPINNER_CHARS[self.loading_spinner_index] + /// Stop all threads gracefully + pub fn stop(&self) { + info!("Stopping application"); + self.should_stop.store(true, Ordering::Relaxed); } } -#[cfg(test)] -mod tests { - use super::*; - use crate::config::Config; - use crate::i18n::I18n; - use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +/// Update or create a connection from a parsed packet +fn update_connection( + connections: &DashMap, + parsed: ParsedPacket, + stats: &AppStats, +) { + let key = parsed.connection_key.clone(); + let now = SystemTime::now(); - fn create_test_app() -> App { - let config = Config::default(); - let i18n = I18n::new("en").unwrap(); - App::new(config, i18n).unwrap() - } + connections + .entry(key.clone()) + .and_modify(|conn| { + *conn = merge_packet_into_connection(conn.clone(), &parsed, now); + }) + .or_insert_with(|| { + debug!("New connection detected: {}", key); + create_connection_from_packet(&parsed, now) + }); +} - #[test] - fn test_dns_toggle_functionality() { - let mut app = create_test_app(); - - // Initially DNS hostnames should be disabled - assert!(!app.show_hostnames); - assert!(app.dns_cache.is_empty()); - - // Test IP address formatting without DNS - let test_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53); - let formatted = app.format_socket_addr(test_addr); - assert_eq!(formatted, "8.8.8.8:53"); - - // Enable DNS resolution - app.show_hostnames = true; - - // Format the same address with DNS enabled - let formatted_with_dns = app.format_socket_addr(test_addr); - // Should either be resolved hostname or cached IP - assert!(!formatted_with_dns.is_empty()); - assert!(formatted_with_dns.contains(":53")); - - // Check that cache is populated - assert!(app.dns_cache.contains_key(&test_addr.ip())); - } - - #[test] - fn test_dns_cache_behavior() { - let mut app = create_test_app(); - app.show_hostnames = true; - - let test_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); - - // First call should populate cache - let first_result = app.format_socket_addr(test_addr); - assert!(app.dns_cache.contains_key(&test_addr.ip())); - - // Second call should use cache - let second_result = app.format_socket_addr(test_addr); - assert_eq!(first_result, second_result); - - // Disable DNS and clear cache - app.show_hostnames = false; - app.dns_cache.clear(); - - let ip_only_result = app.format_socket_addr(test_addr); - assert_eq!(ip_only_result, "127.0.0.1:8080"); - } - - #[test] - fn test_view_mode_switching() { - let mut app = create_test_app(); - - // Should start in Overview mode - assert!(matches!(app.mode, ViewMode::Overview)); - - // Test switching to Help - app.mode = ViewMode::Help; - assert!(matches!(app.mode, ViewMode::Help)); - - // Test switching to Connection Details - app.mode = ViewMode::ConnectionDetails; - assert!(matches!(app.mode, ViewMode::ConnectionDetails)); - } - - #[test] - fn test_connection_selection() { - let mut app = create_test_app(); - - // Initially no connections - assert!(app.connections.is_empty()); - assert_eq!(app.selected_connection_idx, 0); - assert!(app.selected_connection.is_none()); - - // Add some test connections - let conn1 = crate::network::Connection::new( - crate::network::Protocol::TCP, - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 3000), - crate::network::ConnectionState::Established, - ); - let conn2 = crate::network::Connection::new( - crate::network::Protocol::UDP, - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8081), - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 3001), - crate::network::ConnectionState::Listen, - ); - - app.connections = vec![conn1, conn2]; - - // Test that we can select connections - assert_eq!(app.connections.len(), 2); - app.selected_connection_idx = 1; - assert_eq!(app.selected_connection_idx, 1); - } - - #[test] - fn test_detail_focus_field() { - let mut app = create_test_app(); - - // Should start with LocalIp focused - assert!(matches!(app.detail_focus, DetailFocusField::LocalIp)); - - // Test switching focus - app.detail_focus = DetailFocusField::RemoteIp; - assert!(matches!(app.detail_focus, DetailFocusField::RemoteIp)); - } - - #[test] - fn test_app_initialization() { - let config = Config::default(); - let i18n = I18n::new("en").unwrap(); - let app_result = App::new(config, i18n); - - assert!(app_result.is_ok()); - let app = app_result.unwrap(); - - // Check initial state - assert!(matches!(app.mode, ViewMode::Overview)); - assert!(app.show_locations); - assert!(!app.show_hostnames); - assert!(app.connections.is_empty()); - assert!(app.dns_cache.is_empty()); - assert_eq!(app.selected_connection_idx, 0); - assert!(app.is_loading); // Should start in loading state - } - - #[test] - fn test_loading_state_and_spinner() { - let mut app = create_test_app(); - - // Should start loading - assert!(app.is_loading); - assert!(!app.loading_message.is_empty()); - - // Test spinner animation - let first_char = app.get_spinner_char().to_string(); - app.loading_spinner_index = (app.loading_spinner_index + 1) % 4; - let second_char = app.get_spinner_char().to_string(); - assert_ne!(first_char, second_char); - - // Test loading completion - app.is_loading = false; - app.loading_message.clear(); - assert!(!app.is_loading); - assert!(app.loading_message.is_empty()); +impl Drop for App { + fn drop(&mut self) { + self.stop(); + // Give threads time to stop gracefully + thread::sleep(Duration::from_millis(100)); } } diff --git a/src/main.rs b/src/main.rs index 161a950..391eeb5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,8 +9,6 @@ use std::path::Path; use std::time::Duration; mod app; -mod config; -mod i18n; mod network; mod ui; @@ -18,12 +16,12 @@ fn main() -> Result<()> { // Set up logging setup_logging()?; - info!("Starting RustNet"); + info!("Starting RustNet Monitor"); // Parse command line arguments let matches = Command::new("rustnet") .version("0.1.0") - .author("Your Name") + .author("Network Monitor") .about("Cross-platform network monitoring tool") .arg( Arg::new("interface") @@ -34,74 +32,67 @@ fn main() -> Result<()> { .required(false), ) .arg( - Arg::new("config") - .short('c') - .long("config") - .value_name("FILE") - .help("Path to configuration file") - .required(false), + Arg::new("no-localhost") + .long("no-localhost") + .help("Filter out localhost connections") + .action(clap::ArgAction::SetTrue), ) .arg( - Arg::new("language") - .short('l') - .long("language") - .value_name("LANG") - .help("Interface language (en, fr, etc.)") - .required(false), - ) - .arg( - Arg::new("packet_processing_interval") - .short('P') - .long("packet-processing-interval") + Arg::new("refresh-interval") + .short('r') + .long("refresh-interval") .value_name("MILLISECONDS") - .help("Interval for packet processing loop sleep (ms). 0 for continuous.") + .help("UI refresh interval in milliseconds") .value_parser(clap::value_parser!(u64)) + .default_value("1000") .required(false), ) + .arg( + Arg::new("no-dpi") + .long("no-dpi") + .help("Disable deep packet inspection") + .action(clap::ArgAction::SetTrue), + ) .get_matches(); - // Initialize configuration - let config_path = matches.get_one::("config").map(String::as_str); - let mut config = config::Config::load(config_path)?; + // Build configuration from command line arguments + let mut config = app::Config::default(); - info!("Configuration loaded"); - - // Override config with command line arguments if provided if let Some(interface) = matches.get_one::("interface") { config.interface = Some(interface.to_string()); info!("Using interface: {}", interface); } - if let Some(language) = matches.get_one::("language") { - config.language = language.to_string(); - info!("Using language: {}", language); + if matches.get_flag("no-localhost") { + config.filter_localhost = true; + info!("Filtering localhost connections"); } - if let Some(interval) = matches.get_one::("packet_processing_interval") { - config.packet_processing_interval_ms = *interval; - info!("Using packet processing interval: {}ms", interval); + if let Some(interval) = matches.get_one::("refresh-interval") { + config.refresh_interval = *interval; + info!("Using refresh interval: {}ms", interval); } - // Initialize internationalization - let i18n = i18n::I18n::new(&config.language)?; - info!( - "Internationalization initialized for language: {}", - config.language - ); + if matches.get_flag("no-dpi") { + config.enable_dpi = false; + info!("Deep packet inspection disabled"); + } // Set up terminal let backend = CrosstermBackend::new(io::stdout()); let mut terminal = ui::setup_terminal(backend)?; info!("Terminal UI initialized"); - // Create app state - let app = app::App::new(config, i18n)?; - info!("Application state initialized"); + // Create and start the application + let mut app = app::App::new(config)?; + app.start()?; + info!("Application started"); - // Run the application - let res = run_app(&mut terminal, app); + // Run the UI loop + let res = run_ui_loop(&mut terminal, &app); - // Restore terminal + // Cleanup + app.stop(); ui::restore_terminal(&mut terminal)?; // Return any error that occurred @@ -110,7 +101,7 @@ fn main() -> Result<()> { println!("Error: {}", err); } - info!("RustNet shutting down"); + info!("RustNet Monitor shutting down"); Ok(()) } @@ -135,66 +126,95 @@ fn setup_logging() -> Result<()> { Ok(()) } -fn run_app( +fn run_ui_loop( terminal: &mut ui::Terminal, - mut app: app::App, + app: &app::App, ) -> Result<()> { - let tick_rate = Duration::from_millis(200); // Faster refresh for better loading animation + let tick_rate = Duration::from_millis(200); let mut last_tick = std::time::Instant::now(); - let mut capture_started = false; + let mut ui_state = ui::UIState::default(); loop { - // Draw the UI first to show loading screen immediately + // Get current connections and stats + let connections = app.get_connections(); + let stats = app.get_stats(); + + // Draw the UI terminal.draw(|f| { - if let Err(err) = ui::draw(f, &mut app) { + if let Err(err) = ui::draw(f, app, &ui_state, &connections, &stats) { error!("UI draw error: {}", err); } })?; - // Start capture on first iteration (after first UI render) - if !capture_started { - if let Err(err) = app.start_capture() { - error!("Failed to start network capture: {}", err); - // Continue anyway, some features may still work - } - info!("Network capture started"); - capture_started = true; - } - - // Handle timeout (for periodic UI updates) + // Handle timeout for periodic updates let timeout = tick_rate .checked_sub(last_tick.elapsed()) .unwrap_or(Duration::from_secs(0)); - // Update app state on tick (especially important during loading for spinner animation) - let should_tick = last_tick.elapsed() >= tick_rate; - if should_tick { - app.on_tick()?; + // Check if we should tick + if last_tick.elapsed() >= tick_rate { last_tick = std::time::Instant::now(); } - // Handle input events (use shorter timeout during loading for responsive spinner) - let input_timeout = if app.is_loading { - Duration::from_millis(100) - } else { - timeout - }; - - if crossterm::event::poll(input_timeout)? { + // Handle input events + if crossterm::event::poll(timeout)? { if let crossterm::event::Event::Key(key) = crossterm::event::read()? { - // Handle key event - if let Some(action) = app.handle_key(key) { - match action { - app::Action::Quit => { - info!("User requested application exit"); - app.shutdown(); - break; - } - app::Action::Refresh => { - info!("User requested refresh"); - app.refresh()?; - } // Add more actions as needed + use crossterm::event::{KeyCode, KeyModifiers}; + + match (key.code, key.modifiers) { + // Quit + (KeyCode::Char('q'), _) | (KeyCode::Char('c'), KeyModifiers::CONTROL) => { + info!("User requested application exit"); + break; } + + // Tab navigation + (KeyCode::Tab, _) => { + ui_state.selected_tab = (ui_state.selected_tab + 1) % 3; + } + + // Help toggle + (KeyCode::Char('h'), _) => { + ui_state.show_help = !ui_state.show_help; + if ui_state.show_help { + ui_state.selected_tab = 2; // Switch to help tab + } else { + ui_state.selected_tab = 0; // Back to overview + } + } + + // Navigation in connection list + (KeyCode::Up, _) | (KeyCode::Char('k'), _) => { + if !connections.is_empty() && ui_state.selected_connection > 0 { + ui_state.selected_connection -= 1; + } + } + + (KeyCode::Down, _) | (KeyCode::Char('j'), _) => { + if !connections.is_empty() + && ui_state.selected_connection < connections.len().saturating_sub(1) + { + ui_state.selected_connection += 1; + } + } + + // Enter to view details + (KeyCode::Enter, _) => { + if ui_state.selected_tab == 0 && !connections.is_empty() { + ui_state.selected_tab = 1; // Switch to details view + } + } + + // Escape to go back + (KeyCode::Esc, _) => { + if ui_state.selected_tab == 1 { + ui_state.selected_tab = 0; // Back to overview + } else if ui_state.selected_tab == 2 { + ui_state.selected_tab = 0; // Back to overview from help + } + } + + _ => {} } } } diff --git a/src/network/capture.rs b/src/network/capture.rs new file mode 100644 index 0000000..e77e793 --- /dev/null +++ b/src/network/capture.rs @@ -0,0 +1,380 @@ +// network/capture.rs - Packet capture setup and utilities +use anyhow::{Result, anyhow}; +use pcap::{Active, Capture, Device, Error as PcapError}; +use std::time::Duration; + +/// Packet capture configuration +#[derive(Debug, Clone)] +pub struct CaptureConfig { + /// Network interface name (None for default) + pub interface: Option, + /// Promiscuous mode + pub promiscuous: bool, + /// Snapshot length (bytes to capture per packet) + pub snaplen: i32, + /// Buffer size for packet capture + pub buffer_size: i32, + /// Read timeout in milliseconds + pub timeout_ms: i32, + /// BPF filter string + pub filter: Option, +} + +impl Default for CaptureConfig { + fn default() -> Self { + Self { + interface: None, + promiscuous: true, + snaplen: 200, // Limit packet size to keep more in buffer (like Sniffnet) + buffer_size: 2_000_000, // 2MB buffer (same as Sniffnet) + timeout_ms: 150, // 150ms timeout for UI responsiveness (like Sniffnet) + filter: None, // Start without filter to ensure we see packets + } + } +} + +/// Find the best active network device +fn find_best_device() -> Result { + let devices = Device::list()?; + + log::info!( + "Scanning {} devices for best active interface...", + devices.len() + ); + + // Log all devices for debugging + for d in &devices { + let has_valid_ip = d.addresses.iter().any(|addr| match &addr.addr { + std::net::IpAddr::V4(v4) => { + !v4.is_link_local() && !v4.is_loopback() && !v4.is_unspecified() + } + std::net::IpAddr::V6(v6) => { + !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified() + } + }); + + log::debug!( + " Device: {} [up: {}, running: {}, has_ip: {}]", + d.name, + d.flags.is_up(), + d.flags.is_running(), + has_valid_ip + ); + } + + if devices.is_empty() { + return Err(anyhow!("No network devices found")); + } + + // Find the best active device + let suitable_device = devices + .iter() + // First priority: up, running, and has a valid IP address + .find(|d| { + !d.name.starts_with("lo") + && d.name != "any" + && d.flags.is_up() + && d.flags.is_running() + && d.addresses.iter().any(|addr| { + match &addr.addr { + std::net::IpAddr::V4(v4) => { + !v4.is_link_local() && !v4.is_loopback() && !v4.is_unspecified() + } + std::net::IpAddr::V6(v6) => false, // Skip IPv6 for now + } + }) + }) + // Second priority: common active interface names + .or_else(|| { + devices.iter().find(|d| { + (d.name == "en0" || d.name == "en1" || d.name.starts_with("eth")) + && d.flags.is_up() + && d.addresses.iter().any(|addr| addr.addr.is_ipv4()) + }) + }) + // Third priority: any up interface with valid addresses (excluding problematic ones) + .or_else(|| { + devices.iter().find(|d| { + !d.name.starts_with("lo") && + !d.name.starts_with("ap") && // Skip Apple's ap interfaces + !d.name.starts_with("awdl") && // Skip Apple Wireless Direct + !d.name.starts_with("llw") && // Skip Low latency WLAN + !d.name.starts_with("bridge") && // Skip bridges + !d.name.starts_with("utun") && // Skip tunnels + !d.name.starts_with("vmnet") && // Skip VM interfaces + d.name != "any" && + d.flags.is_up() && + !d.addresses.is_empty() + }) + }) + .cloned(); + + match suitable_device { + Some(device) => { + log::info!( + "Selected active device: {} ({} addresses)", + device.name, + device.addresses.len() + ); + for addr in &device.addresses { + log::debug!(" Address: {}", addr.addr); + } + Ok(device) + } + None => { + log::error!("No suitable active network device found!"); + log::error!("Try specifying an interface manually with -i flag"); + Err(anyhow!( + "No active network interface found. Use -i to specify one manually." + )) + } + } +} + +/// Setup packet capture with the given configuration +pub fn setup_packet_capture(config: CaptureConfig) -> Result> { + // Find the capture device + let device = find_capture_device(&config.interface)?; + + log::info!( + "Setting up capture on device: {} ({})", + device.name, + device.desc.as_deref().unwrap_or("no description") + ); + + // Create capture handle + let mut cap = Capture::from_device(device)? + .promisc(config.promiscuous) + .snaplen(config.snaplen) + .buffer_size(config.buffer_size) + .timeout(config.timeout_ms) + .immediate_mode(true); // Parse packets ASAP (like Sniffnet) + + // Open the capture + let mut cap = cap.open()?; + + // Apply BPF filter if specified + if let Some(filter) = &config.filter { + log::info!("Applying BPF filter: {}", filter); + cap.filter(filter, true)?; + } + + // Note: We're not setting non-blocking mode as we're using timeout instead + + Ok(cap) +} + +/// Find a capture device by name or return the default +fn find_capture_device(interface_name: &Option) -> Result { + match interface_name { + Some(name) => { + log::info!("Looking for interface: {}", name); + + // List all devices + let devices = Device::list()?; + + // Find exact match first + if let Some(device) = devices.iter().find(|d| d.name == *name) { + return Ok(device.clone()); + } + + // Try case-insensitive match + let name_lower = name.to_lowercase(); + if let Some(device) = devices.iter().find(|d| d.name.to_lowercase() == name_lower) { + return Ok(device.clone()); + } + + // List available interfaces for error message + let available: Vec = devices + .iter() + .map(|d| { + format!( + "{} ({})", + d.name, + d.desc.as_deref().unwrap_or("no description") + ) + }) + .collect(); + + Err(anyhow!( + "Interface '{}' not found. Available interfaces:\n{}", + name, + available.join("\n") + )) + } + None => { + log::info!("No interface specified, using default"); + + // Try to get default device + match Device::lookup() { + Ok(Some(device)) => { + log::info!( + "Found default device: {} ({})", + device.name, + device.desc.as_deref().unwrap_or("no description") + ); + + // Check if the default device is actually active + let has_valid_ip = device.addresses.iter().any(|addr| { + match &addr.addr { + std::net::IpAddr::V4(v4) => { + !v4.is_link_local() && !v4.is_loopback() && !v4.is_unspecified() + } + std::net::IpAddr::V6(v6) => false, // Skip IPv6 for now + } + }); + + // Check if it's a problematic interface type + let is_problematic = device.name.starts_with("ap") + || device.name.starts_with("awdl") + || device.name.starts_with("llw") + || device.name.starts_with("bridge") + || device.name.starts_with("utun") + || device.name.starts_with("vmnet") + || device.name == "any" + || device.flags.is_loopback(); + + if device.flags.is_up() + && device.flags.is_running() + && has_valid_ip + && !is_problematic + { + log::info!("Default device appears active, using it"); + Ok(device) + } else { + log::warn!( + "Default device '{}' is not suitable (up: {}, running: {}, has_ip: {}, problematic: {})", + device.name, + device.flags.is_up(), + device.flags.is_running(), + has_valid_ip, + is_problematic + ); + log::info!("Looking for a better interface..."); + + // Fall through to the device selection logic below + find_best_device() + } + } + Ok(None) => { + log::info!("No default device found"); + find_best_device() + } + Err(e) => Err(e.into()), + } + } + } +} + +/// List available capture devices +pub fn list_devices() -> Result> { + let devices = Device::list()?; + + Ok(devices + .into_iter() + .map(|d| { + // Check if device is active by checking flags and addresses + let is_active = d.flags.is_up() + && d.flags.is_running() + && d.addresses.iter().any(|addr| { + // Has at least one non-link-local address + match &addr.addr { + std::net::IpAddr::V4(v4) => !v4.is_link_local() && !v4.is_loopback(), + std::net::IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast(), + } + }); + + DeviceInfo { + name: d.name, + description: d.desc, + addresses: d + .addresses + .into_iter() + .map(|addr| format!("{}", addr.addr)) + .collect(), + is_loopback: d.flags.is_loopback(), + is_up: d.flags.is_up(), + is_running: d.flags.is_running(), + is_active, + } + }) + .collect()) +} + +/// Information about a network device +#[derive(Debug, Clone)] +pub struct DeviceInfo { + pub name: String, + pub description: Option, + pub addresses: Vec, + pub is_loopback: bool, + pub is_up: bool, + pub is_running: bool, + pub is_active: bool, +} + +/// Simple packet reader that handles timeouts gracefully +pub struct PacketReader { + capture: Capture, +} + +impl PacketReader { + pub fn new(capture: Capture) -> Self { + Self { capture } + } + + /// Read next packet, returning None on timeout + pub fn next_packet(&mut self) -> Result>> { + match self.capture.next_packet() { + Ok(packet) => Ok(Some(packet.data.to_vec())), + Err(PcapError::TimeoutExpired) => Ok(None), + Err(e) => Err(e.into()), + } + } + + /// Get capture statistics + pub fn stats(&mut self) -> Result { + let stats = self.capture.stats()?; + Ok(CaptureStats { + received: stats.received, + dropped: stats.dropped, + if_dropped: stats.if_dropped, + }) + } +} + +/// Packet capture statistics +#[derive(Debug, Clone, Default)] +pub struct CaptureStats { + pub received: u32, + pub dropped: u32, + pub if_dropped: u32, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = CaptureConfig::default(); + assert!(config.promiscuous); + assert_eq!(config.snaplen, 1024); + assert!(config.filter.is_some()); + } + + #[test] + fn test_list_devices() { + // This might fail in some test environments + if let Ok(devices) = list_devices() { + for device in devices { + println!("Device: {} - {:?}", device.name, device.description); + println!(" Addresses: {:?}", device.addresses); + println!( + " Loopback: {}, Up: {}, Running: {}", + device.is_loopback, device.is_up, device.is_running + ); + } + } + } +} diff --git a/src/network/dpi/dns.rs b/src/network/dpi/dns.rs new file mode 100644 index 0000000..ec2dc4c --- /dev/null +++ b/src/network/dpi/dns.rs @@ -0,0 +1,75 @@ +use crate::network::types::{DnsInfo, DnsQueryType}; + +pub fn analyze_dns(payload: &[u8]) -> Option { + if payload.len() < 12 { + return None; + } + + let mut info = DnsInfo { + query_name: None, + query_type: None, + response_ips: Vec::new(), + is_response: false, + }; + + // DNS header flags + let flags = u16::from_be_bytes([payload[2], payload[3]]); + info.is_response = (flags & 0x8000) != 0; // QR bit + + // Question count + let qdcount = u16::from_be_bytes([payload[4], payload[5]]); + + if qdcount > 0 { + // Parse first question + let mut offset = 12; + let mut name = String::new(); + + // Parse domain name + while offset < payload.len() { + let label_len = payload[offset] as usize; + if label_len == 0 { + offset += 1; + break; + } + + if label_len >= 0xC0 { + // Compressed name - skip for simplicity + offset += 2; + break; + } + + if offset + 1 + label_len > payload.len() { + break; + } + + if !name.is_empty() { + name.push('.'); + } + + if let Ok(label) = std::str::from_utf8(&payload[offset + 1..offset + 1 + label_len]) { + name.push_str(label); + } + + offset += 1 + label_len; + } + + if !name.is_empty() { + info.query_name = Some(name); + } + + // Query type + if offset + 2 <= payload.len() { + let qtype = u16::from_be_bytes([payload[offset], payload[offset + 1]]); + info.query_type = Some(match qtype { + 1 => DnsQueryType::A, + 28 => DnsQueryType::AAAA, + 5 => DnsQueryType::CNAME, + 15 => DnsQueryType::MX, + 16 => DnsQueryType::TXT, + other => DnsQueryType::Other(other), + }); + } + } + + Some(info) +} diff --git a/src/network/dpi/http.rs b/src/network/dpi/http.rs new file mode 100644 index 0000000..3f16c81 --- /dev/null +++ b/src/network/dpi/http.rs @@ -0,0 +1,130 @@ +use crate::network::types::{HttpInfo, HttpVersion}; + +/// Analyze payload for HTTP protocol +pub fn analyze_http(payload: &[u8]) -> Option { + if !is_likely_http(payload) { + return None; + } + + let mut info = HttpInfo { + version: HttpVersion::Http11, + method: None, + host: None, + path: None, + status_code: None, + user_agent: None, + }; + + // Safe string conversion for HTTP parsing + let text = String::from_utf8_lossy(payload); + let lines: Vec<&str> = text.lines().collect(); + + if lines.is_empty() { + return None; + } + + // Parse first line + let first_line = lines[0]; + let parts: Vec<&str> = first_line.split_whitespace().collect(); + + if parts.len() >= 3 { + if first_line.starts_with("HTTP/") { + // Response line: HTTP/1.1 200 OK + info.version = parse_http_version(parts[0]); + info.status_code = parts[1].parse::().ok(); + } else if is_http_method(parts[0]) { + // Request line: GET /path HTTP/1.1 + info.method = Some(parts[0].to_string()); + info.path = Some(parts[1].to_string()); + if parts.len() >= 3 { + info.version = parse_http_version(parts[2]); + } + } else { + return None; // Not valid HTTP + } + } else { + return None; + } + + // Parse headers + for line in lines.iter().skip(1) { + if line.is_empty() { + break; // End of headers + } + + if let Some((key, value)) = line.split_once(':') { + let key = key.trim().to_lowercase(); + let value = value.trim(); + + match key.as_str() { + "host" => info.host = Some(value.to_string()), + "user-agent" => info.user_agent = Some(value.to_string()), + _ => {} + } + } + } + + Some(info) +} + +/// Quick check if payload might be HTTP +fn is_likely_http(payload: &[u8]) -> bool { + if payload.len() < 4 { + return false; + } + + // HTTP request methods + payload.starts_with(b"GET ") || + payload.starts_with(b"POST ") || + payload.starts_with(b"PUT ") || + payload.starts_with(b"DELETE ") || + payload.starts_with(b"HEAD ") || + payload.starts_with(b"OPTIONS ") || + payload.starts_with(b"CONNECT ") || + payload.starts_with(b"TRACE ") || + payload.starts_with(b"PATCH ") || + // HTTP responses + payload.starts_with(b"HTTP/1.0 ") || + payload.starts_with(b"HTTP/1.1 ") || + payload.starts_with(b"HTTP/2 ") +} + +fn is_http_method(s: &str) -> bool { + matches!( + s, + "GET" | "POST" | "PUT" | "DELETE" | "HEAD" | "OPTIONS" | "CONNECT" | "TRACE" | "PATCH" + ) +} + +fn parse_http_version(s: &str) -> HttpVersion { + match s { + "HTTP/1.0" => HttpVersion::Http10, + "HTTP/1.1" => HttpVersion::Http11, + "HTTP/2" => HttpVersion::Http2, + _ => HttpVersion::Http11, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_http_request() { + let payload = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let info = analyze_http(payload).unwrap(); + + assert_eq!(info.method.as_deref(), Some("GET")); + assert_eq!(info.path.as_deref(), Some("/index.html")); + assert_eq!(info.host.as_deref(), Some("example.com")); + } + + #[test] + fn test_http_response() { + let payload = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; + let info = analyze_http(payload).unwrap(); + + assert_eq!(info.status_code, Some(200)); + assert!(info.method.is_none()); + } +} diff --git a/src/network/dpi/mod.rs b/src/network/dpi/mod.rs new file mode 100644 index 0000000..09371b6 --- /dev/null +++ b/src/network/dpi/mod.rs @@ -0,0 +1,100 @@ +use crate::network::types::ApplicationProtocol; + +mod dns; +mod http; +mod quic; +mod tls; + +pub use dns::analyze_dns; +pub use http::analyze_http; +pub use quic::is_quic_packet; +pub use tls::analyze_tls; + +/// Result of DPI analysis +#[derive(Debug, Clone)] +pub struct DpiResult { + pub application: ApplicationProtocol, + pub confidence: f32, // 0.0 to 1.0 + pub needs_more_data: bool, // True if more packets would help +} + +/// Analyze a TCP packet payload +pub fn analyze_tcp_packet( + payload: &[u8], + local_port: u16, + remote_port: u16, + is_outgoing: bool, +) -> Option { + if payload.is_empty() { + return None; + } + + // Try protocols in order of likelihood/speed + + // 1. Check for HTTP (fast string matching) + if let Some(http_result) = http::analyze_http(payload) { + return Some(DpiResult { + application: ApplicationProtocol::Http(http_result), + confidence: 1.0, + needs_more_data: false, + }); + } + + // 2. Check for TLS/HTTPS (port 443 or TLS handshake) + if local_port == 443 || remote_port == 443 || tls::is_tls_handshake(payload) { + if let Some(tls_result) = tls::analyze_tls(payload) { + return Some(DpiResult { + application: ApplicationProtocol::Https(tls_result), + confidence: 1.0, + needs_more_data: false, + }); + } + } + + // 3. Check for SSH (port 22 or SSH banner) + if local_port == 22 || remote_port == 22 || payload.starts_with(b"SSH-") { + return Some(DpiResult { + application: ApplicationProtocol::Ssh, + confidence: 1.0, + needs_more_data: false, + }); + } + + // More protocols here... + + None +} + +/// Analyze a UDP packet payload +pub fn analyze_udp_packet( + payload: &[u8], + local_port: u16, + remote_port: u16, + is_outgoing: bool, +) -> Option { + if payload.is_empty() { + return None; + } + + // 1. DNS (port 53) + if local_port == 53 || remote_port == 53 { + if let Some(dns_result) = dns::analyze_dns(payload) { + return Some(DpiResult { + application: ApplicationProtocol::Dns(dns_result), + confidence: 1.0, + needs_more_data: false, + }); + } + } + + // 2. QUIC/HTTP3 (port 443) + if (local_port == 443 || remote_port == 443) && quic::is_quic_packet(payload) { + return Some(DpiResult { + application: ApplicationProtocol::Quic, + confidence: 0.9, // QUIC detection is less certain + needs_more_data: true, + }); + } + + None +} diff --git a/src/network/dpi/quic.rs b/src/network/dpi/quic.rs new file mode 100644 index 0000000..de6360e --- /dev/null +++ b/src/network/dpi/quic.rs @@ -0,0 +1,22 @@ +pub fn is_quic_packet(payload: &[u8]) -> bool { + if payload.len() < 5 { + return false; + } + + // Check for QUIC long header (bit 7 set) + if (payload[0] & 0x80) != 0 { + // Check version + let version = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]); + + // Known QUIC versions + return version == 0x00000001 || // QUIC v1 + version == 0x6b3343cf || // QUIC v2 + version == 0x51303530 || // Google QUIC + version == 0; // Version negotiation + } + + // Could be short header QUIC packet + // These are harder to identify definitively, but if we see them on port 443 UDP, + // they're likely QUIC + true +} diff --git a/src/network/dpi/tls.rs b/src/network/dpi/tls.rs new file mode 100644 index 0000000..4f5a7a1 --- /dev/null +++ b/src/network/dpi/tls.rs @@ -0,0 +1,201 @@ +use crate::network::types::{TlsInfo, TlsVersion}; + +pub fn is_tls_handshake(payload: &[u8]) -> bool { + if payload.len() < 6 { + return false; + } + + // TLS record header: + // - Content type (1 byte): 0x16 for handshake + // - Version (2 bytes): 0x0301-0x0304 for TLS 1.0-1.3 + // - Length (2 bytes) + + payload[0] == 0x16 && // Handshake content type + payload[1] == 0x03 && // Major version 3 + (payload[2] >= 0x01 && payload[2] <= 0x04) // Minor version 1-4 +} + +pub fn analyze_tls(payload: &[u8]) -> Option { + if !is_tls_handshake(payload) || payload.len() < 9 { + return None; + } + + let mut info = TlsInfo { + version: None, + sni: None, + alpn: Vec::new(), + cipher_suite: None, + }; + + // Record layer version + let record_version = match payload[2] { + 0x01 => Some(TlsVersion::Tls10), + 0x02 => Some(TlsVersion::Tls11), + 0x03 => Some(TlsVersion::Tls12), + 0x04 => Some(TlsVersion::Tls13), + _ => None, + }; + + // Skip TLS record header (5 bytes) + let handshake_data = &payload[5..]; + + if handshake_data.len() < 4 { + return Some(info); + } + + let handshake_type = handshake_data[0]; + + match handshake_type { + 0x01 => { + // Client Hello + info.version = record_version; + if let Some((sni, alpn)) = parse_client_hello_extensions(handshake_data) { + info.sni = sni; + info.alpn = alpn; + } + } + 0x02 => { + // Server Hello + info.version = record_version; + // Could parse cipher suite here if needed + } + _ => {} + } + + Some(info) +} + +/// Parse Client Hello extensions for SNI and ALPN +fn parse_client_hello_extensions(handshake_data: &[u8]) -> Option<(Option, Vec)> { + if handshake_data.len() < 38 { + return None; + } + + // Skip to extensions: + // - Handshake type (1) + Length (3) + Version (2) + Random (32) = 38 + let mut offset = 38; + + // Session ID + if offset >= handshake_data.len() { + return None; + } + let session_id_len = handshake_data[offset] as usize; + offset += 1 + session_id_len; + + // Cipher suites + if offset + 2 > handshake_data.len() { + return None; + } + let cipher_suites_len = + u16::from_be_bytes([handshake_data[offset], handshake_data[offset + 1]]) as usize; + offset += 2 + cipher_suites_len; + + // Compression methods + if offset >= handshake_data.len() { + return None; + } + let compression_len = handshake_data[offset] as usize; + offset += 1 + compression_len; + + // Extensions length + if offset + 2 > handshake_data.len() { + return None; + } + let extensions_len = + u16::from_be_bytes([handshake_data[offset], handshake_data[offset + 1]]) as usize; + offset += 2; + + if offset + extensions_len > handshake_data.len() { + return None; + } + + // Parse extensions + let mut sni = None; + let mut alpn = Vec::new(); + let extensions_data = &handshake_data[offset..offset + extensions_len]; + let mut ext_offset = 0; + + while ext_offset + 4 <= extensions_data.len() { + let ext_type = + u16::from_be_bytes([extensions_data[ext_offset], extensions_data[ext_offset + 1]]); + let ext_len = u16::from_be_bytes([ + extensions_data[ext_offset + 2], + extensions_data[ext_offset + 3], + ]) as usize; + + if ext_offset + 4 + ext_len > extensions_data.len() { + break; + } + + match ext_type { + 0x0000 => { + // SNI + sni = + parse_sni_extension(&extensions_data[ext_offset + 4..ext_offset + 4 + ext_len]); + } + 0x0010 => { + // ALPN + alpn = parse_alpn_extension( + &extensions_data[ext_offset + 4..ext_offset + 4 + ext_len], + ); + } + _ => {} + } + + ext_offset += 4 + ext_len; + } + + Some((sni, alpn)) +} + +fn parse_sni_extension(data: &[u8]) -> Option { + if data.len() < 5 { + return None; + } + + // Skip server name list length (2 bytes) + let mut offset = 2; + + while offset + 3 <= data.len() { + let name_type = data[offset]; + let name_len = u16::from_be_bytes([data[offset + 1], data[offset + 2]]) as usize; + + if name_type == 0x00 { + // host_name + if offset + 3 + name_len <= data.len() { + let hostname_bytes = &data[offset + 3..offset + 3 + name_len]; + if let Ok(hostname) = std::str::from_utf8(hostname_bytes) { + return Some(hostname.to_string()); + } + } + } + + offset += 3 + name_len; + } + + None +} + +/// Parse ALPN extension +fn parse_alpn_extension(data: &[u8]) -> Vec { + let mut protocols = Vec::new(); + + if data.len() < 2 { + return protocols; + } + + // Skip ALPN extension length + let mut offset = 2; + + while offset < data.len() { + let proto_len = data[offset] as usize; + if offset + 1 + proto_len <= data.len() { + if let Ok(proto) = std::str::from_utf8(&data[offset + 1..offset + 1 + proto_len]) { + protocols.push(proto.to_string()); + } + } + offset += 1 + proto_len; + } + + protocols +} diff --git a/src/network/linux.rs b/src/network/linux.rs deleted file mode 100644 index e3252be..0000000 --- a/src/network/linux.rs +++ /dev/null @@ -1,183 +0,0 @@ -// linux.rs -use anyhow::Result; -use std::collections::HashMap; -use std::fs; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; - -use super::{Connection, Protocol, ProtocolState}; - -/// Get connections with process information from /proc -pub fn get_connections_with_process_info(connections: &mut Vec) -> Result<()> { - // Parse TCP connections - parse_proc_net_file("/proc/net/tcp", Protocol::TCP, connections)?; - parse_proc_net_file("/proc/net/tcp6", Protocol::TCP, connections)?; - - // Parse UDP connections - parse_proc_net_file("/proc/net/udp", Protocol::UDP, connections)?; - parse_proc_net_file("/proc/net/udp6", Protocol::UDP, connections)?; - - // Build a map of inodes to process info - let inode_to_process = build_inode_to_process_map()?; - - // Enrich connections with process info - for conn in connections.iter_mut() { - if let Some(inode) = get_socket_inode(conn) { - if let Some((pid, name)) = inode_to_process.get(&inode) { - conn.pid = Some(*pid); - conn.process_name = Some(name.clone()); - } - } - } - - Ok(()) -} - -/// Parse a /proc/net file and add connections -fn parse_proc_net_file( - path: &str, - protocol: Protocol, - connections: &mut Vec, -) -> Result<()> { - let content = match fs::read_to_string(path) { - Ok(c) => c, - Err(_) => return Ok(()), // File might not exist - }; - - for (i, line) in content.lines().enumerate() { - if i == 0 { - continue; // Skip header - } - - let parts: Vec<&str> = line.split_whitespace().collect(); - if parts.len() < 10 { - continue; - } - - // Parse local address - let local_addr = match parse_hex_address(parts[1]) { - Some(addr) => addr, - None => continue, - }; - - // Parse remote address - let remote_addr = match parse_hex_address(parts[2]) { - Some(addr) => addr, - None => continue, - }; - - // Create a basic connection with minimal state - let state = match protocol { - Protocol::TCP => ProtocolState::Tcp(super::TcpState::Established), - Protocol::UDP => ProtocolState::Udp, - _ => continue, - }; - - let mut conn = Connection::new(protocol, local_addr, remote_addr, state); - - // Try to get inode from column 9 (0-indexed) - if parts.len() > 9 { - if let Ok(inode) = parts[9].parse::() { - // Store inode temporarily (we'll use a hack here - store in bytes_sent) - conn.bytes_sent = inode; - } - } - - connections.push(conn); - } - - Ok(()) -} - -/// Parse hex address from /proc/net format -fn parse_hex_address(hex_addr: &str) -> Option { - let parts: Vec<&str> = hex_addr.split(':').collect(); - if parts.len() != 2 { - return None; - } - - let ip_hex = parts[0]; - let port = u16::from_str_radix(parts[1], 16).ok()?; - - // Determine if IPv4 or IPv6 based on length - if ip_hex.len() == 8 { - // IPv4 - let ip_bytes = u32::from_str_radix(ip_hex, 16).ok()?; - let ip = Ipv4Addr::from(ip_bytes.to_le_bytes()); - Some(SocketAddr::new(IpAddr::V4(ip), port)) - } else if ip_hex.len() == 32 { - // IPv6 - let mut bytes = [0u8; 16]; - for i in 0..4 { - let chunk = &ip_hex[i * 8..(i + 1) * 8]; - let value = u32::from_str_radix(chunk, 16).ok()?; - bytes[i * 4..(i + 1) * 4].copy_from_slice(&value.to_le_bytes()); - } - let ip = Ipv6Addr::from(bytes); - Some(SocketAddr::new(IpAddr::V6(ip), port)) - } else { - None - } -} - -/// Build a map of socket inodes to process information -fn build_inode_to_process_map() -> Result> { - let mut inode_map = HashMap::new(); - - // Iterate through /proc/[pid]/fd/ - for entry in fs::read_dir("/proc")? { - let entry = entry?; - let path = entry.path(); - - // Check if it's a PID directory - if let Some(pid_str) = path.file_name().and_then(|s| s.to_str()) { - if let Ok(pid) = pid_str.parse::() { - // Get process name - let comm_path = path.join("comm"); - let process_name = fs::read_to_string(&comm_path) - .unwrap_or_else(|_| "unknown".to_string()) - .trim() - .to_string(); - - // Check all file descriptors - let fd_dir = path.join("fd"); - if let Ok(fd_entries) = fs::read_dir(&fd_dir) { - for fd_entry in fd_entries { - if let Ok(fd_entry) = fd_entry { - if let Ok(link) = fs::read_link(fd_entry.path()) { - if let Some(link_str) = link.to_str() { - if link_str.starts_with("socket:[") { - if let Some(inode) = extract_socket_inode(link_str) { - inode_map.insert(inode, (pid, process_name.clone())); - } - } - } - } - } - } - } - } - } - } - - Ok(inode_map) -} - -/// Extract inode from socket link like "socket:[12345]" -fn extract_socket_inode(link: &str) -> Option { - if link.starts_with("socket:[") && link.ends_with(']') { - let inode_str = &link[8..link.len() - 1]; - inode_str.parse().ok() - } else { - None - } -} - -/// Get socket inode for a connection -fn get_socket_inode(conn: &Connection) -> Option { - // We stored the inode in bytes_sent temporarily - if conn.bytes_sent > 0 { - Some(conn.bytes_sent) - } else { - None - } -} diff --git a/src/network/macos.rs b/src/network/macos.rs deleted file mode 100644 index 650a6c6..0000000 --- a/src/network/macos.rs +++ /dev/null @@ -1,391 +0,0 @@ -use anyhow::Result; -use log::debug; -use std::collections::HashSet; -use std::net::SocketAddr; -use std::process::Command; - -use super::{Connection, ConnectionState, NetworkMonitor, Process, Protocol}; - -/// Get platform-specific connections for macOS -pub fn get_platform_connections( - monitor: &NetworkMonitor, - connections: &mut Vec, -) -> Result<()> { - // Try different commands to maximize connection detection - // First try netstat - more reliable on macOS than lsof in some cases - monitor.get_connections_from_netstat(connections)?; - debug!("Found {} connections from netstat", connections.len()); - - // Then try lsof for additional connections - let before_count = connections.len(); - monitor.get_connections_from_lsof(connections)?; - debug!( - "Found {} additional connections from lsof", - connections.len() - before_count - ); - - Ok(()) -} - -impl NetworkMonitor { - /// Get connections from lsof command - pub(super) fn get_connections_from_lsof(&self, connections: &mut Vec) -> Result<()> { - // Track unique connections to avoid duplicates - let mut seen_connections = HashSet::new(); - for conn in connections.iter() { - let key = format!( - "{:?}:{}-{:?}:{}", - conn.protocol, conn.local_addr, conn.protocol, conn.remote_addr - ); - seen_connections.insert(key); - } - - // Use more aggressive lsof command with less filtering - let output = Command::new("lsof").args(["-i", "-n", "-P"]).output()?; - - if output.status.success() { - let text = String::from_utf8_lossy(&output.stdout); - - for line in text.lines().skip(1) { - // Skip header - let fields: Vec<&str> = line.split_whitespace().collect(); - if fields.len() < 8 { - continue; - } - - // Get process name and PID - let process_name = fields[0].to_string(); - let pid = fields[1].parse::().unwrap_or(0); - - // Find the field with connection info - format usually has (LISTEN), (ESTABLISHED) etc. - let proto_addr_idx = 8; - if fields.len() <= proto_addr_idx { - continue; - } - - let proto_addr = fields[proto_addr_idx]; - let proto_end = match proto_addr.find(' ') { - Some(pos) => pos, - None => continue, - }; - - let proto_str = &proto_addr[..proto_end].to_lowercase(); - let protocol = if proto_str == "tcp" || proto_str == "tcp4" || proto_str == "tcp6" { - Protocol::TCP - } else if proto_str == "udp" || proto_str == "udp4" || proto_str == "udp6" { - Protocol::UDP - } else { - continue; - }; - - // Parse connection state - let state = if fields.len() > proto_addr_idx + 1 { - match fields[proto_addr_idx + 1] { - "(ESTABLISHED)" => ConnectionState::Established, - "(LISTEN)" => ConnectionState::Listen, - "(TIME_WAIT)" => ConnectionState::TimeWait, - "(CLOSE_WAIT)" => ConnectionState::CloseWait, - "(SYN_SENT)" => ConnectionState::SynSent, - "(SYN_RECEIVED)" | "(SYN_RECV)" => ConnectionState::SynReceived, - "(FIN_WAIT_1)" => ConnectionState::FinWait1, - "(FIN_WAIT_2)" => ConnectionState::FinWait2, - "(LAST_ACK)" => ConnectionState::LastAck, - "(CLOSING)" => ConnectionState::Closing, - _ => ConnectionState::Unknown, - } - } else { - ConnectionState::Unknown - }; - - // Parse addresses - if proto_addr.find("->").is_some() { - // Has local and remote address (ESTABLISHED connection) - let addr_str = &proto_addr[proto_end + 1..]; - let parts: Vec<&str> = addr_str.split("->").collect(); - if parts.len() == 2 { - if let (Some(local), Some(remote)) = - (super::parse_addr(parts[0]), super::parse_addr(parts[1])) - { - // Check if this connection is already in our list - let conn_key = - format!("{:?}:{}-{:?}:{}", protocol, local, protocol, remote); - - if !seen_connections.contains(&conn_key) { - let mut conn = Connection::new(protocol, local, remote, state); - conn.pid = Some(pid); - conn.process_name = Some(process_name); - connections.push(conn); - seen_connections.insert(conn_key); - } - } - } - } else { - // Only local address (likely LISTEN) - let addr_str = &proto_addr[proto_end + 1..]; - if let Some(local) = super::parse_addr(addr_str) { - // Use 0.0.0.0:0 as remote for listening sockets - let remote = if local.ip().is_ipv4() { - "0.0.0.0:0".parse().unwrap() - } else { - "[::]:0".parse().unwrap() - }; - - // Check if this connection is already in our list - let conn_key = - format!("{:?}:{}-{:?}:{}", protocol, local, protocol, remote); - - if !seen_connections.contains(&conn_key) { - let mut conn = Connection::new(protocol, local, remote, state); - conn.pid = Some(pid); - conn.process_name = Some(process_name); - connections.push(conn); - seen_connections.insert(conn_key); - } - } - } - } - } - - Ok(()) - } - - /// Get connections from netstat command - pub(super) fn get_connections_from_netstat(&self, connections: &mut Vec) -> Result<()> { - // Track unique connections to avoid duplicates - let mut seen_connections = HashSet::new(); - - // Get TCP connections - let output = Command::new("netstat") - .args(["-anv", "-p", "tcp"]) - .output()?; - - if output.status.success() { - let text = String::from_utf8_lossy(&output.stdout); - - for line in text.lines().skip(2) { - // Skip headers - - let fields: Vec<&str> = line.split_whitespace().collect(); - if fields.len() < 5 { - continue; - } - - // Protocol is always TCP for this command - let protocol = Protocol::TCP; - - // Parse state - let state_idx = 5; // Index where state info is typically found - let state = if fields.len() > state_idx { - match fields[state_idx] { - "ESTABLISHED" => ConnectionState::Established, - "LISTEN" => ConnectionState::Listen, - "TIME_WAIT" => ConnectionState::TimeWait, - "CLOSE_WAIT" => ConnectionState::CloseWait, - "SYN_SENT" => ConnectionState::SynSent, - "SYN_RCVD" | "SYN_RECV" => ConnectionState::SynReceived, - "FIN_WAIT_1" => ConnectionState::FinWait1, - "FIN_WAIT_2" => ConnectionState::FinWait2, - "LAST_ACK" => ConnectionState::LastAck, - "CLOSING" => ConnectionState::Closing, - _ => ConnectionState::Unknown, - } - } else { - ConnectionState::Unknown - }; - - // Parse local and remote addresses - let local_idx = 3; - let remote_idx = 4; - - if fields.len() <= local_idx || fields.len() <= remote_idx { - continue; - } - - if let (Some(local), Some(remote)) = ( - super::parse_addr(fields[local_idx]), - super::parse_addr(fields[remote_idx]), - ) { - // Check if this connection is already in our list - let conn_key = format!("{:?}:{}-{:?}:{}", protocol, local, protocol, remote); - - if !seen_connections.contains(&conn_key) { - connections.push(Connection::new(protocol, local, remote, state)); - seen_connections.insert(conn_key); - } - } - } - } - - // Get UDP connections - let output = Command::new("netstat") - .args(["-anv", "-p", "udp"]) - .output()?; - - if output.status.success() { - let text = String::from_utf8_lossy(&output.stdout); - - for line in text.lines().skip(2) { - // Skip headers - let fields: Vec<&str> = line.split_whitespace().collect(); - if fields.len() < 4 { - continue; - } - - // Protocol is always UDP for this command - let protocol = Protocol::UDP; - - // Parse local address - let local_idx = 3; - - if fields.len() <= local_idx { - continue; - } - - if let Some(local) = super::parse_addr(fields[local_idx]) { - // Use 0.0.0.0:0 as remote for UDP - let remote = if local.ip().is_ipv4() { - "0.0.0.0:0".parse().unwrap() - } else { - "[::]:0".parse().unwrap() - }; - - // Check if this connection is already in our list - let conn_key = format!("{:?}:{}-{:?}:{}", protocol, local, protocol, remote); - - if !seen_connections.contains(&conn_key) { - connections.push(Connection::new( - protocol, - local, - remote, - ConnectionState::Unknown, - )); - seen_connections.insert(conn_key); - } - } - } - } - - Ok(()) - } -} - -/// Parses the NAME field of lsof output to extract local and remote addresses. -pub(super) fn parse_lsof_addrs(addr_field: &str) -> Option<(SocketAddr, SocketAddr)> { - if let Some(arrow_idx) = addr_field.find("->") { - let local_str = &addr_field[..arrow_idx]; - let remote_str = &addr_field[arrow_idx + 2..]; - let local_addr = super::parse_addr(local_str)?; - let remote_addr = super::parse_addr(remote_str)?; - Some((local_addr, remote_addr)) - } else { - let local_addr = super::parse_addr(addr_field)?; - let remote_addr = "0.0.0.0:0".parse().ok()?; - Some((local_addr, remote_addr)) - } -} - -/// Get process information using lsof command -pub(super) fn try_lsof_command(connection: &Connection) -> Option { - let proto_arg = match connection.protocol { - Protocol::TCP => "TCP", - Protocol::UDP => "UDP", - Protocol::ICMP => return None, - }; - - let output = Command::new("lsof") - .args(["-i", proto_arg, "-n", "-P"]) - .output() - .ok()?; - - if output.status.success() { - let text = String::from_utf8_lossy(&output.stdout); - for line in text.lines().skip(1) { - let fields: Vec<&str> = line.split_whitespace().collect(); - if fields.len() < 9 { - continue; - } - - if let Some((lsof_local, lsof_remote)) = parse_lsof_addrs(fields[8]) { - let c = connection; - let match1 = c.local_addr == lsof_local && c.remote_addr == lsof_remote; - let match2 = c.local_addr == lsof_remote && c.remote_addr == lsof_local; - - if match1 || match2 { - if let Ok(pid) = fields[1].parse::() { - return Some(Process { - pid, - name: fields[0].to_string(), - }); - } - } - } - } - } - None -} - -/// Get process information using netstat command -pub(super) fn try_netstat_command(connection: &Connection) -> Option { - if let Some(process) = try_lsof_command(connection) { - return Some(process); - } - - let output = Command::new("netstat") - .args(["-p", "tcp", "-v"]) - .output() - .ok()?; - - if output.status.success() { - let text = String::from_utf8_lossy(&output.stdout); - let local_port = connection.local_addr.port(); - let remote_port = connection.remote_addr.port(); - - for line in text.lines().skip(2) { - let fields: Vec<&str> = line.split_whitespace().collect(); - if fields.len() < 9 { - continue; - } - - if let Some(local_addr_str) = fields.get(3) { - if let Some(remote_addr_str) = fields.get(4) { - if let (Some(local_addr), Some(remote_addr)) = ( - super::parse_addr(local_addr_str), - super::parse_addr(remote_addr_str), - ) { - if local_addr.port() == local_port && remote_addr.port() == remote_port { - if let Some(pid_str) = fields.get(8) { - if let Ok(pid) = pid_str.parse::() { - return get_process_name_by_pid(pid) - .map(|name| Process { pid, name }); - } - } - } - } - } - } - } - } - None -} - -/// Get process name by PID -#[allow(dead_code)] -pub(super) fn get_process_name_by_pid(pid: u32) -> Option { - let output = Command::new("ps") - .args(["-p", &pid.to_string(), "-o", "comm="]) - .output() - .ok()?; - - if !output.status.success() { - return None; - } - - let text = String::from_utf8_lossy(&output.stdout); - let name = text.trim(); - if name.is_empty() { - None - } else { - Some(name.to_string()) - } -} - diff --git a/src/network/merge.rs b/src/network/merge.rs new file mode 100644 index 0000000..0a4cd85 --- /dev/null +++ b/src/network/merge.rs @@ -0,0 +1,303 @@ +// network/merge.rs - Connection merging and update utilities +use crate::network::dpi::DpiResult; +use crate::network::parser::ParsedPacket; +use crate::network::types::{ApplicationProtocol, Connection, DpiInfo, RateInfo}; +use std::time::{Instant, SystemTime}; + +/// Merge a parsed packet into an existing connection +pub fn merge_packet_into_connection( + mut conn: Connection, + parsed: &ParsedPacket, + now: SystemTime, +) -> Connection { + // Update timing + conn.last_activity = now; + + // Update packet counts and bytes + if parsed.is_outgoing { + conn.packets_sent += 1; + conn.bytes_sent += parsed.packet_len as u64; + } else { + conn.packets_received += 1; + conn.bytes_received += parsed.packet_len as u64; + } + + // Update protocol state (from packet flags/state) + conn.protocol_state = parsed.state; + + // Update DPI info if available and better than what we have + if let Some(dpi_result) = &parsed.dpi_result { + merge_dpi_info(&mut conn, dpi_result); + } + + conn +} + +/// Create a new connection from a parsed packet +pub fn create_connection_from_packet(parsed: &ParsedPacket, now: SystemTime) -> Connection { + let mut conn = Connection::new( + parsed.protocol, + parsed.local_addr, + parsed.remote_addr, + parsed.state, + ); + + // Set initial stats based on packet direction + if parsed.is_outgoing { + conn.packets_sent = 1; + conn.bytes_sent = parsed.packet_len as u64; + } else { + conn.packets_received = 1; + conn.bytes_received = parsed.packet_len as u64; + } + + // Apply DPI results if any + if let Some(dpi_result) = &parsed.dpi_result { + conn.dpi_info = Some(DpiInfo { + application: dpi_result.application.clone(), + first_packet_time: Instant::now(), + last_update_time: Instant::now(), + }); + } + + conn.created_at = now; + conn.last_activity = now; + + conn +} + +/// Merge DPI results into connection +fn merge_dpi_info(conn: &mut Connection, dpi_result: &DpiResult) { + match &conn.dpi_info { + None => { + // No existing DPI info, use the new one + conn.dpi_info = Some(DpiInfo { + application: dpi_result.application.clone(), + first_packet_time: Instant::now(), + last_update_time: Instant::now(), + }); + } + Some(existing) => { + // Only update if new info has higher confidence or is more specific + if should_update_dpi( + &existing.application, + &dpi_result.application, + dpi_result.confidence, + ) { + conn.dpi_info = Some(DpiInfo { + application: dpi_result.application.clone(), + first_packet_time: existing.first_packet_time, + last_update_time: Instant::now(), + }); + } + } + } +} + +/// Determine if we should update DPI info based on confidence and specificity +fn should_update_dpi( + existing: &ApplicationProtocol, + new: &ApplicationProtocol, + new_confidence: f32, +) -> bool { + // High confidence always wins + if new_confidence >= 0.95 { + return true; + } + + // Specific protocols override Unknown + match (existing, new) { + (ApplicationProtocol::Unknown, _) => true, + (_, ApplicationProtocol::Unknown) => false, + // HTTPS is more specific than HTTP + (ApplicationProtocol::Http(_), ApplicationProtocol::Https(_)) => true, + (ApplicationProtocol::Https(_), ApplicationProtocol::Http(_)) => false, + // Otherwise, only update if confidence is good + _ => new_confidence >= 0.8, + } +} + +/// Enrich connection with process information +pub fn enrich_with_process_info( + mut conn: Connection, + pid: u32, + process_name: String, +) -> Connection { + conn.pid = Some(pid); + conn.process_name = Some(process_name); + conn +} + +/// Enrich connection with service name +pub fn enrich_with_service_name(mut conn: Connection, service_name: String) -> Connection { + conn.service_name = Some(service_name); + conn +} + +/// Update connection rates based on current stats +pub fn update_connection_rates(mut conn: Connection, now: Instant) -> Connection { + let elapsed = now + .duration_since(conn.current_rate_bps.last_calculation) + .as_secs_f64(); + + if elapsed > 0.1 { + // Update at most every 100ms + conn.current_rate_bps = RateInfo { + outgoing_bps: (conn.bytes_sent as f64 * 8.0) / elapsed, + incoming_bps: (conn.bytes_received as f64 * 8.0) / elapsed, + last_calculation: now, + }; + + // Update backward compatibility fields + conn.current_incoming_rate_bps = conn.current_rate_bps.incoming_bps; + conn.current_outgoing_rate_bps = conn.current_rate_bps.outgoing_bps; + } + + conn +} + +/// Merge two connections (useful for combining data from different sources) +pub fn merge_connections(mut primary: Connection, secondary: &Connection) -> Connection { + // Use secondary's process info if primary doesn't have it + if primary.pid.is_none() && secondary.pid.is_some() { + primary.pid = secondary.pid; + primary.process_name = secondary.process_name.clone(); + } + + // Use secondary's service name if primary doesn't have it + if primary.service_name.is_none() && secondary.service_name.is_some() { + primary.service_name = secondary.service_name.clone(); + } + + // Merge traffic stats (take the maximum) + primary.bytes_sent = primary.bytes_sent.max(secondary.bytes_sent); + primary.bytes_received = primary.bytes_received.max(secondary.bytes_received); + primary.packets_sent = primary.packets_sent.max(secondary.packets_sent); + primary.packets_received = primary.packets_received.max(secondary.packets_received); + + // Use the earlier creation time + if secondary.created_at < primary.created_at { + primary.created_at = secondary.created_at; + } + + // Use the later last activity time + if secondary.last_activity > primary.last_activity { + primary.last_activity = secondary.last_activity; + } + + // Merge DPI info (prefer more specific) + if let Some(secondary_dpi) = &secondary.dpi_info { + match &primary.dpi_info { + None => primary.dpi_info = Some(secondary_dpi.clone()), + Some(primary_dpi) => { + if should_update_dpi(&primary_dpi.application, &secondary_dpi.application, 0.9) { + primary.dpi_info = Some(secondary_dpi.clone()); + } + } + } + } + + primary +} + +/// Check if two connections represent the same flow +pub fn connections_match(a: &Connection, b: &Connection) -> bool { + a.protocol == b.protocol && a.local_addr == b.local_addr && a.remote_addr == b.remote_addr +} + +/// Check if a connection matches a parsed packet +pub fn connection_matches_packet(conn: &Connection, parsed: &ParsedPacket) -> bool { + conn.protocol == parsed.protocol + && conn.local_addr == parsed.local_addr + && conn.remote_addr == parsed.remote_addr +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::network::types::{Protocol, ProtocolState, TcpState}; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + fn create_test_connection() -> Connection { + Connection::new( + Protocol::TCP, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), 12345), + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 80), + ProtocolState::Tcp(TcpState::Established), + ) + } + + fn create_test_packet(is_outgoing: bool) -> ParsedPacket { + ParsedPacket { + connection_key: "test".to_string(), + protocol: Protocol::TCP, + local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), 12345), + remote_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 80), + state: ProtocolState::Tcp(TcpState::Established), + is_outgoing, + packet_len: 100, + dpi_result: None, + } + } + + #[test] + fn test_merge_packet_into_connection() { + let mut conn = create_test_connection(); + let packet = create_test_packet(true); + + conn = merge_packet_into_connection(conn, &packet, SystemTime::now()); + + assert_eq!(conn.packets_sent, 1); + assert_eq!(conn.bytes_sent, 100); + assert_eq!(conn.packets_received, 0); + } + + #[test] + fn test_create_connection_from_packet() { + let packet = create_test_packet(false); + let conn = create_connection_from_packet(&packet, SystemTime::now()); + + assert_eq!(conn.packets_received, 1); + assert_eq!(conn.bytes_received, 100); + assert_eq!(conn.packets_sent, 0); + } + + #[test] + fn test_enrich_with_process_info() { + let conn = create_test_connection(); + let enriched = enrich_with_process_info(conn, 1234, "firefox".to_string()); + + assert_eq!(enriched.pid, Some(1234)); + assert_eq!(enriched.process_name, Some("firefox".to_string())); + } + + #[test] + fn test_merge_connections() { + let mut primary = create_test_connection(); + primary.bytes_sent = 1000; + + let mut secondary = create_test_connection(); + secondary.pid = Some(5678); + secondary.process_name = Some("chrome".to_string()); + secondary.bytes_sent = 2000; + + let merged = merge_connections(primary, &secondary); + + assert_eq!(merged.pid, Some(5678)); + assert_eq!(merged.process_name, Some("chrome".to_string())); + assert_eq!(merged.bytes_sent, 2000); // Takes the maximum + } + + #[test] + fn test_connections_match() { + let conn1 = create_test_connection(); + let conn2 = create_test_connection(); + + assert!(connections_match(&conn1, &conn2)); + + let mut conn3 = create_test_connection(); + conn3.local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 101)), 12345); + + assert!(!connections_match(&conn1, &conn3)); + } +} diff --git a/src/network/mod.rs b/src/network/mod.rs index 4fe124a..3b6822a 100644 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -1,1713 +1,17 @@ -use anyhow::{Result, anyhow}; -use log::{debug, error, info, warn}; -use pcap::{Capture, Device}; -use std::collections::HashMap; -use std::fs::File; -use std::io::{BufRead, BufReader}; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use std::path::Path; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::mpsc::Sender; -use std::time::{Duration, Instant, SystemTime}; - -#[cfg(target_os = "linux")] -mod linux; - -#[cfg(target_os = "windows")] -mod windows; -#[cfg(target_os = "windows")] -use windows::*; - -#[cfg(target_os = "macos")] -mod macos; - -/// Transport protocol -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum Protocol { - TCP, - UDP, - ICMP, - ARP, -} - -impl std::fmt::Display for Protocol { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Protocol::TCP => write!(f, "TCP"), - Protocol::UDP => write!(f, "UDP"), - Protocol::ICMP => write!(f, "ICMP"), - Protocol::ARP => write!(f, "ARP"), - } - } -} - -/// TCP connection state -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TcpState { - Listen, - SynSent, - SynReceived, - Established, - FinWait1, - FinWait2, - CloseWait, - LastAck, - TimeWait, - Closing, - Closed, -} - -/// Protocol-specific state information -#[derive(Debug, Clone, Copy)] -pub enum ProtocolState { - Tcp(TcpState), - Udp, // UDP is stateless - Icmp { - icmp_type: u8, // 8=Echo Request, 0=Echo Reply, etc. - icmp_code: u8, - }, - Arp { - operation: ArpOperation, - }, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ArpOperation { - Request, - Reply, -} - -/// Application layer protocol detection -#[derive(Debug, Clone)] -pub enum ApplicationProtocol { - Http(HttpInfo), - Https(TlsInfo), - Dns(DnsInfo), - Ssh, - Quic, // Basic QUIC detection without deep parsing - Unknown, -} - -/// HTTP information -#[derive(Debug, Clone)] -pub struct HttpInfo { - pub version: HttpVersion, - pub method: Option, // GET, POST, etc. - pub host: Option, // From Host header - pub path: Option, // Request path - pub status_code: Option, // For responses - pub user_agent: Option, // Useful for identifying clients -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum HttpVersion { - Http10, - Http11, - Http2, - Http3, // Inferred from QUIC -} - -/// TLS/HTTPS information -#[derive(Debug, Clone)] -pub struct TlsInfo { - pub version: Option, - pub sni: Option, - pub alpn: Vec, // Application protocols like "h2", "http/1.1" - pub cipher_suite: Option, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TlsVersion { - Ssl3, - Tls10, - Tls11, - Tls12, - Tls13, -} - -/// DNS information -#[derive(Debug, Clone)] -pub struct DnsInfo { - pub query_name: Option, - pub query_type: Option, - pub response_ips: Vec, - pub is_response: bool, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum DnsQueryType { - A, - AAAA, - CNAME, - MX, - TXT, - Other(u16), -} - -/// Deep packet inspection results -#[derive(Debug, Clone)] -pub struct DpiInfo { - pub application: ApplicationProtocol, - pub first_packet_time: Instant, - pub last_update_time: Instant, -} - -/// Rate information -#[derive(Debug, Clone)] -pub struct RateInfo { - pub incoming_bps: f64, - pub outgoing_bps: f64, - pub last_calculation: Instant, -} - -impl Default for RateInfo { - fn default() -> Self { - Self { - incoming_bps: 0.0, - outgoing_bps: 0.0, - last_calculation: Instant::now(), - } - } -} - -/// Network connection -#[derive(Debug, Clone)] -pub struct Connection { - // Core identification - pub protocol: Protocol, - pub local_addr: SocketAddr, - pub remote_addr: SocketAddr, - - // Protocol state - pub protocol_state: ProtocolState, - - // Process information - pub pid: Option, - pub process_name: Option, - - // Traffic statistics - pub bytes_sent: u64, - pub bytes_received: u64, - pub packets_sent: u64, - pub packets_received: u64, - - // Timing - pub created_at: SystemTime, - pub last_activity: SystemTime, - - // Service identification - pub service_name: Option, // From port lookup - - // Deep packet inspection - pub dpi_info: Option, - - // Performance metrics - pub current_rate_bps: RateInfo, - pub rtt_estimate: Option, // Round-trip time if measurable - - // Backward compatibility fields - pub current_incoming_rate_bps: f64, - pub current_outgoing_rate_bps: f64, -} - -// Add a simple state field for backward compatibility -impl Connection { - pub fn state(&self) -> String { - match &self.protocol_state { - ProtocolState::Tcp(tcp_state) => format!("{:?}", tcp_state), - ProtocolState::Udp => "ACTIVE".to_string(), - ProtocolState::Icmp { icmp_type, .. } => match icmp_type { - 8 => "ECHO_REQUEST".to_string(), - 0 => "ECHO_REPLY".to_string(), - 3 => "DEST_UNREACH".to_string(), - 11 => "TIME_EXCEEDED".to_string(), - _ => "UNKNOWN".to_string(), - }, - ProtocolState::Arp { operation } => match operation { - ArpOperation::Request => "ARP_REQUEST".to_string(), - ArpOperation::Reply => "ARP_REPLY".to_string(), - }, - } - } -} - -impl Connection { - /// Create a new connection - pub fn new( - protocol: Protocol, - local_addr: SocketAddr, - remote_addr: SocketAddr, - state: ProtocolState, - ) -> Self { - let now = SystemTime::now(); - Self { - protocol, - local_addr, - remote_addr, - protocol_state: state, - pid: None, - process_name: None, - bytes_sent: 0, - bytes_received: 0, - packets_sent: 0, - packets_received: 0, - created_at: now, - last_activity: now, - service_name: None, - dpi_info: None, - current_rate_bps: RateInfo::default(), - rtt_estimate: None, - // Backward compatibility - current_incoming_rate_bps: 0.0, - current_outgoing_rate_bps: 0.0, - } - } - - /// Check if connection is active (had activity in the last minute) - pub fn is_active(&self) -> bool { - self.last_activity.elapsed().unwrap_or_default() < Duration::from_secs(60) - } - - /// Get the age of the connection (time since creation) - pub fn age(&self) -> Duration { - self.created_at.elapsed().unwrap_or_default() - } - - /// Get time since last activity - pub fn idle_time(&self) -> Duration { - self.last_activity.elapsed().unwrap_or_default() - } - - /// Update transfer rates - pub fn update_rates(&mut self, new_sent: u64, new_received: u64) { - let now = Instant::now(); - let elapsed = now - .duration_since(self.current_rate_bps.last_calculation) - .as_secs_f64(); - - if elapsed > 0.1 { - // Update rates every 100ms minimum - let sent_diff = new_sent.saturating_sub(self.bytes_sent) as f64; - let recv_diff = new_received.saturating_sub(self.bytes_received) as f64; - - self.current_rate_bps = RateInfo { - outgoing_bps: (sent_diff * 8.0) / elapsed, - incoming_bps: (recv_diff * 8.0) / elapsed, - last_calculation: now, - }; - - // Update backward compatibility fields - self.current_incoming_rate_bps = self.current_rate_bps.incoming_bps; - self.current_outgoing_rate_bps = self.current_rate_bps.outgoing_bps; - } - } -} - -/// Process information -#[derive(Debug, Clone)] -pub struct Process { - pub pid: u32, - pub name: String, -} - -/// Main function for the packet capture thread -pub fn packet_capture_thread( - interface_name: Option, - packet_tx: Sender>, - should_stop: Arc, -) -> Result<()> { - let cap_device = match interface_name { - Some(iface) => { - info!("Searching for specified interface: {}", iface); - Device::list()? - .into_iter() - .find(|d| d.name == iface) - .ok_or_else(|| anyhow!("Interface '{}' not found", iface))? - } - None => { - info!("No interface specified, looking up default."); - Device::lookup()?.ok_or_else(|| anyhow!("No default device found"))? - } - }; - - info!("Opening capture on device: {}", cap_device.name); - let mut cap = Capture::from_device(cap_device)? - .promisc(true) - .snaplen(1024) // Increased for DPI - .buffer_size(2_000_000) - .timeout(0) - .immediate_mode(true) - .open()?; - - info!("Applying BPF filter for IPv4 and IPv6"); - cap.filter( - "(ip and (tcp or udp or icmp)) or (ip6 and (tcp or udp or icmp6)) or arp", - true, - )?; - - loop { - if should_stop.load(Ordering::Relaxed) { - info!("Stop signal received, shutting down capture thread."); - break; - } - match cap.next_packet() { - Ok(packet) => { - if packet_tx.send(packet.data.to_vec()).is_err() { - info!("Packet receiver has disconnected, stopping capture thread."); - break; - } - } - Err(pcap::Error::TimeoutExpired) => { - debug!("Timeout expired, no packet captured this iteration."); - continue; - } - Err(e) => { - error!("Error capturing packet: {}", e); - break; - } - } - } - - Ok(()) -} - -/// Network monitor -pub struct NetworkMonitor { - connections: HashMap, - service_lookup: ServiceLookup, - filter_localhost: bool, - local_ips: std::collections::HashSet, -} - -/// Manages lookup of service names from a services file -#[derive(Debug)] -struct ServiceLookup { - services: HashMap<(u16, Protocol), String>, -} - -impl ServiceLookup { - fn new(file_path_str: &str) -> Result { - let mut services = HashMap::new(); - let file_path = Path::new(file_path_str); - - if !file_path.exists() { - warn!( - "Service definition file not found at '{}'. Service names will not be available.", - file_path_str - ); - return Ok(Self { services }); - } - - let file = File::open(file_path)?; - let reader = BufReader::new(file); - - for line_result in reader.lines() { - let line = match line_result { - Ok(l) => l, - Err(e) => { - warn!("Error reading line from services file: {}", e); - continue; - } - }; - - let line = line.trim(); - if line.is_empty() || line.starts_with('#') { - continue; - } - - let parts: Vec<&str> = line.split_whitespace().collect(); - if parts.len() < 2 { - debug!("Skipping malformed line in services file: {}", line); - continue; - } - - let service_name = parts[0].to_string(); - let port_protocol_str = parts[1]; - - let port_protocol_parts: Vec<&str> = port_protocol_str.split('/').collect(); - if port_protocol_parts.len() != 2 { - debug!( - "Skipping malformed port/protocol in services file: {} from line: {}", - port_protocol_str, line - ); - continue; - } - - let port = match port_protocol_parts[0].parse::() { - Ok(p) => p, - Err(_) => { - debug!( - "Skipping invalid port in services file: {} from line: {}", - port_protocol_parts[0], line - ); - continue; - } - }; - - let protocol = match port_protocol_parts[1].to_lowercase().as_str() { - "tcp" => Protocol::TCP, - "udp" => Protocol::UDP, - _ => continue, - }; - - services.entry((port, protocol)).or_insert(service_name); - } - debug!( - "ServiceLookup initialized with {} entries from '{}'", - services.len(), - file_path_str - ); - Ok(Self { services }) - } - - fn get(&self, port: u16, protocol: Protocol) -> Option { - self.services.get(&(port, protocol)).cloned() - } -} - -/// Sets the service name for a given connection based on its port and protocol -fn set_connection_service_name_for_connection( - conn: &mut Connection, - service_lookup: &ServiceLookup, -) { - let local_port = conn.local_addr.port(); - let remote_port = conn.remote_addr.port(); - let protocol = conn.protocol; - - let mut final_service_name: Option = None; - - match conn.protocol_state { - ProtocolState::Tcp(TcpState::Listen) => { - final_service_name = service_lookup.get(local_port, protocol); - } - _ => { - let local_service_name_opt = service_lookup.get(local_port, protocol); - if local_service_name_opt.is_some() { - final_service_name = local_service_name_opt; - } else { - let remote_service_name_opt = service_lookup.get(remote_port, protocol); - if remote_service_name_opt.is_some() { - final_service_name = remote_service_name_opt; - } - } - } - } - conn.service_name = final_service_name; -} - -impl NetworkMonitor { - /// Create a new network monitor - pub fn new(_interface: Option, filter_localhost: bool) -> Result { - log::info!("NetworkMonitor::new - Initializing"); - - let mut local_ips = std::collections::HashSet::new(); - for iface in pnet_datalink::interfaces() { - for ip_network in iface.ips { - local_ips.insert(ip_network.ip()); - } - } - - if local_ips.is_empty() { - warn!( - "Could not determine any local IP addresses. Connection directionality might be inaccurate." - ); - } else { - debug!("Found local IPs: {:?}", local_ips); - } - - let services_file_path = "assets/services"; - let service_lookup = ServiceLookup::new(services_file_path).unwrap_or_else(|e| { - error!( - "Failed to load service definitions from '{}': {}. Proceeding without service names.", - services_file_path, e - ); - ServiceLookup { - services: HashMap::new(), - } - }); - - log::info!("NetworkMonitor::new - Initialization complete"); - Ok(Self { - local_ips, - service_lookup, - connections: HashMap::new(), - filter_localhost, - }) - } - - /// Get active connections - pub fn get_connections(&mut self) -> Result> { - // Start with pcap-captured connections as the primary source - let mut result_connections: Vec = self - .connections - .values() - .filter(|conn| conn.is_active()) - .cloned() - .collect(); - - debug!( - "get_connections: Found {} active pcap connections", - result_connections.len() - ); - - // Enrich pcap connections with process information from platform - if !result_connections.is_empty() { - // Get connection info with processes from platform - let mut platform_conns: Vec = Vec::new(); - - #[cfg(target_os = "linux")] - { - if let Err(e) = linux::get_connections_with_process_info(&mut platform_conns) { - error!("Error getting process info from platform: {}", e); - } - } - #[cfg(target_os = "macos")] - { - if let Err(e) = macos::get_connections_with_process_info(&mut platform_conns) { - error!("Error getting process info from platform: {}", e); - } - } - #[cfg(target_os = "windows")] - { - if let Err(e) = windows::get_connections_with_process_info(&mut platform_conns) { - error!("Error getting process info from platform: {}", e); - } - } - - debug!( - "Found {} platform connections for process enrichment", - platform_conns.len() - ); - - // Create a lookup map for platform connections - let mut platform_lookup: HashMap = HashMap::new(); - for conn in platform_conns { - if let (Some(pid), Some(name)) = (conn.pid, conn.process_name) { - let key = format!( - "{:?}:{}-{:?}:{}", - conn.protocol, conn.local_addr, conn.protocol, conn.remote_addr - ); - platform_lookup.insert(key, (pid, name)); - } - } - - // Enrich pcap connections with process names - for conn in &mut result_connections { - if conn.process_name.is_none() { - let key = self.get_connection_key_for_merge(conn); - if let Some((pid, name)) = platform_lookup.get(&key) { - debug!( - "Enriching connection {}:{} with process {} (PID: {})", - conn.local_addr, conn.remote_addr, name, pid - ); - conn.process_name = Some(name.clone()); - } - } - } - } - - // Sort by last activity (most recent first) - result_connections.sort_by(|a, b| b.last_activity.cmp(&a.last_activity)); - - // Apply localhost filter if enabled - if self.filter_localhost { - result_connections.retain(|conn| { - !(conn.local_addr.ip().is_loopback() && conn.remote_addr.ip().is_loopback()) - }); - } - - // Set service names - for conn in &mut result_connections { - set_connection_service_name_for_connection(conn, &self.service_lookup); - if conn.current_rate_bps.incoming_bps > 0.0 || conn.current_rate_bps.outgoing_bps > 0.0 - { - debug!( - "Connection: {:?}, Incoming: {:.2} bps, Outgoing: {:.2} bps", - conn.local_addr, - conn.current_rate_bps.incoming_bps, - conn.current_rate_bps.outgoing_bps - ); - } - } - - debug!( - "get_connections: Returning {} total connections", - result_connections.len() - ); - Ok(result_connections) - } - - fn determine_addresses( - &self, - src_ip: IpAddr, - src_port: u16, - dst_ip: IpAddr, - dst_port: u16, - is_outgoing: bool, - ) -> (SocketAddr, SocketAddr) { - if is_outgoing { - ( - SocketAddr::new(src_ip, src_port), - SocketAddr::new(dst_ip, dst_port), - ) - } else { - ( - SocketAddr::new(dst_ip, dst_port), - SocketAddr::new(src_ip, src_port), - ) - } - } - - /// Process a single raw packet from the queue - pub fn process_packet(&mut self, data: &[u8]) { - if data.len() < 14 { - return; - } - - // Check EtherType to determine packet type - let ethertype = u16::from_be_bytes([data[12], data[13]]); - - match ethertype { - 0x0800 => { - // IPv4 packet - self.process_ipv4_packet(data); - } - 0x86dd => { - // IPv6 packet - self.process_ipv6_packet(data); - } - 0x0806 => { - // ARP packet - self.process_arp_packet(data); - } - _ => { - // Other packet types - ignore - } - } - } - - fn process_ipv4_packet(&mut self, data: &[u8]) { - let ip_data = &data[14..]; - if ip_data.len() < 20 { - return; - } - - let version = ip_data[0] >> 4; - if version != 4 { - return; - } - - let protocol = ip_data[9]; - let src_ip = IpAddr::V4(Ipv4Addr::new( - ip_data[12], - ip_data[13], - ip_data[14], - ip_data[15], - )); - let dst_ip = IpAddr::V4(Ipv4Addr::new( - ip_data[16], - ip_data[17], - ip_data[18], - ip_data[19], - )); - - let ihl = ip_data[0] & 0x0F; - let ip_header_len = (ihl as usize) * 4; - - if ip_data.len() < ip_header_len { - return; - } - let transport_data = &ip_data[ip_header_len..]; - - let is_outgoing = self.local_ips.contains(&src_ip); - - match protocol { - 1 => self.process_icmp_packet(data, is_outgoing, transport_data, src_ip, dst_ip), - 6 => self.process_tcp_packet(data, is_outgoing, transport_data, src_ip, dst_ip), - 17 => self.process_udp_packet(data, is_outgoing, transport_data, src_ip, dst_ip), - _ => {} - } - } - - fn process_ipv6_packet(&mut self, data: &[u8]) { - let ip_data = &data[14..]; - if ip_data.len() < 40 { - // IPv6 header is fixed 40 bytes - return; - } - - let version = ip_data[0] >> 4; - if version != 6 { - return; - } - - let next_header = ip_data[6]; // Protocol type - - // Extract IPv6 addresses - let src_ip = IpAddr::V6(Ipv6Addr::new( - u16::from_be_bytes([ip_data[8], ip_data[9]]), - u16::from_be_bytes([ip_data[10], ip_data[11]]), - u16::from_be_bytes([ip_data[12], ip_data[13]]), - u16::from_be_bytes([ip_data[14], ip_data[15]]), - u16::from_be_bytes([ip_data[16], ip_data[17]]), - u16::from_be_bytes([ip_data[18], ip_data[19]]), - u16::from_be_bytes([ip_data[20], ip_data[21]]), - u16::from_be_bytes([ip_data[22], ip_data[23]]), - )); - - let dst_ip = IpAddr::V6(Ipv6Addr::new( - u16::from_be_bytes([ip_data[24], ip_data[25]]), - u16::from_be_bytes([ip_data[26], ip_data[27]]), - u16::from_be_bytes([ip_data[28], ip_data[29]]), - u16::from_be_bytes([ip_data[30], ip_data[31]]), - u16::from_be_bytes([ip_data[32], ip_data[33]]), - u16::from_be_bytes([ip_data[34], ip_data[35]]), - u16::from_be_bytes([ip_data[36], ip_data[37]]), - u16::from_be_bytes([ip_data[38], ip_data[39]]), - )); - - let transport_data = &ip_data[40..]; // IPv6 header is always 40 bytes - let is_outgoing = self.local_ips.contains(&src_ip); - - // Handle extension headers if present - let (final_next_header, transport_offset) = - self.parse_ipv6_extension_headers(next_header, transport_data); - let final_transport_data = &transport_data[transport_offset..]; - - match final_next_header { - 58 => { - self.process_icmpv6_packet(data, is_outgoing, final_transport_data, src_ip, dst_ip) - } - 6 => self.process_tcp_packet(data, is_outgoing, final_transport_data, src_ip, dst_ip), - 17 => self.process_udp_packet(data, is_outgoing, final_transport_data, src_ip, dst_ip), - _ => {} - } - } - - fn parse_ipv6_extension_headers(&self, mut next_header: u8, data: &[u8]) -> (u8, usize) { - let mut offset = 0; - - // Common IPv6 extension headers - const HOP_BY_HOP: u8 = 0; - const ROUTING: u8 = 43; - const FRAGMENT: u8 = 44; - const ENCAPSULATING_SECURITY: u8 = 50; - const AUTHENTICATION: u8 = 51; - const DESTINATION_OPTIONS: u8 = 60; - - loop { - match next_header { - HOP_BY_HOP | ROUTING | DESTINATION_OPTIONS => { - if data.len() < offset + 2 { - return (next_header, offset); - } - next_header = data[offset]; - let header_len = ((data[offset + 1] as usize) + 1) * 8; - offset += header_len; - } - FRAGMENT => { - if data.len() < offset + 8 { - return (next_header, offset); - } - next_header = data[offset]; - offset += 8; // Fragment header is fixed 8 bytes - } - AUTHENTICATION => { - if data.len() < offset + 2 { - return (next_header, offset); - } - next_header = data[offset]; - let header_len = ((data[offset + 1] as usize) + 2) * 4; - offset += header_len; - } - ENCAPSULATING_SECURITY => { - // ESP is complex, just skip for now - return (next_header, offset); - } - _ => { - // Not an extension header, this is the final protocol - return (next_header, offset); - } - } - - if offset >= data.len() { - return (next_header, offset); - } - } - } - - fn process_icmp_packet( - &mut self, - data: &[u8], - is_outgoing: bool, - transport_data: &[u8], - src_ip: IpAddr, - dst_ip: IpAddr, - ) { - if transport_data.is_empty() { - return; - } - - let icmp_type = transport_data[0]; - let icmp_code = if transport_data.len() > 1 { - transport_data[1] - } else { - 0 - }; - - let (local_addr, remote_addr) = self.determine_addresses(src_ip, 0, dst_ip, 0, is_outgoing); - - let conn_key = format!( - "{:?}:{}-{:?}:{}", - Protocol::ICMP, - local_addr, - Protocol::ICMP, - remote_addr - ); - - let state = ProtocolState::Icmp { - icmp_type, - icmp_code, - }; - - let conn = self - .connections - .entry(conn_key) - .or_insert_with(|| Connection::new(Protocol::ICMP, local_addr, remote_addr, state)); - - // Update connection state - conn.protocol_state = state; - conn.last_activity = SystemTime::now(); - - // Update statistics - if is_outgoing { - conn.packets_sent += 1; - conn.bytes_sent += data.len() as u64; - } else { - conn.packets_received += 1; - conn.bytes_received += data.len() as u64; - } - - // Update rates - conn.update_rates(conn.bytes_sent, conn.bytes_received); - - // Set service name - set_connection_service_name_for_connection(conn, &self.service_lookup); - } - - fn process_icmpv6_packet( - &mut self, - data: &[u8], - is_outgoing: bool, - transport_data: &[u8], - src_ip: IpAddr, - dst_ip: IpAddr, - ) { - if transport_data.is_empty() { - return; - } - - let icmp_type = transport_data[0]; - let icmp_code = if transport_data.len() > 1 { - transport_data[1] - } else { - 0 - }; - - // ICMPv6 types are different from ICMPv4 - // 128 = Echo Request, 129 = Echo Reply, 1 = Destination Unreachable, 3 = Time Exceeded - - let (local_addr, remote_addr) = self.determine_addresses(src_ip, 0, dst_ip, 0, is_outgoing); - - let conn_key = format!( - "{:?}:{}-{:?}:{}", - Protocol::ICMP, - local_addr, - Protocol::ICMP, - remote_addr - ); - - let state = ProtocolState::Icmp { - icmp_type, - icmp_code, - }; - - let conn = self - .connections - .entry(conn_key) - .or_insert_with(|| Connection::new(Protocol::ICMP, local_addr, remote_addr, state)); - - // Rest of the processing is the same as ICMPv4 - conn.protocol_state = state; - conn.last_activity = SystemTime::now(); - - if is_outgoing { - conn.packets_sent += 1; - conn.bytes_sent += data.len() as u64; - } else { - conn.packets_received += 1; - conn.bytes_received += data.len() as u64; - } - - conn.update_rates(conn.bytes_sent, conn.bytes_received); - set_connection_service_name_for_connection(conn, &self.service_lookup); - } - - fn process_tcp_packet( - &mut self, - data: &[u8], - is_outgoing: bool, - transport_data: &[u8], - src_ip: IpAddr, - dst_ip: IpAddr, - ) { - if transport_data.len() < 20 { - return; - } - - let src_port = u16::from_be_bytes([transport_data[0], transport_data[1]]); - let dst_port = u16::from_be_bytes([transport_data[2], transport_data[3]]); - let flags = transport_data[13]; - - // Determine TCP state from flags - let tcp_state = match flags { - 0x02 => TcpState::SynSent, - 0x12 => TcpState::SynReceived, - 0x10 => TcpState::Established, - 0x01 => TcpState::FinWait1, - 0x11 => TcpState::FinWait2, - 0x04 => TcpState::Closed, - 0x14 => TcpState::Closing, - _ => TcpState::Established, - }; - - let (local_addr, remote_addr) = - self.determine_addresses(src_ip, src_port, dst_ip, dst_port, is_outgoing); - - let conn_key = format!( - "{:?}:{}-{:?}:{}", - Protocol::TCP, - local_addr, - Protocol::TCP, - remote_addr - ); - - let state = ProtocolState::Tcp(tcp_state); - - // Extract TCP payload for DPI - let tcp_header_len = ((transport_data[12] >> 4) as usize) * 4; - let needs_dpi = if transport_data.len() > tcp_header_len { - let tcp_payload = &transport_data[tcp_header_len..]; - !tcp_payload.is_empty() && !self.connections.contains_key(&conn_key) - } else { - false - }; - - let conn = self - .connections - .entry(conn_key.clone()) - .or_insert_with(|| Connection::new(Protocol::TCP, local_addr, remote_addr, state)); - - // Update connection state - conn.protocol_state = state; - conn.last_activity = SystemTime::now(); - - // Update statistics - if is_outgoing { - conn.packets_sent += 1; - conn.bytes_sent += data.len() as u64; - } else { - conn.packets_received += 1; - conn.bytes_received += data.len() as u64; - } - - // Update rates - conn.update_rates(conn.bytes_sent, conn.bytes_received); - - // Set service name - set_connection_service_name_for_connection(conn, &self.service_lookup); - - // Do DPI after releasing the mutable borrow - if needs_dpi && transport_data.len() > tcp_header_len { - let tcp_payload = &transport_data[tcp_header_len..]; - self.process_tcp_payload_for_dpi( - &conn_key, - tcp_payload, - local_addr.port(), - remote_addr.port(), - ); - } - } - - fn process_udp_packet( - &mut self, - data: &[u8], - is_outgoing: bool, - transport_data: &[u8], - src_ip: IpAddr, - dst_ip: IpAddr, - ) { - if transport_data.len() < 8 { - return; - } - - let src_port = u16::from_be_bytes([transport_data[0], transport_data[1]]); - let dst_port = u16::from_be_bytes([transport_data[2], transport_data[3]]); - - let (local_addr, remote_addr) = - self.determine_addresses(src_ip, src_port, dst_ip, dst_port, is_outgoing); - - let conn_key = format!( - "{:?}:{}-{:?}:{}", - Protocol::UDP, - local_addr, - Protocol::UDP, - remote_addr - ); - - let state = ProtocolState::Udp; - - // Check if we need DPI - let needs_dpi = if transport_data.len() > 8 { - let udp_payload = &transport_data[8..]; - !udp_payload.is_empty() && !self.connections.contains_key(&conn_key) - } else { - false - }; - - let conn = self - .connections - .entry(conn_key.clone()) - .or_insert_with(|| Connection::new(Protocol::UDP, local_addr, remote_addr, state)); - - // Update connection - conn.last_activity = SystemTime::now(); - - if is_outgoing { - conn.packets_sent += 1; - conn.bytes_sent += data.len() as u64; - } else { - conn.packets_received += 1; - conn.bytes_received += data.len() as u64; - } - - // Update rates - conn.update_rates(conn.bytes_sent, conn.bytes_received); - - // Set service name - set_connection_service_name_for_connection(conn, &self.service_lookup); - - // Do DPI after releasing the mutable borrow - if needs_dpi && transport_data.len() > 8 { - let udp_payload = &transport_data[8..]; - self.process_udp_payload_for_dpi( - &conn_key, - udp_payload, - local_addr.port(), - remote_addr.port(), - ); - } - } - - fn process_arp_packet(&mut self, data: &[u8]) { - let arp_data = &data[14..]; - if arp_data.len() < 28 { - return; - } - - // Parse ARP header - let hardware_type = u16::from_be_bytes([arp_data[0], arp_data[1]]); - let protocol_type = u16::from_be_bytes([arp_data[2], arp_data[3]]); - let opcode = u16::from_be_bytes([arp_data[6], arp_data[7]]); - - // We only handle Ethernet (1) and IPv4 (0x0800) - if hardware_type != 1 || protocol_type != 0x0800 { - return; - } - - let sender_ip = IpAddr::from([arp_data[14], arp_data[15], arp_data[16], arp_data[17]]); - let target_ip = IpAddr::from([arp_data[24], arp_data[25], arp_data[26], arp_data[27]]); - - let operation = match opcode { - 1 => ArpOperation::Request, - 2 => ArpOperation::Reply, - _ => return, - }; - - let is_outgoing = self.local_ips.contains(&sender_ip); - let (local_addr, remote_addr) = if is_outgoing { - (SocketAddr::new(sender_ip, 0), SocketAddr::new(target_ip, 0)) - } else { - (SocketAddr::new(target_ip, 0), SocketAddr::new(sender_ip, 0)) - }; - - let conn_key = format!( - "{:?}:{}-{:?}:{}", - Protocol::ARP, - local_addr, - Protocol::ARP, - remote_addr - ); - - let state = ProtocolState::Arp { operation }; - - let conn = self - .connections - .entry(conn_key) - .or_insert_with(|| Connection::new(Protocol::ARP, local_addr, remote_addr, state)); - - // Update connection - conn.protocol_state = state; - conn.last_activity = SystemTime::now(); - - if is_outgoing { - conn.packets_sent += 1; - conn.bytes_sent += data.len() as u64; - } else { - conn.packets_received += 1; - conn.bytes_received += data.len() as u64; - } - - // Update rates - conn.update_rates(conn.bytes_sent, conn.bytes_received); - } - - // DPI helper methods - fn process_tcp_payload_for_dpi( - &mut self, - conn_key: &str, - payload: &[u8], - local_port: u16, - remote_port: u16, - ) { - if let Some(app_protocol) = - self.identify_tcp_application_from_payload(payload, local_port, remote_port) - { - if let Some(conn) = self.connections.get_mut(conn_key) { - conn.dpi_info = Some(DpiInfo { - application: app_protocol, - first_packet_time: Instant::now(), - last_update_time: Instant::now(), - }); - } - } - } - - fn process_udp_payload_for_dpi( - &mut self, - conn_key: &str, - payload: &[u8], - local_port: u16, - remote_port: u16, - ) { - if let Some(app_protocol) = - self.identify_udp_application_from_payload(payload, local_port, remote_port) - { - if let Some(conn) = self.connections.get_mut(conn_key) { - conn.dpi_info = Some(DpiInfo { - application: app_protocol, - first_packet_time: Instant::now(), - last_update_time: Instant::now(), - }); - } - } - } - - fn identify_tcp_application_from_payload( - &self, - payload: &[u8], - local_port: u16, - remote_port: u16, - ) -> Option { - // Check for HTTP/1.x - if self.is_http_payload(payload) { - return Some(ApplicationProtocol::Http(self.parse_http_info(payload))); - } - - // Check for TLS/HTTPS - if (local_port == 443 || remote_port == 443) || self.is_tls_handshake(payload) { - if let Some(tls_info) = self.extract_tls_info(payload) { - return Some(ApplicationProtocol::Https(tls_info)); - } - } - - // Check for SSH - if (local_port == 22 || remote_port == 22) || payload.starts_with(b"SSH-") { - return Some(ApplicationProtocol::Ssh); - } - - None - } - - fn identify_udp_application_from_payload( - &self, - payload: &[u8], - local_port: u16, - remote_port: u16, - ) -> Option { - // DNS - if local_port == 53 || remote_port == 53 { - if let Some(dns_info) = self.parse_dns_packet(payload) { - return Some(ApplicationProtocol::Dns(dns_info)); - } - } - - // QUIC/HTTP3 - if (local_port == 443 || remote_port == 443) && self.is_quic_packet(payload) { - return Some(ApplicationProtocol::Quic); - } - - None - } - - // DPI implementation methods - /// Check if payload looks like HTTP/1.x - fn is_http_payload(&self, payload: &[u8]) -> bool { - if payload.len() < 4 { - return false; - } - - // HTTP request methods - payload.starts_with(b"GET ") || - payload.starts_with(b"POST ") || - payload.starts_with(b"PUT ") || - payload.starts_with(b"DELETE ") || - payload.starts_with(b"HEAD ") || - payload.starts_with(b"OPTIONS ") || - payload.starts_with(b"CONNECT ") || - payload.starts_with(b"TRACE ") || - payload.starts_with(b"PATCH ") || - // HTTP responses - payload.starts_with(b"HTTP/1.0 ") || - payload.starts_with(b"HTTP/1.1 ") - } - - /// Parse HTTP information from payload - fn parse_http_info(&self, payload: &[u8]) -> HttpInfo { - let mut info = HttpInfo { - version: HttpVersion::Http11, // Default - method: None, - host: None, - path: None, - status_code: None, - user_agent: None, - }; - - // Convert to string for easier parsing (only what we can safely convert) - let text = String::from_utf8_lossy(payload); - let lines: Vec<&str> = text.lines().collect(); - - if lines.is_empty() { - return info; - } - - // Parse first line (request or response) - let first_line = lines[0]; - let parts: Vec<&str> = first_line.split_whitespace().collect(); - - if parts.len() >= 3 { - if first_line.starts_with("HTTP/") { - // Response line: HTTP/1.1 200 OK - info.version = if parts[0] == "HTTP/1.0" { - HttpVersion::Http10 - } else { - HttpVersion::Http11 - }; - info.status_code = parts[1].parse::().ok(); - } else { - // Request line: GET /path HTTP/1.1 - info.method = Some(parts[0].to_string()); - info.path = Some(parts[1].to_string()); - info.version = if parts[2] == "HTTP/1.0" { - HttpVersion::Http10 - } else { - HttpVersion::Http11 - }; - } - } - - // Parse headers - for line in lines.iter().skip(1) { - if line.is_empty() { - break; // End of headers - } - - if let Some((key, value)) = line.split_once(':') { - let key = key.trim().to_lowercase(); - let value = value.trim(); - - match key.as_str() { - "host" => info.host = Some(value.to_string()), - "user-agent" => info.user_agent = Some(value.to_string()), - _ => {} - } - } - } - - info - } - - /// Check if this is a TLS handshake packet - fn is_tls_handshake(&self, payload: &[u8]) -> bool { - if payload.len() < 6 { - return false; - } - - // TLS record header: - // - Content type (1 byte): 0x16 for handshake - // - Version (2 bytes): 0x0301-0x0304 for TLS 1.0-1.3 - // - Length (2 bytes) - - payload[0] == 0x16 && // Handshake content type - payload[1] == 0x03 && // Major version 3 - (payload[2] >= 0x01 && payload[2] <= 0x04) // Minor version 1-4 - } - - /// Extract TLS information from handshake - fn extract_tls_info(&self, payload: &[u8]) -> Option { - if !self.is_tls_handshake(payload) || payload.len() < 9 { - return None; - } - - let mut info = TlsInfo { - version: None, - sni: None, - alpn: Vec::new(), - cipher_suite: None, - }; - - // Record layer version - let record_version = match payload[2] { - 0x01 => Some(TlsVersion::Tls10), - 0x02 => Some(TlsVersion::Tls11), - 0x03 => Some(TlsVersion::Tls12), - 0x04 => Some(TlsVersion::Tls13), - _ => None, - }; - - // Skip TLS record header (5 bytes) - let handshake_data = &payload[5..]; - - if handshake_data.len() < 4 { - return Some(info); - } - - let handshake_type = handshake_data[0]; - - match handshake_type { - 0x01 => { - // Client Hello - info.version = record_version; - if let Some((sni, alpn)) = self.parse_client_hello_extensions(handshake_data) { - info.sni = sni; - info.alpn = alpn; - } - } - 0x02 => { - // Server Hello - info.version = record_version; - // Could parse cipher suite here if needed - } - _ => {} - } - - Some(info) - } - - /// Parse Client Hello extensions for SNI and ALPN - fn parse_client_hello_extensions( - &self, - handshake_data: &[u8], - ) -> Option<(Option, Vec)> { - if handshake_data.len() < 38 { - return None; - } - - // Skip to extensions: - // - Handshake type (1) + Length (3) + Version (2) + Random (32) = 38 - let mut offset = 38; - - // Session ID - if offset >= handshake_data.len() { - return None; - } - let session_id_len = handshake_data[offset] as usize; - offset += 1 + session_id_len; - - // Cipher suites - if offset + 2 > handshake_data.len() { - return None; - } - let cipher_suites_len = - u16::from_be_bytes([handshake_data[offset], handshake_data[offset + 1]]) as usize; - offset += 2 + cipher_suites_len; - - // Compression methods - if offset >= handshake_data.len() { - return None; - } - let compression_len = handshake_data[offset] as usize; - offset += 1 + compression_len; - - // Extensions length - if offset + 2 > handshake_data.len() { - return None; - } - let extensions_len = - u16::from_be_bytes([handshake_data[offset], handshake_data[offset + 1]]) as usize; - offset += 2; - - if offset + extensions_len > handshake_data.len() { - return None; - } - - // Parse extensions - let mut sni = None; - let mut alpn = Vec::new(); - let extensions_data = &handshake_data[offset..offset + extensions_len]; - let mut ext_offset = 0; - - while ext_offset + 4 <= extensions_data.len() { - let ext_type = - u16::from_be_bytes([extensions_data[ext_offset], extensions_data[ext_offset + 1]]); - let ext_len = u16::from_be_bytes([ - extensions_data[ext_offset + 2], - extensions_data[ext_offset + 3], - ]) as usize; - - if ext_offset + 4 + ext_len > extensions_data.len() { - break; - } - - match ext_type { - 0x0000 => { - // SNI - sni = self.parse_sni_extension( - &extensions_data[ext_offset + 4..ext_offset + 4 + ext_len], - ); - } - 0x0010 => { - // ALPN - alpn = self.parse_alpn_extension( - &extensions_data[ext_offset + 4..ext_offset + 4 + ext_len], - ); - } - _ => {} - } - - ext_offset += 4 + ext_len; - } - - Some((sni, alpn)) - } - - /// Parse SNI extension - fn parse_sni_extension(&self, data: &[u8]) -> Option { - if data.len() < 5 { - return None; - } - - // Skip server name list length (2 bytes) - let mut offset = 2; - - while offset + 3 <= data.len() { - let name_type = data[offset]; - let name_len = u16::from_be_bytes([data[offset + 1], data[offset + 2]]) as usize; - - if name_type == 0x00 { - // host_name - if offset + 3 + name_len <= data.len() { - let hostname_bytes = &data[offset + 3..offset + 3 + name_len]; - if let Ok(hostname) = std::str::from_utf8(hostname_bytes) { - return Some(hostname.to_string()); - } - } - } - - offset += 3 + name_len; - } - - None - } - - /// Parse ALPN extension - fn parse_alpn_extension(&self, data: &[u8]) -> Vec { - let mut protocols = Vec::new(); - - if data.len() < 2 { - return protocols; - } - - // Skip ALPN extension length - let mut offset = 2; - - while offset < data.len() { - let proto_len = data[offset] as usize; - if offset + 1 + proto_len <= data.len() { - if let Ok(proto) = std::str::from_utf8(&data[offset + 1..offset + 1 + proto_len]) { - protocols.push(proto.to_string()); - } - } - offset += 1 + proto_len; - } - - protocols - } - - /// Parse DNS packet - fn parse_dns_packet(&self, payload: &[u8]) -> Option { - if payload.len() < 12 { - return None; - } - - let mut info = DnsInfo { - query_name: None, - query_type: None, - response_ips: Vec::new(), - is_response: false, - }; - - // DNS header flags - let flags = u16::from_be_bytes([payload[2], payload[3]]); - info.is_response = (flags & 0x8000) != 0; // QR bit - - // Question count - let qdcount = u16::from_be_bytes([payload[4], payload[5]]); - - if qdcount > 0 { - // Parse first question - let mut offset = 12; - let mut name = String::new(); - - // Parse domain name - while offset < payload.len() { - let label_len = payload[offset] as usize; - if label_len == 0 { - offset += 1; - break; - } - - if label_len >= 0xC0 { - // Compressed name - skip for simplicity - offset += 2; - break; - } - - if offset + 1 + label_len > payload.len() { - break; - } - - if !name.is_empty() { - name.push('.'); - } - - if let Ok(label) = std::str::from_utf8(&payload[offset + 1..offset + 1 + label_len]) - { - name.push_str(label); - } - - offset += 1 + label_len; - } - - if !name.is_empty() { - info.query_name = Some(name); - } - - // Query type - if offset + 2 <= payload.len() { - let qtype = u16::from_be_bytes([payload[offset], payload[offset + 1]]); - info.query_type = Some(match qtype { - 1 => DnsQueryType::A, - 28 => DnsQueryType::AAAA, - 5 => DnsQueryType::CNAME, - 15 => DnsQueryType::MX, - 16 => DnsQueryType::TXT, - other => DnsQueryType::Other(other), - }); - } - } - - Some(info) - } - - /// Check if this is a QUIC packet - fn is_quic_packet(&self, payload: &[u8]) -> bool { - if payload.len() < 5 { - return false; - } - - // Check for QUIC long header (bit 7 set) - if (payload[0] & 0x80) != 0 { - // Check version - let version = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]); - - // Known QUIC versions - return version == 0x00000001 || // QUIC v1 - version == 0x6b3343cf || // QUIC v2 - version == 0x51303530 || // Google QUIC - version == 0; // Version negotiation - } - - // Could be short header QUIC packet - // These are harder to identify definitively, but if we see them on port 443 UDP, - // they're likely QUIC - true - } - - fn get_connection_key_for_merge(&self, conn: &Connection) -> String { - format!( - "{:?}:{}-{:?}:{}", - conn.protocol, conn.local_addr, conn.protocol, conn.remote_addr - ) - } -} - -fn parse_addr(addr_str: &str) -> Option { - let addr_str = addr_str.trim(); - - if let Ok(socket_addr) = addr_str.parse::() { - return Some(socket_addr); - } - - if let Ok(port) = addr_str.parse::() { - return Some(std::net::SocketAddr::new( - std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), - port, - )); - } - - if let Some(dot_idx) = addr_str.rfind('.') { - if let Some(socket_addr) = parse_with_separator(addr_str, dot_idx) { - return Some(socket_addr); - } - } - - if let Some(colon_idx) = addr_str.rfind(':') { - if let Some(socket_addr) = parse_with_separator(addr_str, colon_idx) { - return Some(socket_addr); - } - } - - None -} - -fn parse_with_separator(addr_str: &str, sep_idx: usize) -> Option { - let (host_part, port_part) = addr_str.split_at(sep_idx); - let port_part = &port_part[1..]; - - let host = if host_part.starts_with('[') && host_part.ends_with(']') { - &host_part[1..host_part.len() - 1] - } else { - host_part - }; - - let ip_addr = host.parse::().ok()?; - let port = if port_part == "*" { - 0 - } else { - port_part.parse::().ok()? - }; - - Some(std::net::SocketAddr::new(ip_addr, port)) -} +use anyhow::Result; + +// submodules +pub mod capture; +pub mod dpi; +pub mod merge; +pub mod parser; +pub mod platform; +pub mod services; +pub mod types; + +// Re-export commonly used items at the module root +pub use capture::setup_packet_capture; +pub use parser::{PacketParser, ParsedPacket}; +pub use platform::{ConnectionKey, ProcessLookup, create_process_lookup}; +pub use services::ServiceLookup; +pub use types::{ApplicationProtocol, Connection, DpiInfo, Protocol, ProtocolState, TcpState}; diff --git a/src/network/mod.rs.old b/src/network/mod.rs.old new file mode 100644 index 0000000..4fe124a --- /dev/null +++ b/src/network/mod.rs.old @@ -0,0 +1,1713 @@ +use anyhow::{Result, anyhow}; +use log::{debug, error, info, warn}; +use pcap::{Capture, Device}; +use std::collections::HashMap; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::path::Path; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::Sender; +use std::time::{Duration, Instant, SystemTime}; + +#[cfg(target_os = "linux")] +mod linux; + +#[cfg(target_os = "windows")] +mod windows; +#[cfg(target_os = "windows")] +use windows::*; + +#[cfg(target_os = "macos")] +mod macos; + +/// Transport protocol +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Protocol { + TCP, + UDP, + ICMP, + ARP, +} + +impl std::fmt::Display for Protocol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Protocol::TCP => write!(f, "TCP"), + Protocol::UDP => write!(f, "UDP"), + Protocol::ICMP => write!(f, "ICMP"), + Protocol::ARP => write!(f, "ARP"), + } + } +} + +/// TCP connection state +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TcpState { + Listen, + SynSent, + SynReceived, + Established, + FinWait1, + FinWait2, + CloseWait, + LastAck, + TimeWait, + Closing, + Closed, +} + +/// Protocol-specific state information +#[derive(Debug, Clone, Copy)] +pub enum ProtocolState { + Tcp(TcpState), + Udp, // UDP is stateless + Icmp { + icmp_type: u8, // 8=Echo Request, 0=Echo Reply, etc. + icmp_code: u8, + }, + Arp { + operation: ArpOperation, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ArpOperation { + Request, + Reply, +} + +/// Application layer protocol detection +#[derive(Debug, Clone)] +pub enum ApplicationProtocol { + Http(HttpInfo), + Https(TlsInfo), + Dns(DnsInfo), + Ssh, + Quic, // Basic QUIC detection without deep parsing + Unknown, +} + +/// HTTP information +#[derive(Debug, Clone)] +pub struct HttpInfo { + pub version: HttpVersion, + pub method: Option, // GET, POST, etc. + pub host: Option, // From Host header + pub path: Option, // Request path + pub status_code: Option, // For responses + pub user_agent: Option, // Useful for identifying clients +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HttpVersion { + Http10, + Http11, + Http2, + Http3, // Inferred from QUIC +} + +/// TLS/HTTPS information +#[derive(Debug, Clone)] +pub struct TlsInfo { + pub version: Option, + pub sni: Option, + pub alpn: Vec, // Application protocols like "h2", "http/1.1" + pub cipher_suite: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TlsVersion { + Ssl3, + Tls10, + Tls11, + Tls12, + Tls13, +} + +/// DNS information +#[derive(Debug, Clone)] +pub struct DnsInfo { + pub query_name: Option, + pub query_type: Option, + pub response_ips: Vec, + pub is_response: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DnsQueryType { + A, + AAAA, + CNAME, + MX, + TXT, + Other(u16), +} + +/// Deep packet inspection results +#[derive(Debug, Clone)] +pub struct DpiInfo { + pub application: ApplicationProtocol, + pub first_packet_time: Instant, + pub last_update_time: Instant, +} + +/// Rate information +#[derive(Debug, Clone)] +pub struct RateInfo { + pub incoming_bps: f64, + pub outgoing_bps: f64, + pub last_calculation: Instant, +} + +impl Default for RateInfo { + fn default() -> Self { + Self { + incoming_bps: 0.0, + outgoing_bps: 0.0, + last_calculation: Instant::now(), + } + } +} + +/// Network connection +#[derive(Debug, Clone)] +pub struct Connection { + // Core identification + pub protocol: Protocol, + pub local_addr: SocketAddr, + pub remote_addr: SocketAddr, + + // Protocol state + pub protocol_state: ProtocolState, + + // Process information + pub pid: Option, + pub process_name: Option, + + // Traffic statistics + pub bytes_sent: u64, + pub bytes_received: u64, + pub packets_sent: u64, + pub packets_received: u64, + + // Timing + pub created_at: SystemTime, + pub last_activity: SystemTime, + + // Service identification + pub service_name: Option, // From port lookup + + // Deep packet inspection + pub dpi_info: Option, + + // Performance metrics + pub current_rate_bps: RateInfo, + pub rtt_estimate: Option, // Round-trip time if measurable + + // Backward compatibility fields + pub current_incoming_rate_bps: f64, + pub current_outgoing_rate_bps: f64, +} + +// Add a simple state field for backward compatibility +impl Connection { + pub fn state(&self) -> String { + match &self.protocol_state { + ProtocolState::Tcp(tcp_state) => format!("{:?}", tcp_state), + ProtocolState::Udp => "ACTIVE".to_string(), + ProtocolState::Icmp { icmp_type, .. } => match icmp_type { + 8 => "ECHO_REQUEST".to_string(), + 0 => "ECHO_REPLY".to_string(), + 3 => "DEST_UNREACH".to_string(), + 11 => "TIME_EXCEEDED".to_string(), + _ => "UNKNOWN".to_string(), + }, + ProtocolState::Arp { operation } => match operation { + ArpOperation::Request => "ARP_REQUEST".to_string(), + ArpOperation::Reply => "ARP_REPLY".to_string(), + }, + } + } +} + +impl Connection { + /// Create a new connection + pub fn new( + protocol: Protocol, + local_addr: SocketAddr, + remote_addr: SocketAddr, + state: ProtocolState, + ) -> Self { + let now = SystemTime::now(); + Self { + protocol, + local_addr, + remote_addr, + protocol_state: state, + pid: None, + process_name: None, + bytes_sent: 0, + bytes_received: 0, + packets_sent: 0, + packets_received: 0, + created_at: now, + last_activity: now, + service_name: None, + dpi_info: None, + current_rate_bps: RateInfo::default(), + rtt_estimate: None, + // Backward compatibility + current_incoming_rate_bps: 0.0, + current_outgoing_rate_bps: 0.0, + } + } + + /// Check if connection is active (had activity in the last minute) + pub fn is_active(&self) -> bool { + self.last_activity.elapsed().unwrap_or_default() < Duration::from_secs(60) + } + + /// Get the age of the connection (time since creation) + pub fn age(&self) -> Duration { + self.created_at.elapsed().unwrap_or_default() + } + + /// Get time since last activity + pub fn idle_time(&self) -> Duration { + self.last_activity.elapsed().unwrap_or_default() + } + + /// Update transfer rates + pub fn update_rates(&mut self, new_sent: u64, new_received: u64) { + let now = Instant::now(); + let elapsed = now + .duration_since(self.current_rate_bps.last_calculation) + .as_secs_f64(); + + if elapsed > 0.1 { + // Update rates every 100ms minimum + let sent_diff = new_sent.saturating_sub(self.bytes_sent) as f64; + let recv_diff = new_received.saturating_sub(self.bytes_received) as f64; + + self.current_rate_bps = RateInfo { + outgoing_bps: (sent_diff * 8.0) / elapsed, + incoming_bps: (recv_diff * 8.0) / elapsed, + last_calculation: now, + }; + + // Update backward compatibility fields + self.current_incoming_rate_bps = self.current_rate_bps.incoming_bps; + self.current_outgoing_rate_bps = self.current_rate_bps.outgoing_bps; + } + } +} + +/// Process information +#[derive(Debug, Clone)] +pub struct Process { + pub pid: u32, + pub name: String, +} + +/// Main function for the packet capture thread +pub fn packet_capture_thread( + interface_name: Option, + packet_tx: Sender>, + should_stop: Arc, +) -> Result<()> { + let cap_device = match interface_name { + Some(iface) => { + info!("Searching for specified interface: {}", iface); + Device::list()? + .into_iter() + .find(|d| d.name == iface) + .ok_or_else(|| anyhow!("Interface '{}' not found", iface))? + } + None => { + info!("No interface specified, looking up default."); + Device::lookup()?.ok_or_else(|| anyhow!("No default device found"))? + } + }; + + info!("Opening capture on device: {}", cap_device.name); + let mut cap = Capture::from_device(cap_device)? + .promisc(true) + .snaplen(1024) // Increased for DPI + .buffer_size(2_000_000) + .timeout(0) + .immediate_mode(true) + .open()?; + + info!("Applying BPF filter for IPv4 and IPv6"); + cap.filter( + "(ip and (tcp or udp or icmp)) or (ip6 and (tcp or udp or icmp6)) or arp", + true, + )?; + + loop { + if should_stop.load(Ordering::Relaxed) { + info!("Stop signal received, shutting down capture thread."); + break; + } + match cap.next_packet() { + Ok(packet) => { + if packet_tx.send(packet.data.to_vec()).is_err() { + info!("Packet receiver has disconnected, stopping capture thread."); + break; + } + } + Err(pcap::Error::TimeoutExpired) => { + debug!("Timeout expired, no packet captured this iteration."); + continue; + } + Err(e) => { + error!("Error capturing packet: {}", e); + break; + } + } + } + + Ok(()) +} + +/// Network monitor +pub struct NetworkMonitor { + connections: HashMap, + service_lookup: ServiceLookup, + filter_localhost: bool, + local_ips: std::collections::HashSet, +} + +/// Manages lookup of service names from a services file +#[derive(Debug)] +struct ServiceLookup { + services: HashMap<(u16, Protocol), String>, +} + +impl ServiceLookup { + fn new(file_path_str: &str) -> Result { + let mut services = HashMap::new(); + let file_path = Path::new(file_path_str); + + if !file_path.exists() { + warn!( + "Service definition file not found at '{}'. Service names will not be available.", + file_path_str + ); + return Ok(Self { services }); + } + + let file = File::open(file_path)?; + let reader = BufReader::new(file); + + for line_result in reader.lines() { + let line = match line_result { + Ok(l) => l, + Err(e) => { + warn!("Error reading line from services file: {}", e); + continue; + } + }; + + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 2 { + debug!("Skipping malformed line in services file: {}", line); + continue; + } + + let service_name = parts[0].to_string(); + let port_protocol_str = parts[1]; + + let port_protocol_parts: Vec<&str> = port_protocol_str.split('/').collect(); + if port_protocol_parts.len() != 2 { + debug!( + "Skipping malformed port/protocol in services file: {} from line: {}", + port_protocol_str, line + ); + continue; + } + + let port = match port_protocol_parts[0].parse::() { + Ok(p) => p, + Err(_) => { + debug!( + "Skipping invalid port in services file: {} from line: {}", + port_protocol_parts[0], line + ); + continue; + } + }; + + let protocol = match port_protocol_parts[1].to_lowercase().as_str() { + "tcp" => Protocol::TCP, + "udp" => Protocol::UDP, + _ => continue, + }; + + services.entry((port, protocol)).or_insert(service_name); + } + debug!( + "ServiceLookup initialized with {} entries from '{}'", + services.len(), + file_path_str + ); + Ok(Self { services }) + } + + fn get(&self, port: u16, protocol: Protocol) -> Option { + self.services.get(&(port, protocol)).cloned() + } +} + +/// Sets the service name for a given connection based on its port and protocol +fn set_connection_service_name_for_connection( + conn: &mut Connection, + service_lookup: &ServiceLookup, +) { + let local_port = conn.local_addr.port(); + let remote_port = conn.remote_addr.port(); + let protocol = conn.protocol; + + let mut final_service_name: Option = None; + + match conn.protocol_state { + ProtocolState::Tcp(TcpState::Listen) => { + final_service_name = service_lookup.get(local_port, protocol); + } + _ => { + let local_service_name_opt = service_lookup.get(local_port, protocol); + if local_service_name_opt.is_some() { + final_service_name = local_service_name_opt; + } else { + let remote_service_name_opt = service_lookup.get(remote_port, protocol); + if remote_service_name_opt.is_some() { + final_service_name = remote_service_name_opt; + } + } + } + } + conn.service_name = final_service_name; +} + +impl NetworkMonitor { + /// Create a new network monitor + pub fn new(_interface: Option, filter_localhost: bool) -> Result { + log::info!("NetworkMonitor::new - Initializing"); + + let mut local_ips = std::collections::HashSet::new(); + for iface in pnet_datalink::interfaces() { + for ip_network in iface.ips { + local_ips.insert(ip_network.ip()); + } + } + + if local_ips.is_empty() { + warn!( + "Could not determine any local IP addresses. Connection directionality might be inaccurate." + ); + } else { + debug!("Found local IPs: {:?}", local_ips); + } + + let services_file_path = "assets/services"; + let service_lookup = ServiceLookup::new(services_file_path).unwrap_or_else(|e| { + error!( + "Failed to load service definitions from '{}': {}. Proceeding without service names.", + services_file_path, e + ); + ServiceLookup { + services: HashMap::new(), + } + }); + + log::info!("NetworkMonitor::new - Initialization complete"); + Ok(Self { + local_ips, + service_lookup, + connections: HashMap::new(), + filter_localhost, + }) + } + + /// Get active connections + pub fn get_connections(&mut self) -> Result> { + // Start with pcap-captured connections as the primary source + let mut result_connections: Vec = self + .connections + .values() + .filter(|conn| conn.is_active()) + .cloned() + .collect(); + + debug!( + "get_connections: Found {} active pcap connections", + result_connections.len() + ); + + // Enrich pcap connections with process information from platform + if !result_connections.is_empty() { + // Get connection info with processes from platform + let mut platform_conns: Vec = Vec::new(); + + #[cfg(target_os = "linux")] + { + if let Err(e) = linux::get_connections_with_process_info(&mut platform_conns) { + error!("Error getting process info from platform: {}", e); + } + } + #[cfg(target_os = "macos")] + { + if let Err(e) = macos::get_connections_with_process_info(&mut platform_conns) { + error!("Error getting process info from platform: {}", e); + } + } + #[cfg(target_os = "windows")] + { + if let Err(e) = windows::get_connections_with_process_info(&mut platform_conns) { + error!("Error getting process info from platform: {}", e); + } + } + + debug!( + "Found {} platform connections for process enrichment", + platform_conns.len() + ); + + // Create a lookup map for platform connections + let mut platform_lookup: HashMap = HashMap::new(); + for conn in platform_conns { + if let (Some(pid), Some(name)) = (conn.pid, conn.process_name) { + let key = format!( + "{:?}:{}-{:?}:{}", + conn.protocol, conn.local_addr, conn.protocol, conn.remote_addr + ); + platform_lookup.insert(key, (pid, name)); + } + } + + // Enrich pcap connections with process names + for conn in &mut result_connections { + if conn.process_name.is_none() { + let key = self.get_connection_key_for_merge(conn); + if let Some((pid, name)) = platform_lookup.get(&key) { + debug!( + "Enriching connection {}:{} with process {} (PID: {})", + conn.local_addr, conn.remote_addr, name, pid + ); + conn.process_name = Some(name.clone()); + } + } + } + } + + // Sort by last activity (most recent first) + result_connections.sort_by(|a, b| b.last_activity.cmp(&a.last_activity)); + + // Apply localhost filter if enabled + if self.filter_localhost { + result_connections.retain(|conn| { + !(conn.local_addr.ip().is_loopback() && conn.remote_addr.ip().is_loopback()) + }); + } + + // Set service names + for conn in &mut result_connections { + set_connection_service_name_for_connection(conn, &self.service_lookup); + if conn.current_rate_bps.incoming_bps > 0.0 || conn.current_rate_bps.outgoing_bps > 0.0 + { + debug!( + "Connection: {:?}, Incoming: {:.2} bps, Outgoing: {:.2} bps", + conn.local_addr, + conn.current_rate_bps.incoming_bps, + conn.current_rate_bps.outgoing_bps + ); + } + } + + debug!( + "get_connections: Returning {} total connections", + result_connections.len() + ); + Ok(result_connections) + } + + fn determine_addresses( + &self, + src_ip: IpAddr, + src_port: u16, + dst_ip: IpAddr, + dst_port: u16, + is_outgoing: bool, + ) -> (SocketAddr, SocketAddr) { + if is_outgoing { + ( + SocketAddr::new(src_ip, src_port), + SocketAddr::new(dst_ip, dst_port), + ) + } else { + ( + SocketAddr::new(dst_ip, dst_port), + SocketAddr::new(src_ip, src_port), + ) + } + } + + /// Process a single raw packet from the queue + pub fn process_packet(&mut self, data: &[u8]) { + if data.len() < 14 { + return; + } + + // Check EtherType to determine packet type + let ethertype = u16::from_be_bytes([data[12], data[13]]); + + match ethertype { + 0x0800 => { + // IPv4 packet + self.process_ipv4_packet(data); + } + 0x86dd => { + // IPv6 packet + self.process_ipv6_packet(data); + } + 0x0806 => { + // ARP packet + self.process_arp_packet(data); + } + _ => { + // Other packet types - ignore + } + } + } + + fn process_ipv4_packet(&mut self, data: &[u8]) { + let ip_data = &data[14..]; + if ip_data.len() < 20 { + return; + } + + let version = ip_data[0] >> 4; + if version != 4 { + return; + } + + let protocol = ip_data[9]; + let src_ip = IpAddr::V4(Ipv4Addr::new( + ip_data[12], + ip_data[13], + ip_data[14], + ip_data[15], + )); + let dst_ip = IpAddr::V4(Ipv4Addr::new( + ip_data[16], + ip_data[17], + ip_data[18], + ip_data[19], + )); + + let ihl = ip_data[0] & 0x0F; + let ip_header_len = (ihl as usize) * 4; + + if ip_data.len() < ip_header_len { + return; + } + let transport_data = &ip_data[ip_header_len..]; + + let is_outgoing = self.local_ips.contains(&src_ip); + + match protocol { + 1 => self.process_icmp_packet(data, is_outgoing, transport_data, src_ip, dst_ip), + 6 => self.process_tcp_packet(data, is_outgoing, transport_data, src_ip, dst_ip), + 17 => self.process_udp_packet(data, is_outgoing, transport_data, src_ip, dst_ip), + _ => {} + } + } + + fn process_ipv6_packet(&mut self, data: &[u8]) { + let ip_data = &data[14..]; + if ip_data.len() < 40 { + // IPv6 header is fixed 40 bytes + return; + } + + let version = ip_data[0] >> 4; + if version != 6 { + return; + } + + let next_header = ip_data[6]; // Protocol type + + // Extract IPv6 addresses + let src_ip = IpAddr::V6(Ipv6Addr::new( + u16::from_be_bytes([ip_data[8], ip_data[9]]), + u16::from_be_bytes([ip_data[10], ip_data[11]]), + u16::from_be_bytes([ip_data[12], ip_data[13]]), + u16::from_be_bytes([ip_data[14], ip_data[15]]), + u16::from_be_bytes([ip_data[16], ip_data[17]]), + u16::from_be_bytes([ip_data[18], ip_data[19]]), + u16::from_be_bytes([ip_data[20], ip_data[21]]), + u16::from_be_bytes([ip_data[22], ip_data[23]]), + )); + + let dst_ip = IpAddr::V6(Ipv6Addr::new( + u16::from_be_bytes([ip_data[24], ip_data[25]]), + u16::from_be_bytes([ip_data[26], ip_data[27]]), + u16::from_be_bytes([ip_data[28], ip_data[29]]), + u16::from_be_bytes([ip_data[30], ip_data[31]]), + u16::from_be_bytes([ip_data[32], ip_data[33]]), + u16::from_be_bytes([ip_data[34], ip_data[35]]), + u16::from_be_bytes([ip_data[36], ip_data[37]]), + u16::from_be_bytes([ip_data[38], ip_data[39]]), + )); + + let transport_data = &ip_data[40..]; // IPv6 header is always 40 bytes + let is_outgoing = self.local_ips.contains(&src_ip); + + // Handle extension headers if present + let (final_next_header, transport_offset) = + self.parse_ipv6_extension_headers(next_header, transport_data); + let final_transport_data = &transport_data[transport_offset..]; + + match final_next_header { + 58 => { + self.process_icmpv6_packet(data, is_outgoing, final_transport_data, src_ip, dst_ip) + } + 6 => self.process_tcp_packet(data, is_outgoing, final_transport_data, src_ip, dst_ip), + 17 => self.process_udp_packet(data, is_outgoing, final_transport_data, src_ip, dst_ip), + _ => {} + } + } + + fn parse_ipv6_extension_headers(&self, mut next_header: u8, data: &[u8]) -> (u8, usize) { + let mut offset = 0; + + // Common IPv6 extension headers + const HOP_BY_HOP: u8 = 0; + const ROUTING: u8 = 43; + const FRAGMENT: u8 = 44; + const ENCAPSULATING_SECURITY: u8 = 50; + const AUTHENTICATION: u8 = 51; + const DESTINATION_OPTIONS: u8 = 60; + + loop { + match next_header { + HOP_BY_HOP | ROUTING | DESTINATION_OPTIONS => { + if data.len() < offset + 2 { + return (next_header, offset); + } + next_header = data[offset]; + let header_len = ((data[offset + 1] as usize) + 1) * 8; + offset += header_len; + } + FRAGMENT => { + if data.len() < offset + 8 { + return (next_header, offset); + } + next_header = data[offset]; + offset += 8; // Fragment header is fixed 8 bytes + } + AUTHENTICATION => { + if data.len() < offset + 2 { + return (next_header, offset); + } + next_header = data[offset]; + let header_len = ((data[offset + 1] as usize) + 2) * 4; + offset += header_len; + } + ENCAPSULATING_SECURITY => { + // ESP is complex, just skip for now + return (next_header, offset); + } + _ => { + // Not an extension header, this is the final protocol + return (next_header, offset); + } + } + + if offset >= data.len() { + return (next_header, offset); + } + } + } + + fn process_icmp_packet( + &mut self, + data: &[u8], + is_outgoing: bool, + transport_data: &[u8], + src_ip: IpAddr, + dst_ip: IpAddr, + ) { + if transport_data.is_empty() { + return; + } + + let icmp_type = transport_data[0]; + let icmp_code = if transport_data.len() > 1 { + transport_data[1] + } else { + 0 + }; + + let (local_addr, remote_addr) = self.determine_addresses(src_ip, 0, dst_ip, 0, is_outgoing); + + let conn_key = format!( + "{:?}:{}-{:?}:{}", + Protocol::ICMP, + local_addr, + Protocol::ICMP, + remote_addr + ); + + let state = ProtocolState::Icmp { + icmp_type, + icmp_code, + }; + + let conn = self + .connections + .entry(conn_key) + .or_insert_with(|| Connection::new(Protocol::ICMP, local_addr, remote_addr, state)); + + // Update connection state + conn.protocol_state = state; + conn.last_activity = SystemTime::now(); + + // Update statistics + if is_outgoing { + conn.packets_sent += 1; + conn.bytes_sent += data.len() as u64; + } else { + conn.packets_received += 1; + conn.bytes_received += data.len() as u64; + } + + // Update rates + conn.update_rates(conn.bytes_sent, conn.bytes_received); + + // Set service name + set_connection_service_name_for_connection(conn, &self.service_lookup); + } + + fn process_icmpv6_packet( + &mut self, + data: &[u8], + is_outgoing: bool, + transport_data: &[u8], + src_ip: IpAddr, + dst_ip: IpAddr, + ) { + if transport_data.is_empty() { + return; + } + + let icmp_type = transport_data[0]; + let icmp_code = if transport_data.len() > 1 { + transport_data[1] + } else { + 0 + }; + + // ICMPv6 types are different from ICMPv4 + // 128 = Echo Request, 129 = Echo Reply, 1 = Destination Unreachable, 3 = Time Exceeded + + let (local_addr, remote_addr) = self.determine_addresses(src_ip, 0, dst_ip, 0, is_outgoing); + + let conn_key = format!( + "{:?}:{}-{:?}:{}", + Protocol::ICMP, + local_addr, + Protocol::ICMP, + remote_addr + ); + + let state = ProtocolState::Icmp { + icmp_type, + icmp_code, + }; + + let conn = self + .connections + .entry(conn_key) + .or_insert_with(|| Connection::new(Protocol::ICMP, local_addr, remote_addr, state)); + + // Rest of the processing is the same as ICMPv4 + conn.protocol_state = state; + conn.last_activity = SystemTime::now(); + + if is_outgoing { + conn.packets_sent += 1; + conn.bytes_sent += data.len() as u64; + } else { + conn.packets_received += 1; + conn.bytes_received += data.len() as u64; + } + + conn.update_rates(conn.bytes_sent, conn.bytes_received); + set_connection_service_name_for_connection(conn, &self.service_lookup); + } + + fn process_tcp_packet( + &mut self, + data: &[u8], + is_outgoing: bool, + transport_data: &[u8], + src_ip: IpAddr, + dst_ip: IpAddr, + ) { + if transport_data.len() < 20 { + return; + } + + let src_port = u16::from_be_bytes([transport_data[0], transport_data[1]]); + let dst_port = u16::from_be_bytes([transport_data[2], transport_data[3]]); + let flags = transport_data[13]; + + // Determine TCP state from flags + let tcp_state = match flags { + 0x02 => TcpState::SynSent, + 0x12 => TcpState::SynReceived, + 0x10 => TcpState::Established, + 0x01 => TcpState::FinWait1, + 0x11 => TcpState::FinWait2, + 0x04 => TcpState::Closed, + 0x14 => TcpState::Closing, + _ => TcpState::Established, + }; + + let (local_addr, remote_addr) = + self.determine_addresses(src_ip, src_port, dst_ip, dst_port, is_outgoing); + + let conn_key = format!( + "{:?}:{}-{:?}:{}", + Protocol::TCP, + local_addr, + Protocol::TCP, + remote_addr + ); + + let state = ProtocolState::Tcp(tcp_state); + + // Extract TCP payload for DPI + let tcp_header_len = ((transport_data[12] >> 4) as usize) * 4; + let needs_dpi = if transport_data.len() > tcp_header_len { + let tcp_payload = &transport_data[tcp_header_len..]; + !tcp_payload.is_empty() && !self.connections.contains_key(&conn_key) + } else { + false + }; + + let conn = self + .connections + .entry(conn_key.clone()) + .or_insert_with(|| Connection::new(Protocol::TCP, local_addr, remote_addr, state)); + + // Update connection state + conn.protocol_state = state; + conn.last_activity = SystemTime::now(); + + // Update statistics + if is_outgoing { + conn.packets_sent += 1; + conn.bytes_sent += data.len() as u64; + } else { + conn.packets_received += 1; + conn.bytes_received += data.len() as u64; + } + + // Update rates + conn.update_rates(conn.bytes_sent, conn.bytes_received); + + // Set service name + set_connection_service_name_for_connection(conn, &self.service_lookup); + + // Do DPI after releasing the mutable borrow + if needs_dpi && transport_data.len() > tcp_header_len { + let tcp_payload = &transport_data[tcp_header_len..]; + self.process_tcp_payload_for_dpi( + &conn_key, + tcp_payload, + local_addr.port(), + remote_addr.port(), + ); + } + } + + fn process_udp_packet( + &mut self, + data: &[u8], + is_outgoing: bool, + transport_data: &[u8], + src_ip: IpAddr, + dst_ip: IpAddr, + ) { + if transport_data.len() < 8 { + return; + } + + let src_port = u16::from_be_bytes([transport_data[0], transport_data[1]]); + let dst_port = u16::from_be_bytes([transport_data[2], transport_data[3]]); + + let (local_addr, remote_addr) = + self.determine_addresses(src_ip, src_port, dst_ip, dst_port, is_outgoing); + + let conn_key = format!( + "{:?}:{}-{:?}:{}", + Protocol::UDP, + local_addr, + Protocol::UDP, + remote_addr + ); + + let state = ProtocolState::Udp; + + // Check if we need DPI + let needs_dpi = if transport_data.len() > 8 { + let udp_payload = &transport_data[8..]; + !udp_payload.is_empty() && !self.connections.contains_key(&conn_key) + } else { + false + }; + + let conn = self + .connections + .entry(conn_key.clone()) + .or_insert_with(|| Connection::new(Protocol::UDP, local_addr, remote_addr, state)); + + // Update connection + conn.last_activity = SystemTime::now(); + + if is_outgoing { + conn.packets_sent += 1; + conn.bytes_sent += data.len() as u64; + } else { + conn.packets_received += 1; + conn.bytes_received += data.len() as u64; + } + + // Update rates + conn.update_rates(conn.bytes_sent, conn.bytes_received); + + // Set service name + set_connection_service_name_for_connection(conn, &self.service_lookup); + + // Do DPI after releasing the mutable borrow + if needs_dpi && transport_data.len() > 8 { + let udp_payload = &transport_data[8..]; + self.process_udp_payload_for_dpi( + &conn_key, + udp_payload, + local_addr.port(), + remote_addr.port(), + ); + } + } + + fn process_arp_packet(&mut self, data: &[u8]) { + let arp_data = &data[14..]; + if arp_data.len() < 28 { + return; + } + + // Parse ARP header + let hardware_type = u16::from_be_bytes([arp_data[0], arp_data[1]]); + let protocol_type = u16::from_be_bytes([arp_data[2], arp_data[3]]); + let opcode = u16::from_be_bytes([arp_data[6], arp_data[7]]); + + // We only handle Ethernet (1) and IPv4 (0x0800) + if hardware_type != 1 || protocol_type != 0x0800 { + return; + } + + let sender_ip = IpAddr::from([arp_data[14], arp_data[15], arp_data[16], arp_data[17]]); + let target_ip = IpAddr::from([arp_data[24], arp_data[25], arp_data[26], arp_data[27]]); + + let operation = match opcode { + 1 => ArpOperation::Request, + 2 => ArpOperation::Reply, + _ => return, + }; + + let is_outgoing = self.local_ips.contains(&sender_ip); + let (local_addr, remote_addr) = if is_outgoing { + (SocketAddr::new(sender_ip, 0), SocketAddr::new(target_ip, 0)) + } else { + (SocketAddr::new(target_ip, 0), SocketAddr::new(sender_ip, 0)) + }; + + let conn_key = format!( + "{:?}:{}-{:?}:{}", + Protocol::ARP, + local_addr, + Protocol::ARP, + remote_addr + ); + + let state = ProtocolState::Arp { operation }; + + let conn = self + .connections + .entry(conn_key) + .or_insert_with(|| Connection::new(Protocol::ARP, local_addr, remote_addr, state)); + + // Update connection + conn.protocol_state = state; + conn.last_activity = SystemTime::now(); + + if is_outgoing { + conn.packets_sent += 1; + conn.bytes_sent += data.len() as u64; + } else { + conn.packets_received += 1; + conn.bytes_received += data.len() as u64; + } + + // Update rates + conn.update_rates(conn.bytes_sent, conn.bytes_received); + } + + // DPI helper methods + fn process_tcp_payload_for_dpi( + &mut self, + conn_key: &str, + payload: &[u8], + local_port: u16, + remote_port: u16, + ) { + if let Some(app_protocol) = + self.identify_tcp_application_from_payload(payload, local_port, remote_port) + { + if let Some(conn) = self.connections.get_mut(conn_key) { + conn.dpi_info = Some(DpiInfo { + application: app_protocol, + first_packet_time: Instant::now(), + last_update_time: Instant::now(), + }); + } + } + } + + fn process_udp_payload_for_dpi( + &mut self, + conn_key: &str, + payload: &[u8], + local_port: u16, + remote_port: u16, + ) { + if let Some(app_protocol) = + self.identify_udp_application_from_payload(payload, local_port, remote_port) + { + if let Some(conn) = self.connections.get_mut(conn_key) { + conn.dpi_info = Some(DpiInfo { + application: app_protocol, + first_packet_time: Instant::now(), + last_update_time: Instant::now(), + }); + } + } + } + + fn identify_tcp_application_from_payload( + &self, + payload: &[u8], + local_port: u16, + remote_port: u16, + ) -> Option { + // Check for HTTP/1.x + if self.is_http_payload(payload) { + return Some(ApplicationProtocol::Http(self.parse_http_info(payload))); + } + + // Check for TLS/HTTPS + if (local_port == 443 || remote_port == 443) || self.is_tls_handshake(payload) { + if let Some(tls_info) = self.extract_tls_info(payload) { + return Some(ApplicationProtocol::Https(tls_info)); + } + } + + // Check for SSH + if (local_port == 22 || remote_port == 22) || payload.starts_with(b"SSH-") { + return Some(ApplicationProtocol::Ssh); + } + + None + } + + fn identify_udp_application_from_payload( + &self, + payload: &[u8], + local_port: u16, + remote_port: u16, + ) -> Option { + // DNS + if local_port == 53 || remote_port == 53 { + if let Some(dns_info) = self.parse_dns_packet(payload) { + return Some(ApplicationProtocol::Dns(dns_info)); + } + } + + // QUIC/HTTP3 + if (local_port == 443 || remote_port == 443) && self.is_quic_packet(payload) { + return Some(ApplicationProtocol::Quic); + } + + None + } + + // DPI implementation methods + /// Check if payload looks like HTTP/1.x + fn is_http_payload(&self, payload: &[u8]) -> bool { + if payload.len() < 4 { + return false; + } + + // HTTP request methods + payload.starts_with(b"GET ") || + payload.starts_with(b"POST ") || + payload.starts_with(b"PUT ") || + payload.starts_with(b"DELETE ") || + payload.starts_with(b"HEAD ") || + payload.starts_with(b"OPTIONS ") || + payload.starts_with(b"CONNECT ") || + payload.starts_with(b"TRACE ") || + payload.starts_with(b"PATCH ") || + // HTTP responses + payload.starts_with(b"HTTP/1.0 ") || + payload.starts_with(b"HTTP/1.1 ") + } + + /// Parse HTTP information from payload + fn parse_http_info(&self, payload: &[u8]) -> HttpInfo { + let mut info = HttpInfo { + version: HttpVersion::Http11, // Default + method: None, + host: None, + path: None, + status_code: None, + user_agent: None, + }; + + // Convert to string for easier parsing (only what we can safely convert) + let text = String::from_utf8_lossy(payload); + let lines: Vec<&str> = text.lines().collect(); + + if lines.is_empty() { + return info; + } + + // Parse first line (request or response) + let first_line = lines[0]; + let parts: Vec<&str> = first_line.split_whitespace().collect(); + + if parts.len() >= 3 { + if first_line.starts_with("HTTP/") { + // Response line: HTTP/1.1 200 OK + info.version = if parts[0] == "HTTP/1.0" { + HttpVersion::Http10 + } else { + HttpVersion::Http11 + }; + info.status_code = parts[1].parse::().ok(); + } else { + // Request line: GET /path HTTP/1.1 + info.method = Some(parts[0].to_string()); + info.path = Some(parts[1].to_string()); + info.version = if parts[2] == "HTTP/1.0" { + HttpVersion::Http10 + } else { + HttpVersion::Http11 + }; + } + } + + // Parse headers + for line in lines.iter().skip(1) { + if line.is_empty() { + break; // End of headers + } + + if let Some((key, value)) = line.split_once(':') { + let key = key.trim().to_lowercase(); + let value = value.trim(); + + match key.as_str() { + "host" => info.host = Some(value.to_string()), + "user-agent" => info.user_agent = Some(value.to_string()), + _ => {} + } + } + } + + info + } + + /// Check if this is a TLS handshake packet + fn is_tls_handshake(&self, payload: &[u8]) -> bool { + if payload.len() < 6 { + return false; + } + + // TLS record header: + // - Content type (1 byte): 0x16 for handshake + // - Version (2 bytes): 0x0301-0x0304 for TLS 1.0-1.3 + // - Length (2 bytes) + + payload[0] == 0x16 && // Handshake content type + payload[1] == 0x03 && // Major version 3 + (payload[2] >= 0x01 && payload[2] <= 0x04) // Minor version 1-4 + } + + /// Extract TLS information from handshake + fn extract_tls_info(&self, payload: &[u8]) -> Option { + if !self.is_tls_handshake(payload) || payload.len() < 9 { + return None; + } + + let mut info = TlsInfo { + version: None, + sni: None, + alpn: Vec::new(), + cipher_suite: None, + }; + + // Record layer version + let record_version = match payload[2] { + 0x01 => Some(TlsVersion::Tls10), + 0x02 => Some(TlsVersion::Tls11), + 0x03 => Some(TlsVersion::Tls12), + 0x04 => Some(TlsVersion::Tls13), + _ => None, + }; + + // Skip TLS record header (5 bytes) + let handshake_data = &payload[5..]; + + if handshake_data.len() < 4 { + return Some(info); + } + + let handshake_type = handshake_data[0]; + + match handshake_type { + 0x01 => { + // Client Hello + info.version = record_version; + if let Some((sni, alpn)) = self.parse_client_hello_extensions(handshake_data) { + info.sni = sni; + info.alpn = alpn; + } + } + 0x02 => { + // Server Hello + info.version = record_version; + // Could parse cipher suite here if needed + } + _ => {} + } + + Some(info) + } + + /// Parse Client Hello extensions for SNI and ALPN + fn parse_client_hello_extensions( + &self, + handshake_data: &[u8], + ) -> Option<(Option, Vec)> { + if handshake_data.len() < 38 { + return None; + } + + // Skip to extensions: + // - Handshake type (1) + Length (3) + Version (2) + Random (32) = 38 + let mut offset = 38; + + // Session ID + if offset >= handshake_data.len() { + return None; + } + let session_id_len = handshake_data[offset] as usize; + offset += 1 + session_id_len; + + // Cipher suites + if offset + 2 > handshake_data.len() { + return None; + } + let cipher_suites_len = + u16::from_be_bytes([handshake_data[offset], handshake_data[offset + 1]]) as usize; + offset += 2 + cipher_suites_len; + + // Compression methods + if offset >= handshake_data.len() { + return None; + } + let compression_len = handshake_data[offset] as usize; + offset += 1 + compression_len; + + // Extensions length + if offset + 2 > handshake_data.len() { + return None; + } + let extensions_len = + u16::from_be_bytes([handshake_data[offset], handshake_data[offset + 1]]) as usize; + offset += 2; + + if offset + extensions_len > handshake_data.len() { + return None; + } + + // Parse extensions + let mut sni = None; + let mut alpn = Vec::new(); + let extensions_data = &handshake_data[offset..offset + extensions_len]; + let mut ext_offset = 0; + + while ext_offset + 4 <= extensions_data.len() { + let ext_type = + u16::from_be_bytes([extensions_data[ext_offset], extensions_data[ext_offset + 1]]); + let ext_len = u16::from_be_bytes([ + extensions_data[ext_offset + 2], + extensions_data[ext_offset + 3], + ]) as usize; + + if ext_offset + 4 + ext_len > extensions_data.len() { + break; + } + + match ext_type { + 0x0000 => { + // SNI + sni = self.parse_sni_extension( + &extensions_data[ext_offset + 4..ext_offset + 4 + ext_len], + ); + } + 0x0010 => { + // ALPN + alpn = self.parse_alpn_extension( + &extensions_data[ext_offset + 4..ext_offset + 4 + ext_len], + ); + } + _ => {} + } + + ext_offset += 4 + ext_len; + } + + Some((sni, alpn)) + } + + /// Parse SNI extension + fn parse_sni_extension(&self, data: &[u8]) -> Option { + if data.len() < 5 { + return None; + } + + // Skip server name list length (2 bytes) + let mut offset = 2; + + while offset + 3 <= data.len() { + let name_type = data[offset]; + let name_len = u16::from_be_bytes([data[offset + 1], data[offset + 2]]) as usize; + + if name_type == 0x00 { + // host_name + if offset + 3 + name_len <= data.len() { + let hostname_bytes = &data[offset + 3..offset + 3 + name_len]; + if let Ok(hostname) = std::str::from_utf8(hostname_bytes) { + return Some(hostname.to_string()); + } + } + } + + offset += 3 + name_len; + } + + None + } + + /// Parse ALPN extension + fn parse_alpn_extension(&self, data: &[u8]) -> Vec { + let mut protocols = Vec::new(); + + if data.len() < 2 { + return protocols; + } + + // Skip ALPN extension length + let mut offset = 2; + + while offset < data.len() { + let proto_len = data[offset] as usize; + if offset + 1 + proto_len <= data.len() { + if let Ok(proto) = std::str::from_utf8(&data[offset + 1..offset + 1 + proto_len]) { + protocols.push(proto.to_string()); + } + } + offset += 1 + proto_len; + } + + protocols + } + + /// Parse DNS packet + fn parse_dns_packet(&self, payload: &[u8]) -> Option { + if payload.len() < 12 { + return None; + } + + let mut info = DnsInfo { + query_name: None, + query_type: None, + response_ips: Vec::new(), + is_response: false, + }; + + // DNS header flags + let flags = u16::from_be_bytes([payload[2], payload[3]]); + info.is_response = (flags & 0x8000) != 0; // QR bit + + // Question count + let qdcount = u16::from_be_bytes([payload[4], payload[5]]); + + if qdcount > 0 { + // Parse first question + let mut offset = 12; + let mut name = String::new(); + + // Parse domain name + while offset < payload.len() { + let label_len = payload[offset] as usize; + if label_len == 0 { + offset += 1; + break; + } + + if label_len >= 0xC0 { + // Compressed name - skip for simplicity + offset += 2; + break; + } + + if offset + 1 + label_len > payload.len() { + break; + } + + if !name.is_empty() { + name.push('.'); + } + + if let Ok(label) = std::str::from_utf8(&payload[offset + 1..offset + 1 + label_len]) + { + name.push_str(label); + } + + offset += 1 + label_len; + } + + if !name.is_empty() { + info.query_name = Some(name); + } + + // Query type + if offset + 2 <= payload.len() { + let qtype = u16::from_be_bytes([payload[offset], payload[offset + 1]]); + info.query_type = Some(match qtype { + 1 => DnsQueryType::A, + 28 => DnsQueryType::AAAA, + 5 => DnsQueryType::CNAME, + 15 => DnsQueryType::MX, + 16 => DnsQueryType::TXT, + other => DnsQueryType::Other(other), + }); + } + } + + Some(info) + } + + /// Check if this is a QUIC packet + fn is_quic_packet(&self, payload: &[u8]) -> bool { + if payload.len() < 5 { + return false; + } + + // Check for QUIC long header (bit 7 set) + if (payload[0] & 0x80) != 0 { + // Check version + let version = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]); + + // Known QUIC versions + return version == 0x00000001 || // QUIC v1 + version == 0x6b3343cf || // QUIC v2 + version == 0x51303530 || // Google QUIC + version == 0; // Version negotiation + } + + // Could be short header QUIC packet + // These are harder to identify definitively, but if we see them on port 443 UDP, + // they're likely QUIC + true + } + + fn get_connection_key_for_merge(&self, conn: &Connection) -> String { + format!( + "{:?}:{}-{:?}:{}", + conn.protocol, conn.local_addr, conn.protocol, conn.remote_addr + ) + } +} + +fn parse_addr(addr_str: &str) -> Option { + let addr_str = addr_str.trim(); + + if let Ok(socket_addr) = addr_str.parse::() { + return Some(socket_addr); + } + + if let Ok(port) = addr_str.parse::() { + return Some(std::net::SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), + port, + )); + } + + if let Some(dot_idx) = addr_str.rfind('.') { + if let Some(socket_addr) = parse_with_separator(addr_str, dot_idx) { + return Some(socket_addr); + } + } + + if let Some(colon_idx) = addr_str.rfind(':') { + if let Some(socket_addr) = parse_with_separator(addr_str, colon_idx) { + return Some(socket_addr); + } + } + + None +} + +fn parse_with_separator(addr_str: &str, sep_idx: usize) -> Option { + let (host_part, port_part) = addr_str.split_at(sep_idx); + let port_part = &port_part[1..]; + + let host = if host_part.starts_with('[') && host_part.ends_with(']') { + &host_part[1..host_part.len() - 1] + } else { + host_part + }; + + let ip_addr = host.parse::().ok()?; + let port = if port_part == "*" { + 0 + } else { + port_part.parse::().ok()? + }; + + Some(std::net::SocketAddr::new(ip_addr, port)) +} diff --git a/src/network/parser.rs b/src/network/parser.rs new file mode 100644 index 0000000..74c4967 --- /dev/null +++ b/src/network/parser.rs @@ -0,0 +1,479 @@ +// network/parser.rs - Updated with DPI integration +use crate::network::dpi::{self, DpiResult}; +use crate::network::types::*; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + +/// Result of parsing a packet +#[derive(Debug)] +pub struct ParsedPacket { + pub connection_key: String, + pub protocol: Protocol, + pub local_addr: SocketAddr, + pub remote_addr: SocketAddr, + pub state: ProtocolState, + pub is_outgoing: bool, + pub packet_len: usize, + pub dpi_result: Option, // DPI results if available +} + +pub struct ParserConfig { + pub enable_dpi: bool, + pub dpi_packet_limit: usize, // Only inspect first N packets per connection +} + +impl Default for ParserConfig { + fn default() -> Self { + Self { + enable_dpi: true, + dpi_packet_limit: 10, // Only inspect first 10 packets + } + } +} + +/// Packet parser - stateless, thread-safe +pub struct PacketParser { + local_ips: std::collections::HashSet, + config: ParserConfig, +} + +impl PacketParser { + pub fn new() -> Self { + let mut local_ips = std::collections::HashSet::new(); + for iface in pnet_datalink::interfaces() { + for ip_network in iface.ips { + local_ips.insert(ip_network.ip()); + } + } + Self { + local_ips, + config: ParserConfig::default(), + } + } + + pub fn with_config(config: ParserConfig) -> Self { + let mut local_ips = std::collections::HashSet::new(); + for iface in pnet_datalink::interfaces() { + for ip_network in iface.ips { + local_ips.insert(ip_network.ip()); + } + } + Self { local_ips, config } + } + + /// Parse a raw packet + pub fn parse_packet(&self, data: &[u8]) -> Option { + if data.len() < 14 { + return None; + } + + let ethertype = u16::from_be_bytes([data[12], data[13]]); + + match ethertype { + 0x0800 => self.parse_ipv4_packet(data), + 0x86dd => self.parse_ipv6_packet(data), + 0x0806 => self.parse_arp_packet(data), + _ => None, + } + } + + fn parse_ipv4_packet(&self, data: &[u8]) -> Option { + let ip_data = &data[14..]; + if ip_data.len() < 20 { + return None; + } + + let version = ip_data[0] >> 4; + if version != 4 { + return None; + } + + let protocol_num = ip_data[9]; + let src_ip = IpAddr::V4(Ipv4Addr::new( + ip_data[12], + ip_data[13], + ip_data[14], + ip_data[15], + )); + let dst_ip = IpAddr::V4(Ipv4Addr::new( + ip_data[16], + ip_data[17], + ip_data[18], + ip_data[19], + )); + + let ihl = ip_data[0] & 0x0F; + let ip_header_len = (ihl as usize) * 4; + + if ip_data.len() < ip_header_len { + return None; + } + + let transport_data = &ip_data[ip_header_len..]; + let is_outgoing = self.local_ips.contains(&src_ip); + + match protocol_num { + 1 => self.parse_icmp(transport_data, src_ip, dst_ip, is_outgoing, data.len()), + 6 => self.parse_tcp(transport_data, src_ip, dst_ip, is_outgoing, data.len()), + 17 => self.parse_udp(transport_data, src_ip, dst_ip, is_outgoing, data.len()), + _ => None, + } + } + + fn parse_ipv6_packet(&self, data: &[u8]) -> Option { + let ip_data = &data[14..]; + if ip_data.len() < 40 { + return None; + } + + let version = ip_data[0] >> 4; + if version != 6 { + return None; + } + + let next_header = ip_data[6]; + + // Extract IPv6 addresses + let src_ip = IpAddr::V6(Ipv6Addr::new( + u16::from_be_bytes([ip_data[8], ip_data[9]]), + u16::from_be_bytes([ip_data[10], ip_data[11]]), + u16::from_be_bytes([ip_data[12], ip_data[13]]), + u16::from_be_bytes([ip_data[14], ip_data[15]]), + u16::from_be_bytes([ip_data[16], ip_data[17]]), + u16::from_be_bytes([ip_data[18], ip_data[19]]), + u16::from_be_bytes([ip_data[20], ip_data[21]]), + u16::from_be_bytes([ip_data[22], ip_data[23]]), + )); + + let dst_ip = IpAddr::V6(Ipv6Addr::new( + u16::from_be_bytes([ip_data[24], ip_data[25]]), + u16::from_be_bytes([ip_data[26], ip_data[27]]), + u16::from_be_bytes([ip_data[28], ip_data[29]]), + u16::from_be_bytes([ip_data[30], ip_data[31]]), + u16::from_be_bytes([ip_data[32], ip_data[33]]), + u16::from_be_bytes([ip_data[34], ip_data[35]]), + u16::from_be_bytes([ip_data[36], ip_data[37]]), + u16::from_be_bytes([ip_data[38], ip_data[39]]), + )); + + let transport_data = &ip_data[40..]; + let is_outgoing = self.local_ips.contains(&src_ip); + + // Handle extension headers if needed + let (final_next_header, transport_offset) = + self.parse_ipv6_extension_headers(next_header, transport_data); + let final_transport_data = &transport_data[transport_offset..]; + + match final_next_header { + 58 => self.parse_icmpv6( + final_transport_data, + src_ip, + dst_ip, + is_outgoing, + data.len(), + ), + 6 => self.parse_tcp( + final_transport_data, + src_ip, + dst_ip, + is_outgoing, + data.len(), + ), + 17 => self.parse_udp( + final_transport_data, + src_ip, + dst_ip, + is_outgoing, + data.len(), + ), + _ => None, + } + } + + fn parse_tcp( + &self, + transport_data: &[u8], + src_ip: IpAddr, + dst_ip: IpAddr, + is_outgoing: bool, + packet_len: usize, + ) -> Option { + if transport_data.len() < 20 { + return None; + } + + let src_port = u16::from_be_bytes([transport_data[0], transport_data[1]]); + let dst_port = u16::from_be_bytes([transport_data[2], transport_data[3]]); + let flags = transport_data[13]; + + let tcp_state = parse_tcp_flags(flags); + + let (local_addr, remote_addr) = if is_outgoing { + ( + SocketAddr::new(src_ip, src_port), + SocketAddr::new(dst_ip, dst_port), + ) + } else { + ( + SocketAddr::new(dst_ip, dst_port), + SocketAddr::new(src_ip, src_port), + ) + }; + + // Perform DPI if enabled and there's payload + let dpi_result = if self.config.enable_dpi { + let tcp_header_len = ((transport_data[12] >> 4) as usize) * 4; + if transport_data.len() > tcp_header_len { + let payload = &transport_data[tcp_header_len..]; + dpi::analyze_tcp_packet(payload, local_addr.port(), remote_addr.port(), is_outgoing) + } else { + None + } + } else { + None + }; + + Some(ParsedPacket { + connection_key: format!("TCP:{}-TCP:{}", local_addr, remote_addr), + protocol: Protocol::TCP, + local_addr, + remote_addr, + state: ProtocolState::Tcp(tcp_state), + is_outgoing, + packet_len, + dpi_result, + }) + } + + fn parse_udp( + &self, + transport_data: &[u8], + src_ip: IpAddr, + dst_ip: IpAddr, + is_outgoing: bool, + packet_len: usize, + ) -> Option { + if transport_data.len() < 8 { + return None; + } + + let src_port = u16::from_be_bytes([transport_data[0], transport_data[1]]); + let dst_port = u16::from_be_bytes([transport_data[2], transport_data[3]]); + + let (local_addr, remote_addr) = if is_outgoing { + ( + SocketAddr::new(src_ip, src_port), + SocketAddr::new(dst_ip, dst_port), + ) + } else { + ( + SocketAddr::new(dst_ip, dst_port), + SocketAddr::new(src_ip, src_port), + ) + }; + + // Perform DPI if enabled and there's payload + let dpi_result = if self.config.enable_dpi && transport_data.len() > 8 { + let payload = &transport_data[8..]; + dpi::analyze_udp_packet(payload, local_addr.port(), remote_addr.port(), is_outgoing) + } else { + None + }; + + Some(ParsedPacket { + connection_key: format!("UDP:{}-UDP:{}", local_addr, remote_addr), + protocol: Protocol::UDP, + local_addr, + remote_addr, + state: ProtocolState::Udp, + is_outgoing, + packet_len, + dpi_result, + }) + } + + fn parse_icmp( + &self, + transport_data: &[u8], + src_ip: IpAddr, + dst_ip: IpAddr, + is_outgoing: bool, + packet_len: usize, + ) -> Option { + if transport_data.is_empty() { + return None; + } + + let icmp_type = transport_data[0]; + let icmp_code = if transport_data.len() > 1 { + transport_data[1] + } else { + 0 + }; + + let (local_addr, remote_addr) = if is_outgoing { + (SocketAddr::new(src_ip, 0), SocketAddr::new(dst_ip, 0)) + } else { + (SocketAddr::new(dst_ip, 0), SocketAddr::new(src_ip, 0)) + }; + + Some(ParsedPacket { + connection_key: format!("ICMP:{}-ICMP:{}", local_addr, remote_addr), + protocol: Protocol::ICMP, + local_addr, + remote_addr, + state: ProtocolState::Icmp { + icmp_type, + icmp_code, + }, + is_outgoing, + packet_len, + dpi_result: None, + }) + } + + fn parse_icmpv6( + &self, + transport_data: &[u8], + src_ip: IpAddr, + dst_ip: IpAddr, + is_outgoing: bool, + packet_len: usize, + ) -> Option { + if transport_data.is_empty() { + return None; + } + + let icmp_type = transport_data[0]; + let icmp_code = if transport_data.len() > 1 { + transport_data[1] + } else { + 0 + }; + + let (local_addr, remote_addr) = if is_outgoing { + (SocketAddr::new(src_ip, 0), SocketAddr::new(dst_ip, 0)) + } else { + (SocketAddr::new(dst_ip, 0), SocketAddr::new(src_ip, 0)) + }; + + Some(ParsedPacket { + connection_key: format!("ICMP:{}-ICMP:{}", local_addr, remote_addr), + protocol: Protocol::ICMP, + local_addr, + remote_addr, + state: ProtocolState::Icmp { + icmp_type, + icmp_code, + }, + is_outgoing, + packet_len, + dpi_result: None, // No DPI for ICMPv6 + }) + } + + fn parse_arp_packet(&self, data: &[u8]) -> Option { + let arp_data = &data[14..]; + if arp_data.len() < 28 { + return None; + } + + let hardware_type = u16::from_be_bytes([arp_data[0], arp_data[1]]); + let protocol_type = u16::from_be_bytes([arp_data[2], arp_data[3]]); + let opcode = u16::from_be_bytes([arp_data[6], arp_data[7]]); + + if hardware_type != 1 || protocol_type != 0x0800 { + return None; + } + + let sender_ip = IpAddr::from([arp_data[14], arp_data[15], arp_data[16], arp_data[17]]); + let target_ip = IpAddr::from([arp_data[24], arp_data[25], arp_data[26], arp_data[27]]); + + let operation = match opcode { + 1 => ArpOperation::Request, + 2 => ArpOperation::Reply, + _ => return None, + }; + + let is_outgoing = self.local_ips.contains(&sender_ip); + let (local_addr, remote_addr) = if is_outgoing { + (SocketAddr::new(sender_ip, 0), SocketAddr::new(target_ip, 0)) + } else { + (SocketAddr::new(target_ip, 0), SocketAddr::new(sender_ip, 0)) + }; + + Some(ParsedPacket { + connection_key: format!("ARP:{}-ARP:{}", local_addr, remote_addr), + protocol: Protocol::ARP, + local_addr, + remote_addr, + state: ProtocolState::Arp { operation }, + is_outgoing, + packet_len: data.len(), + dpi_result: None, + }) + } + + fn parse_ipv6_extension_headers(&self, mut next_header: u8, data: &[u8]) -> (u8, usize) { + let mut offset = 0; + + const HOP_BY_HOP: u8 = 0; + const ROUTING: u8 = 43; + const FRAGMENT: u8 = 44; + const ENCAPSULATING_SECURITY: u8 = 50; + const AUTHENTICATION: u8 = 51; + const DESTINATION_OPTIONS: u8 = 60; + + loop { + match next_header { + HOP_BY_HOP | ROUTING | DESTINATION_OPTIONS => { + if data.len() < offset + 2 { + return (next_header, offset); + } + next_header = data[offset]; + let header_len = ((data[offset + 1] as usize) + 1) * 8; + offset += header_len; + } + FRAGMENT => { + if data.len() < offset + 8 { + return (next_header, offset); + } + next_header = data[offset]; + offset += 8; + } + AUTHENTICATION => { + if data.len() < offset + 2 { + return (next_header, offset); + } + next_header = data[offset]; + let header_len = ((data[offset + 1] as usize) + 2) * 4; + offset += header_len; + } + ENCAPSULATING_SECURITY => { + return (next_header, offset); + } + _ => { + return (next_header, offset); + } + } + + if offset >= data.len() { + return (next_header, offset); + } + } + } +} + +// ... rest of parsing methods + +fn parse_tcp_flags(flags: u8) -> TcpState { + match flags { + 0x02 => TcpState::SynSent, + 0x12 => TcpState::SynReceived, + 0x10 => TcpState::Established, + 0x01 => TcpState::FinWait1, + 0x11 => TcpState::FinWait2, + 0x04 => TcpState::Closed, + 0x14 => TcpState::Closing, + _ => TcpState::Established, + } +} diff --git a/src/network/platform/linux.rs b/src/network/platform/linux.rs new file mode 100644 index 0000000..2e49ab7 --- /dev/null +++ b/src/network/platform/linux.rs @@ -0,0 +1,228 @@ +// network/platform/linux.rs - Linux process lookup +use super::{ConnectionKey, ProcessLookup}; +use crate::types::{Connection, Protocol}; +use anyhow::Result; +use std::collections::HashMap; +use std::fs; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::RwLock; +use std::time::{Duration, Instant}; + +pub struct LinuxProcessLookup { + // Cache: ConnectionKey -> (pid, process_name) + cache: RwLock, +} + +struct ProcessCache { + lookup: HashMap, + last_refresh: Instant, +} + +impl LinuxProcessLookup { + pub fn new() -> Result { + Ok(Self { + cache: RwLock::new(ProcessCache { + lookup: HashMap::new(), + last_refresh: Instant::now() - Duration::from_secs(3600), + }), + }) + } + + /// Build connection -> process mapping + fn build_process_map() -> Result> { + let mut process_map = HashMap::new(); + + // First, build inode -> process mapping + let inode_to_process = Self::build_inode_map()?; + + // Then, parse network files to map connections -> inodes -> processes + Self::parse_and_map( + "/proc/net/tcp", + Protocol::TCP, + &inode_to_process, + &mut process_map, + )?; + Self::parse_and_map( + "/proc/net/tcp6", + Protocol::TCP, + &inode_to_process, + &mut process_map, + )?; + Self::parse_and_map( + "/proc/net/udp", + Protocol::UDP, + &inode_to_process, + &mut process_map, + )?; + Self::parse_and_map( + "/proc/net/udp6", + Protocol::UDP, + &inode_to_process, + &mut process_map, + )?; + + Ok(process_map) + } + + /// Build inode -> (pid, process_name) mapping + fn build_inode_map() -> Result> { + let mut inode_map = HashMap::new(); + + for entry in fs::read_dir("/proc")? { + let entry = entry?; + let path = entry.path(); + + if let Some(pid_str) = path.file_name().and_then(|s| s.to_str()) { + if let Ok(pid) = pid_str.parse::() { + if pid == 0 { + continue; + } + + // Get process name + let comm_path = path.join("comm"); + let process_name = fs::read_to_string(&comm_path) + .unwrap_or_else(|_| "unknown".to_string()) + .trim() + .to_string(); + + // Check file descriptors + let fd_dir = path.join("fd"); + if let Ok(fd_entries) = fs::read_dir(&fd_dir) { + for fd_entry in fd_entries.flatten() { + if let Ok(link) = fs::read_link(fd_entry.path()) { + if let Some(link_str) = link.to_str() { + if let Some(inode) = Self::extract_socket_inode(link_str) { + inode_map.insert(inode, (pid, process_name.clone())); + } + } + } + } + } + } + } + } + + Ok(inode_map) + } + + /// Parse /proc/net file and map connections to processes + fn parse_and_map( + path: &str, + protocol: Protocol, + inode_map: &HashMap, + result: &mut HashMap, + ) -> Result<()> { + let content = match fs::read_to_string(path) { + Ok(c) => c, + Err(_) => return Ok(()), // File might not exist + }; + + for (i, line) in content.lines().enumerate() { + if i == 0 { + continue; // Skip header + } + + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 10 { + continue; + } + + // Parse addresses + let local_addr = match Self::parse_hex_address(parts[1]) { + Some(addr) => addr, + None => continue, + }; + + let remote_addr = match Self::parse_hex_address(parts[2]) { + Some(addr) => addr, + None => continue, + }; + + // Get inode + if let Ok(inode) = parts[9].parse::() { + if let Some((pid, name)) = inode_map.get(&inode) { + let key = ConnectionKey { + protocol, + local_addr, + remote_addr, + }; + result.insert(key, (*pid, name.clone())); + } + } + } + + Ok(()) + } + + fn parse_hex_address(hex_addr: &str) -> Option { + let parts: Vec<&str> = hex_addr.split(':').collect(); + if parts.len() != 2 { + return None; + } + + let ip_hex = parts[0]; + let port = u16::from_str_radix(parts[1], 16).ok()?; + + if ip_hex.len() == 8 { + // IPv4 + let ip_bytes = u32::from_str_radix(ip_hex, 16).ok()?; + let ip = Ipv4Addr::from(ip_bytes.to_le_bytes()); + Some(SocketAddr::new(IpAddr::V4(ip), port)) + } else if ip_hex.len() == 32 { + // IPv6 + let mut bytes = [0u8; 16]; + for i in 0..4 { + let chunk = &ip_hex[i * 8..(i + 1) * 8]; + let value = u32::from_str_radix(chunk, 16).ok()?; + bytes[i * 4..(i + 1) * 4].copy_from_slice(&value.to_le_bytes()); + } + let ip = Ipv6Addr::from(bytes); + Some(SocketAddr::new(IpAddr::V6(ip), port)) + } else { + None + } + } + + fn extract_socket_inode(link: &str) -> Option { + if link.starts_with("socket:[") && link.ends_with(']') { + let inode_str = &link[8..link.len() - 1]; + inode_str.parse().ok() + } else { + None + } + } +} + +impl ProcessLookup for LinuxProcessLookup { + fn get_process_for_connection(&self, conn: &Connection) -> Option<(u32, String)> { + let key = ConnectionKey::from_connection(conn); + + // Try cache first + { + let cache = self.cache.read().unwrap(); + if cache.last_refresh.elapsed() < Duration::from_secs(2) { + if let Some(process_info) = cache.lookup.get(&key) { + return Some(process_info.clone()); + } + } + } + + // Cache is stale or miss, refresh + if self.refresh().is_ok() { + let cache = self.cache.read().unwrap(); + cache.lookup.get(&key).cloned() + } else { + None + } + } + + fn refresh(&self) -> Result<()> { + let process_map = Self::build_process_map()?; + + let mut cache = self.cache.write().unwrap(); + cache.lookup = process_map; + cache.last_refresh = Instant::now(); + + Ok(()) + } +} diff --git a/src/network/platform/macos.rs b/src/network/platform/macos.rs new file mode 100644 index 0000000..dbe4688 --- /dev/null +++ b/src/network/platform/macos.rs @@ -0,0 +1,81 @@ +use super::{ConnectionKey, ProcessLookup}; +use crate::network::types::{Connection, Protocol}; +use anyhow::Result; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::process::Command; +use std::sync::RwLock; + +pub struct MacOSProcessLookup { + cache: RwLock>, +} + +impl MacOSProcessLookup { + pub fn new() -> Result { + Ok(Self { + cache: RwLock::new(HashMap::new()), + }) + } + + fn parse_lsof() -> Result> { + let mut lookup = HashMap::new(); + + // Run lsof to get network connections + let output = Command::new("lsof") + .args(&["-i", "-n", "-P", "+c", "0"]) + .output()?; + + if !output.status.success() { + return Ok(lookup); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + + for line in stdout.lines().skip(1) { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 10 { + continue; + } + + let process_name = parts[0].to_string(); + let pid = match parts[1].parse::() { + Ok(p) => p, + Err(_) => continue, + }; + + // Parse connection from NAME field + if let Some((protocol, local, remote)) = parse_lsof_connection(parts[8]) { + let key = ConnectionKey { + protocol, + local_addr: local, + remote_addr: remote, + }; + lookup.insert(key, (pid, process_name)); + } + } + + Ok(lookup) + } +} + +impl ProcessLookup for MacOSProcessLookup { + fn get_process_for_connection(&self, conn: &Connection) -> Option<(u32, String)> { + let key = ConnectionKey::from_connection(conn); + self.cache.read().unwrap().get(&key).cloned() + } + + fn refresh(&self) -> Result<()> { + let new_cache = Self::parse_lsof()?; + *self.cache.write().unwrap() = new_cache; + Ok(()) + } +} + +fn parse_lsof_connection(name: &str) -> Option<(Protocol, SocketAddr, SocketAddr)> { + // Parse lsof NAME field format: + // "192.168.1.1:443->10.0.0.1:12345" + // Determine protocol and parse addresses + + // Implementation would parse the connection string + None // Placeholder +} diff --git a/src/network/platform/mod.rs b/src/network/platform/mod.rs new file mode 100644 index 0000000..e60a1d9 --- /dev/null +++ b/src/network/platform/mod.rs @@ -0,0 +1,73 @@ +// network/platform/mod.rs - Platform process lookup +use crate::network::types::{Connection, Protocol}; +use anyhow::Result; +use std::net::SocketAddr; + +// Platform-specific modules +#[cfg(target_os = "linux")] +mod linux; +#[cfg(target_os = "macos")] +mod macos; +#[cfg(target_os = "windows")] +mod windows; + +// Re-export the appropriate implementation +#[cfg(target_os = "linux")] +pub use linux::LinuxProcessLookup; +#[cfg(target_os = "macos")] +pub use macos::MacOSProcessLookup; +#[cfg(target_os = "windows")] +pub use windows::WindowsProcessLookup; + +/// Trait for platform-specific process lookup +pub trait ProcessLookup: Send + Sync { + /// Look up process information for a connection + /// Returns (pid, process_name) if found + fn get_process_for_connection(&self, conn: &Connection) -> Option<(u32, String)>; + + /// Refresh internal caches if any (best-effort) + fn refresh(&self) -> Result<()> { + Ok(()) // Default no-op + } +} + +/// Create a platform-specific process lookup +pub fn create_process_lookup() -> Result> { + #[cfg(target_os = "linux")] + { + Ok(Box::new(LinuxProcessLookup::new()?)) + } + + #[cfg(target_os = "windows")] + { + Ok(Box::new(WindowsProcessLookup::new()?)) + } + + #[cfg(target_os = "macos")] + { + Ok(Box::new(MacOSProcessLookup::new()?)) + } + + #[cfg(not(any(target_os = "linux", target_os = "windows", target_os = "macos")))] + { + Err(anyhow::anyhow!("Unsupported platform")) + } +} + +/// Connection identifier for lookups +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct ConnectionKey { + pub protocol: Protocol, + pub local_addr: SocketAddr, + pub remote_addr: SocketAddr, +} + +impl ConnectionKey { + pub fn from_connection(conn: &Connection) -> Self { + Self { + protocol: conn.protocol, + local_addr: conn.local_addr, + remote_addr: conn.remote_addr, + } + } +} diff --git a/src/network/platform/windows.rs b/src/network/platform/windows.rs new file mode 100644 index 0000000..f54ca28 --- /dev/null +++ b/src/network/platform/windows.rs @@ -0,0 +1,59 @@ +use super::{ConnectionKey, ProcessLookup}; +use crate::network::types::{Connection, Protocol}; +use anyhow::Result; +use std::collections::HashMap; +use std::sync::RwLock; + +pub struct WindowsProcessLookup { + // Windows can get process info directly from connection tables + cache: RwLock>, +} + +impl WindowsProcessLookup { + pub fn new() -> Result { + Ok(Self { + cache: RwLock::new(HashMap::new()), + }) + } + + fn refresh_tcp_processes( + &self, + cache: &mut HashMap, + ) -> Result<()> { + // Use GetExtendedTcpTable to get TCP connections with PIDs + // This is pseudo-code - actual implementation would use winapi + + // For each connection in the table: + // - Extract local/remote addresses + // - Get PID from dwOwningPid + // - Look up process name from PID + // - Insert into cache + + Ok(()) + } + + fn refresh_udp_processes( + &self, + cache: &mut HashMap, + ) -> Result<()> { + // Similar to TCP using GetExtendedUdpTable + Ok(()) + } +} + +impl ProcessLookup for WindowsProcessLookup { + fn get_process_for_connection(&self, conn: &Connection) -> Option<(u32, String)> { + let key = ConnectionKey::from_connection(conn); + self.cache.read().unwrap().get(&key).cloned() + } + + fn refresh(&self) -> Result<()> { + let mut new_cache = HashMap::new(); + + self.refresh_tcp_processes(&mut new_cache)?; + self.refresh_udp_processes(&mut new_cache)?; + + *self.cache.write().unwrap() = new_cache; + Ok(()) + } +} diff --git a/src/network/services.rs b/src/network/services.rs new file mode 100644 index 0000000..9edb6ae --- /dev/null +++ b/src/network/services.rs @@ -0,0 +1,272 @@ +use crate::network::types::Protocol; +use anyhow::Result; +use std::collections::HashMap; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::Path; + +/// Service name lookup table +#[derive(Debug, Clone)] +pub struct ServiceLookup { + /// Map of (port, protocol) -> service name + services: HashMap<(u16, Protocol), String>, + /// Common alternative names for services + aliases: HashMap, +} + +impl ServiceLookup { + /// Create an empty service lookup + pub fn new() -> Self { + Self { + services: HashMap::new(), + aliases: HashMap::new(), + } + } + + /// Load services from a file (typically /etc/services format) + pub fn from_file>(path: P) -> Result { + let mut services = HashMap::new(); + let mut aliases = HashMap::new(); + + let file = File::open(path)?; + let reader = BufReader::new(file); + + for line in reader.lines() { + let line = line?; + let line = line.trim(); + + // Skip comments and empty lines + if line.is_empty() || line.starts_with('#') { + continue; + } + + // Parse line format: service-name port/protocol [aliases...] + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 2 { + continue; + } + + let service_name = parts[0]; + let port_protocol = parts[1]; + + // Parse port/protocol + let port_parts: Vec<&str> = port_protocol.split('/').collect(); + if port_parts.len() != 2 { + continue; + } + + let port = match port_parts[0].parse::() { + Ok(p) => p, + Err(_) => continue, + }; + + let protocol = match port_parts[1].to_lowercase().as_str() { + "tcp" => Protocol::TCP, + "udp" => Protocol::UDP, + _ => continue, + }; + + // Store the service + services + .entry((port, protocol)) + .or_insert_with(|| service_name.to_string()); + + // Store aliases if any + for &alias in &parts[2..] { + if !alias.starts_with('#') { + aliases.insert(alias.to_string(), service_name.to_string()); + } else { + break; // Rest is comment + } + } + } + + Ok(Self { services, aliases }) + } + + /// Create with common well-known services + pub fn with_defaults() -> Self { + let mut lookup = Self::new(); + + // Common TCP services + lookup.add_service(20, Protocol::TCP, "ftp-data"); + lookup.add_service(21, Protocol::TCP, "ftp"); + lookup.add_service(22, Protocol::TCP, "ssh"); + lookup.add_service(23, Protocol::TCP, "telnet"); + lookup.add_service(25, Protocol::TCP, "smtp"); + lookup.add_service(53, Protocol::TCP, "dns"); + lookup.add_service(80, Protocol::TCP, "http"); + lookup.add_service(110, Protocol::TCP, "pop3"); + lookup.add_service(143, Protocol::TCP, "imap"); + lookup.add_service(443, Protocol::TCP, "https"); + lookup.add_service(445, Protocol::TCP, "microsoft-ds"); + lookup.add_service(587, Protocol::TCP, "submission"); + lookup.add_service(993, Protocol::TCP, "imaps"); + lookup.add_service(995, Protocol::TCP, "pop3s"); + lookup.add_service(1433, Protocol::TCP, "mssql"); + lookup.add_service(3306, Protocol::TCP, "mysql"); + lookup.add_service(3389, Protocol::TCP, "rdp"); + lookup.add_service(5432, Protocol::TCP, "postgresql"); + lookup.add_service(5900, Protocol::TCP, "vnc"); + lookup.add_service(6379, Protocol::TCP, "redis"); + lookup.add_service(8080, Protocol::TCP, "http-alt"); + lookup.add_service(8443, Protocol::TCP, "https-alt"); + lookup.add_service(27017, Protocol::TCP, "mongodb"); + + // Common UDP services + lookup.add_service(53, Protocol::UDP, "dns"); + lookup.add_service(67, Protocol::UDP, "dhcp-server"); + lookup.add_service(68, Protocol::UDP, "dhcp-client"); + lookup.add_service(123, Protocol::UDP, "ntp"); + lookup.add_service(161, Protocol::UDP, "snmp"); + lookup.add_service(443, Protocol::UDP, "https"); // QUIC + lookup.add_service(500, Protocol::UDP, "isakmp"); + lookup.add_service(1194, Protocol::UDP, "openvpn"); + lookup.add_service(4500, Protocol::UDP, "ipsec-nat"); + lookup.add_service(5060, Protocol::UDP, "sip"); + + lookup + } + + /// Add a service mapping + pub fn add_service(&mut self, port: u16, protocol: Protocol, name: &str) { + self.services.insert((port, protocol), name.to_string()); + } + + /// Look up a service name by port and protocol + pub fn lookup(&self, port: u16, protocol: Protocol) -> Option<&str> { + self.services.get(&(port, protocol)).map(|s| s.as_str()) + } + + /// Look up service name with fallback to common names + pub fn lookup_with_fallback(&self, port: u16, protocol: Protocol) -> Option { + if let Some(name) = self.lookup(port, protocol) { + return Some(name.to_string()); + } + + // Common dynamic port ranges with generic names + match port { + 1024..=5000 => Some("user-port".to_string()), + 5001..=32767 => Some("dynamic".to_string()), + 32768..=60999 => Some("private".to_string()), + 61000..=65535 => Some("ephemeral".to_string()), + _ => None, + } + } + + /// Get a display name for a service (formats well-known services better) + pub fn display_name(&self, port: u16, protocol: Protocol) -> String { + match self.lookup(port, protocol) { + Some("http") => "HTTP".to_string(), + Some("https") => "HTTPS".to_string(), + Some("ssh") => "SSH".to_string(), + Some("ftp") => "FTP".to_string(), + Some("smtp") => "SMTP".to_string(), + Some("imap") => "IMAP".to_string(), + Some("pop3") => "POP3".to_string(), + Some("dns") => "DNS".to_string(), + Some("dhcp-server") => "DHCP Server".to_string(), + Some("dhcp-client") => "DHCP Client".to_string(), + Some("ntp") => "NTP".to_string(), + Some("rdp") => "RDP".to_string(), + Some("vnc") => "VNC".to_string(), + Some(name) => { + // Capitalize first letter + let mut chars = name.chars(); + match chars.next() { + None => String::new(), + Some(first) => first.to_uppercase().collect::() + chars.as_str(), + } + } + None => format!("{}/{}", port, protocol), + } + } + + /// Get all services for a specific protocol + pub fn services_by_protocol(&self, protocol: Protocol) -> Vec<(u16, &str)> { + let mut services: Vec<(u16, &str)> = self + .services + .iter() + .filter_map(|((port, proto), name)| { + if *proto == protocol { + Some((*port, name.as_str())) + } else { + None + } + }) + .collect(); + + services.sort_by_key(|(port, _)| *port); + services + } + + /// Get the number of services loaded + pub fn len(&self) -> usize { + self.services.len() + } + + /// Check if the lookup table is empty + pub fn is_empty(&self) -> bool { + self.services.is_empty() + } +} + +impl Default for ServiceLookup { + fn default() -> Self { + Self::with_defaults() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_services() { + let lookup = ServiceLookup::with_defaults(); + + assert_eq!(lookup.lookup(80, Protocol::TCP), Some("http")); + assert_eq!(lookup.lookup(443, Protocol::TCP), Some("https")); + assert_eq!(lookup.lookup(22, Protocol::TCP), Some("ssh")); + assert_eq!(lookup.lookup(53, Protocol::UDP), Some("dns")); + } + + #[test] + fn test_display_names() { + let lookup = ServiceLookup::with_defaults(); + + assert_eq!(lookup.display_name(80, Protocol::TCP), "HTTP"); + assert_eq!(lookup.display_name(443, Protocol::TCP), "HTTPS"); + assert_eq!(lookup.display_name(12345, Protocol::TCP), "12345/TCP"); + } + + #[test] + fn test_lookup_with_fallback() { + let lookup = ServiceLookup::with_defaults(); + + assert_eq!( + lookup.lookup_with_fallback(80, Protocol::TCP), + Some("http".to_string()) + ); + assert_eq!( + lookup.lookup_with_fallback(50000, Protocol::TCP), + Some("private".to_string()) + ); + assert_eq!( + lookup.lookup_with_fallback(65000, Protocol::TCP), + Some("ephemeral".to_string()) + ); + } + + #[test] + fn test_services_by_protocol() { + let lookup = ServiceLookup::with_defaults(); + + let tcp_services = lookup.services_by_protocol(Protocol::TCP); + assert!(tcp_services.iter().any(|(port, _)| *port == 80)); + assert!(tcp_services.iter().any(|(port, _)| *port == 443)); + + let udp_services = lookup.services_by_protocol(Protocol::UDP); + assert!(udp_services.iter().any(|(port, _)| *port == 53)); + } +} diff --git a/src/network/types.rs b/src/network/types.rs new file mode 100644 index 0000000..f337e07 --- /dev/null +++ b/src/network/types.rs @@ -0,0 +1,273 @@ +use std::net::SocketAddr; +use std::time::{Duration, Instant, SystemTime}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Protocol { + TCP, + UDP, + ICMP, + ARP, +} + +impl std::fmt::Display for Protocol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Protocol::TCP => write!(f, "TCP"), + Protocol::UDP => write!(f, "UDP"), + Protocol::ICMP => write!(f, "ICMP"), + Protocol::ARP => write!(f, "ARP"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TcpState { + Listen, + SynSent, + SynReceived, + Established, + FinWait1, + FinWait2, + CloseWait, + LastAck, + TimeWait, + Closing, + Closed, +} + +#[derive(Debug, Clone, Copy)] +pub enum ProtocolState { + Tcp(TcpState), + Udp, + Icmp { icmp_type: u8, icmp_code: u8 }, + Arp { operation: ArpOperation }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ArpOperation { + Request, + Reply, +} + +#[derive(Debug, Clone)] +pub enum ApplicationProtocol { + Http(HttpInfo), + Https(TlsInfo), + Dns(DnsInfo), + Ssh, + Quic, + Unknown, +} + +#[derive(Debug, Clone)] +pub struct HttpInfo { + pub version: HttpVersion, + pub method: Option, + pub host: Option, + pub path: Option, + pub status_code: Option, + pub user_agent: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HttpVersion { + Http10, + Http11, + Http2, + Http3, +} + +#[derive(Debug, Clone)] +pub struct TlsInfo { + pub version: Option, + pub sni: Option, + pub alpn: Vec, + pub cipher_suite: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TlsVersion { + Ssl3, + Tls10, + Tls11, + Tls12, + Tls13, +} + +#[derive(Debug, Clone)] +pub struct DnsInfo { + pub query_name: Option, + pub query_type: Option, + pub response_ips: Vec, + pub is_response: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DnsQueryType { + A, + AAAA, + CNAME, + MX, + TXT, + Other(u16), +} + +#[derive(Debug, Clone)] +pub struct DpiInfo { + pub application: ApplicationProtocol, + pub first_packet_time: Instant, + pub last_update_time: Instant, +} + +#[derive(Debug, Clone)] +pub struct RateInfo { + pub incoming_bps: f64, + pub outgoing_bps: f64, + pub last_calculation: Instant, +} + +impl Default for RateInfo { + fn default() -> Self { + Self { + incoming_bps: 0.0, + outgoing_bps: 0.0, + last_calculation: Instant::now(), + } + } +} + +#[derive(Debug, Clone)] +pub struct Connection { + // Core identification + pub protocol: Protocol, + pub local_addr: SocketAddr, + pub remote_addr: SocketAddr, + + // Protocol state + pub protocol_state: ProtocolState, + + // Process information + pub pid: Option, + pub process_name: Option, + + // Traffic statistics + pub bytes_sent: u64, + pub bytes_received: u64, + pub packets_sent: u64, + pub packets_received: u64, + + // Timing + pub created_at: SystemTime, + pub last_activity: SystemTime, + + // Service identification + pub service_name: Option, + + // Deep packet inspection + pub dpi_info: Option, + + // Performance metrics + pub current_rate_bps: RateInfo, + pub rtt_estimate: Option, + + // Backward compatibility fields + pub current_incoming_rate_bps: f64, + pub current_outgoing_rate_bps: f64, +} + +impl Connection { + /// Create a new connection + pub fn new( + protocol: Protocol, + local_addr: SocketAddr, + remote_addr: SocketAddr, + state: ProtocolState, + ) -> Self { + let now = SystemTime::now(); + Self { + protocol, + local_addr, + remote_addr, + protocol_state: state, + pid: None, + process_name: None, + bytes_sent: 0, + bytes_received: 0, + packets_sent: 0, + packets_received: 0, + created_at: now, + last_activity: now, + service_name: None, + dpi_info: None, + current_rate_bps: RateInfo::default(), + rtt_estimate: None, + current_incoming_rate_bps: 0.0, + current_outgoing_rate_bps: 0.0, + } + } + + /// Generate a unique key for this connection + pub fn key(&self) -> String { + format!( + "{:?}:{}-{:?}:{}", + self.protocol, self.local_addr, self.protocol, self.remote_addr + ) + } + + /// Check if connection is active (had activity in the last minute) + pub fn is_active(&self) -> bool { + self.last_activity.elapsed().unwrap_or_default() < Duration::from_secs(60) + } + + /// Get the age of the connection + pub fn age(&self) -> Duration { + self.created_at.elapsed().unwrap_or_default() + } + + /// Get time since last activity + pub fn idle_time(&self) -> Duration { + self.last_activity.elapsed().unwrap_or_default() + } + + /// Get display state + pub fn state(&self) -> String { + match &self.protocol_state { + ProtocolState::Tcp(tcp_state) => format!("{:?}", tcp_state), + ProtocolState::Udp => "ACTIVE".to_string(), + ProtocolState::Icmp { icmp_type, .. } => match icmp_type { + 8 => "ECHO_REQUEST".to_string(), + 0 => "ECHO_REPLY".to_string(), + 3 => "DEST_UNREACH".to_string(), + 11 => "TIME_EXCEEDED".to_string(), + _ => "UNKNOWN".to_string(), + }, + ProtocolState::Arp { operation } => match operation { + ArpOperation::Request => "ARP_REQUEST".to_string(), + ArpOperation::Reply => "ARP_REPLY".to_string(), + }, + } + } + + /// Update transfer rates + pub fn update_rates(&mut self, new_sent: u64, new_received: u64) { + let now = Instant::now(); + let elapsed = now + .duration_since(self.current_rate_bps.last_calculation) + .as_secs_f64(); + + if elapsed > 0.1 { + let sent_diff = new_sent.saturating_sub(self.bytes_sent) as f64; + let recv_diff = new_received.saturating_sub(self.bytes_received) as f64; + + self.current_rate_bps = RateInfo { + outgoing_bps: (sent_diff * 8.0) / elapsed, + incoming_bps: (recv_diff * 8.0) / elapsed, + last_calculation: now, + }; + + // Update backward compatibility fields + self.current_incoming_rate_bps = self.current_rate_bps.incoming_bps; + self.current_outgoing_rate_bps = self.current_rate_bps.outgoing_bps; + } + } +} diff --git a/src/network/windows.rs b/src/network/windows.rs deleted file mode 100644 index 69ab9f0..0000000 --- a/src/network/windows.rs +++ /dev/null @@ -1,212 +0,0 @@ -use anyhow::Result; -use std::process::Command; - -use super::{Connection, ConnectionState, NetworkMonitor, Process, Protocol}; - -/// Get platform-specific connections for Windows -pub fn get_platform_connections( - monitor: &NetworkMonitor, - connections: &mut Vec, -) -> Result<()> { - // Use netstat on Windows for both TCP and UDP - monitor.get_connections_from_netstat(connections)?; - - Ok(()) -} - -// Methods below remain part of NetworkMonitor impl -impl NetworkMonitor { - /// Get platform-specific process for a connection - pub(super) fn get_platform_process_for_connection( - &self, - connection: &Connection, - ) -> Option { - // Try netstat - if let Some(process) = try_netstat_command(connection) { - return Some(process); - } - - // Fall back to API calls if we implement them - try_windows_api(connection) - } - - /// Get process information by PID - pub(super) fn get_process_by_pid(&self, pid: u32) -> Option { - // Use tasklist to get process info - if let Ok(output) = Command::new("tasklist") - .args(["/FI", &format!("PID eq {}", pid), "/FO", "CSV", "/NH"]) - .output() - { - let text = String::from_utf8_lossy(&output.stdout); - let line = text.lines().next().unwrap_or(""); - - // Parse CSV format - let parts: Vec<&str> = line.split(',').collect(); - if parts.len() >= 2 { - // Remove quotes - let name = parts[0].trim_matches('"').to_string(); - - return Some(Process { - pid, - name, - }); - } - } - - None - } - - /// Get connections from netstat command - pub(super) fn get_connections_from_netstat(&self, connections: &mut Vec) -> Result<()> { - let output = Command::new("netstat").args(["-ano"]).output()?; - - if output.status.success() { - let text = String::from_utf8_lossy(&output.stdout); - - for line in text.lines().skip(4) { - // Skip headers - let fields: Vec<&str> = line.split_whitespace().collect(); - if fields.len() < 5 { - continue; - } - - // Parse protocol - let protocol = match fields[0].to_lowercase().as_str() { - "tcp" | "tcp6" => Protocol::TCP, - "udp" | "udp6" => Protocol::UDP, - _ => continue, - }; - - // Parse state - let state_pos = 3; - let state = if fields.len() > state_pos { - match fields[state_pos] { - "ESTABLISHED" => ConnectionState::Established, - "LISTENING" | "LISTEN" => ConnectionState::Listen, - "TIME_WAIT" => ConnectionState::TimeWait, - "CLOSE_WAIT" => ConnectionState::CloseWait, - "SYN_SENT" => ConnectionState::SynSent, - "SYN_RECEIVED" | "SYN_RECV" => ConnectionState::SynReceived, - "FIN_WAIT_1" => ConnectionState::FinWait1, - "FIN_WAIT_2" => ConnectionState::FinWait2, - "LAST_ACK" => ConnectionState::LastAck, - "CLOSING" => ConnectionState::Closing, - _ => ConnectionState::Unknown, - } - } else { - ConnectionState::Unknown - }; - - // Parse local and remote addresses - let local_idx = 1; - let remote_idx = 2; - - if let (Some(local), Some(remote)) = ( - super::parse_addr(fields[local_idx]), - super::parse_addr(fields[remote_idx]), - ) { - let mut conn = Connection::new(protocol, local, remote, state); - - // Parse PID - let pid_pos = 4; - if fields.len() > pid_pos && fields[pid_pos] != "-" { - if let Ok(pid) = fields[pid_pos].parse::() { - conn.pid = Some(pid); - } - } - - connections.push(conn); - } - } - } - - Ok(()) - } -} - -/// Get process information using netstat command -pub(super) fn try_netstat_command(connection: &Connection) -> Option { - let output = Command::new("netstat").args(["-ano"]).output().ok()?; - - if output.status.success() { - let text = String::from_utf8_lossy(&output.stdout); - let local_addr = format!("{}", connection.local_addr); - let remote_addr = format!("{}", connection.remote_addr); - - for line in text.lines().skip(2) { - // Skip headers - let fields: Vec<&str> = line.split_whitespace().collect(); - if fields.len() < 5 { - continue; - } - - // Check if this line matches our connection - let local_idx = 1; - let remote_idx = 2; - let proto_idx = 0; - - let matches_protocol = match connection.protocol { - Protocol::TCP => { - fields[proto_idx].eq_ignore_ascii_case("tcp") - || fields[proto_idx].eq_ignore_ascii_case("tcp6") - } - Protocol::UDP => { - fields[proto_idx].eq_ignore_ascii_case("udp") - || fields[proto_idx].eq_ignore_ascii_case("udp6") - } - _ => false, - }; - - if matches_protocol - && (fields[local_idx].contains(&local_addr) - || fields[local_idx].contains(&format!(":{}", connection.local_addr.port()))) - && (fields[remote_idx].contains(&remote_addr) - || fields[remote_idx].contains(&format!(":{}", connection.remote_addr.port()))) - { - // Found matching connection, get PID - let pid_pos = 4; - if fields.len() > pid_pos && fields[pid_pos] != "-" { - if let Ok(pid) = fields[pid_pos].parse::() { - // Get process name - let name = get_process_name_by_pid(pid) - .unwrap_or_else(|| format!("process-{}", pid)); - - return Some(Process { - pid, - name, - }); - } - } - - break; - } - } - } - - None -} - -/// Try Windows API to get process information -pub(super) fn try_windows_api(_connection: &Connection) -> Option { - // This would require using the Windows API (like GetExtendedTcpTable) - // For simplicity, we'll just return None as a placeholder - // In a real implementation, you'd use the windows crate to make API calls - None -} - -/// Get process name by PID -fn get_process_name_by_pid(pid: u32) -> Option { - let output = Command::new("tasklist") - .args(["/FI", &format!("PID eq {}", pid), "/FO", "CSV", "/NH"]) - .output() - .ok()?; - - let text = String::from_utf8_lossy(&output.stdout); - let line = text.lines().next()?; - - // Parse CSV format (remove quotes) - let name_end = line.find(',')? - 1; - let name = line[1..name_end].to_string(); - - Some(name) -} diff --git a/src/ui.rs b/src/ui.rs index 262e60a..d8d3029 100644 --- a/src/ui.rs +++ b/src/ui.rs @@ -6,11 +6,10 @@ use ratatui::{ text::{Line, Span}, widgets::{Block, Borders, Cell, Paragraph, Row, Table, Tabs, Wrap}, }; -// Removed unused import: use std::collections::HashMap; -use std::net::SocketAddr; // Import SocketAddr +use std::time::Instant; -use crate::app::{App, DetailFocusField, ViewMode}; // Added DetailFocusField -use crate::network::Protocol; +use crate::app::{App, AppStats}; +use crate::network::types::{Connection, Protocol}; pub type Terminal = RatatuiTerminal; @@ -40,11 +39,34 @@ pub fn restore_terminal(terminal: &mut Terminal Ok(()) } +/// UI state for managing the interface +pub struct UIState { + pub selected_tab: usize, + pub selected_connection: usize, + pub show_help: bool, +} + +impl Default for UIState { + fn default() -> Self { + Self { + selected_tab: 0, + selected_connection: 0, + show_help: false, + } + } +} + /// Draw the UI -pub fn draw(f: &mut Frame, app: &mut App) -> Result<()> { - // If still loading, show loading screen instead of normal UI - if app.is_loading { - draw_loading_screen(f, app); +pub fn draw( + f: &mut Frame, + app: &App, + ui_state: &UIState, + connections: &[Connection], + stats: &AppStats, +) -> Result<()> { + // If still loading, show loading screen + if app.is_loading() { + draw_loading_screen(f); return Ok(()); } @@ -57,41 +79,35 @@ pub fn draw(f: &mut Frame, app: &mut App) -> Result<()> { ]) .split(f.area()); - draw_tabs(f, app, chunks[0]); + draw_tabs(f, ui_state, chunks[0]); - match app.mode { - ViewMode::Overview => draw_overview(f, app, chunks[1])?, - ViewMode::ConnectionDetails => draw_connection_details(f, app, chunks[1])?, - ViewMode::Help => draw_help(f, app, chunks[1])?, + match ui_state.selected_tab { + 0 => draw_overview(f, ui_state, connections, stats, chunks[1])?, + 1 => draw_connection_details(f, ui_state, connections, chunks[1])?, + 2 => draw_help(f, chunks[1])?, + _ => {} } - draw_status_bar(f, app, chunks[2]); + draw_status_bar(f, connections.len(), chunks[2]); Ok(()) } /// Draw mode tabs -fn draw_tabs(f: &mut Frame, app: &App, area: Rect) { +fn draw_tabs(f: &mut Frame, ui_state: &UIState, area: Rect) { let titles = vec![ - Span::styled(app.i18n.get("overview"), Style::default().fg(Color::Green)), - Span::styled( - app.i18n.get("connections"), - Style::default().fg(Color::Green), - ), - Span::styled(app.i18n.get("help"), Style::default().fg(Color::Green)), + Span::styled("Overview", Style::default().fg(Color::Green)), + Span::styled("Details", Style::default().fg(Color::Green)), + Span::styled("Help", Style::default().fg(Color::Green)), ]; let tabs = Tabs::new(titles.into_iter().map(Line::from).collect::>()) .block( Block::default() .borders(Borders::ALL) - .title(app.i18n.get("rustnet")), + .title("RustNet Monitor"), ) - .select(match app.mode { - ViewMode::Overview => 0, - ViewMode::ConnectionDetails => 1, - ViewMode::Help => 2, - }) + .select(ui_state.selected_tab) .style(Style::default().fg(Color::White)) .highlight_style( Style::default() @@ -103,27 +119,38 @@ fn draw_tabs(f: &mut Frame, app: &App, area: Rect) { } /// Draw the overview mode -fn draw_overview(f: &mut Frame, app: &mut App, area: Rect) -> Result<()> { +fn draw_overview( + f: &mut Frame, + ui_state: &UIState, + connections: &[Connection], + stats: &AppStats, + area: Rect, +) -> Result<()> { let chunks = Layout::default() .direction(Direction::Horizontal) .constraints([Constraint::Percentage(70), Constraint::Percentage(30)]) .split(area); - draw_connections_list(f, app, chunks[0]); - draw_side_panel(f, app, chunks[1])?; + draw_connections_list(f, ui_state, connections, chunks[0]); + draw_stats_panel(f, connections, stats, chunks[1])?; Ok(()) } /// Draw connections list -fn draw_connections_list(f: &mut Frame, app: &mut App, area: Rect) { +fn draw_connections_list( + f: &mut Frame, + ui_state: &UIState, + connections: &[Connection], + area: Rect, +) { let widths = [ Constraint::Length(6), // Protocol Constraint::Length(28), // Local Address - Constraint::Length(38), // Remote Address - Increased Width + Constraint::Length(38), // Remote Address Constraint::Length(12), // State Constraint::Length(10), // Service - Constraint::Length(22), // Bandwidth (Down/Up) + Constraint::Length(22), // Bandwidth Constraint::Min(10), // Process ]; @@ -133,7 +160,7 @@ fn draw_connections_list(f: &mut Frame, app: &mut App, area: Rect) { "Remote Address", "State", "Service", - "Down / Up", // Updated Header + "Down / Up", "Process", ] .iter() @@ -146,167 +173,153 @@ fn draw_connections_list(f: &mut Frame, app: &mut App, area: Rect) { }); let header = Row::new(header_cells).height(1).bottom_margin(1); - let mut rows = Vec::new(); - // Collect addresses to format to avoid borrowing issues with app.format_socket_addr - let addresses_to_format: Vec<(SocketAddr, SocketAddr)> = app - .connections + let rows: Vec = connections .iter() - .map(|conn| (conn.local_addr, conn.remote_addr)) + .map(|conn| { + let pid_str = conn + .pid + .map(|p| p.to_string()) + .unwrap_or_else(|| "-".to_string()); + + let process_str = conn.process_name.clone().unwrap_or_else(|| "-".to_string()); + let process_display = if conn.pid.is_some() { + format!("{} ({})", process_str, pid_str) + } else { + process_str + }; + + let service_display = conn.service_name.clone().unwrap_or_else(|| "-".to_string()); + + let incoming_rate = format_rate(conn.current_incoming_rate_bps); + let outgoing_rate = format_rate(conn.current_outgoing_rate_bps); + let bandwidth_display = format!("{} / {}", incoming_rate, outgoing_rate); + + let cells = [ + Cell::from(conn.protocol.to_string()), + Cell::from(conn.local_addr.to_string()), + Cell::from(conn.remote_addr.to_string()), + Cell::from(conn.state()), + Cell::from(service_display), + Cell::from(bandwidth_display), + Cell::from(process_display), + ]; + Row::new(cells) + }) .collect(); - let mut formatted_addresses = Vec::new(); - for (local_addr, remote_addr) in addresses_to_format { - let local_display = app.format_socket_addr(local_addr); - let remote_display = app.format_socket_addr(remote_addr); - formatted_addresses.push((local_display, remote_display)); - } - - for (idx, conn) in app.connections.iter().enumerate() { - let pid_str = conn - .pid - .map(|p| p.to_string()) - .unwrap_or_else(|| "-".to_string()); - - let process_str = conn.process_name.clone().unwrap_or_else(|| "-".to_string()); - let process_display = format!("{} ({})", process_str, pid_str); - - let (local_display, remote_display) = formatted_addresses[idx].clone(); - let service_display = conn.service_name.clone().unwrap_or_else(|| "-".to_string()); - - let incoming_rate_str = format_rate_from_bytes_per_second(conn.current_incoming_rate_bps); - let outgoing_rate_str = format_rate_from_bytes_per_second(conn.current_outgoing_rate_bps); - let bandwidth_display = format!("{} / {}", incoming_rate_str, outgoing_rate_str); - - let cells = [ - Cell::from(conn.protocol.to_string()), - Cell::from(local_display), - Cell::from(remote_display), - Cell::from(conn.state()), - Cell::from(service_display), - Cell::from(bandwidth_display), // Updated Cell - Cell::from(process_display), - ]; - rows.push(Row::new(cells)); - } - // Create table state with current selection let mut state = ratatui::widgets::TableState::default(); - if !app.connections.is_empty() { - state.select(Some(app.selected_connection_idx)); + if !connections.is_empty() { + state.select(Some( + ui_state + .selected_connection + .min(connections.len().saturating_sub(1)), + )); } - let connections = Table::new(rows, &widths) + let connections_table = Table::new(rows, &widths) .header(header) .block( Block::default() .borders(Borders::ALL) - .title(app.i18n.get("connections")), + .title("Active Connections"), ) .row_highlight_style(Style::default().add_modifier(Modifier::REVERSED)) .highlight_symbol("> "); - f.render_stateful_widget(connections, area, &mut state); + f.render_stateful_widget(connections_table, area, &mut state); } -/// Draw side panel with stats -fn draw_side_panel(f: &mut Frame, app: &App, area: Rect) -> Result<()> { +/// Draw stats panel +fn draw_stats_panel( + f: &mut Frame, + connections: &[Connection], + stats: &AppStats, + area: Rect, +) -> Result<()> { let chunks = Layout::default() .direction(Direction::Vertical) .constraints([ - Constraint::Length(3), // Interface - Constraint::Min(0), // Summary stats (takes remaining space) + Constraint::Length(8), // Connection stats + Constraint::Min(0), // Traffic stats ]) .split(area); - let interface_text = format!( - "{}: {}", - app.i18n.get("interface"), - app.config - .interface - .clone() - .unwrap_or_else(|| app.i18n.get("default").to_string()) - ); - let interface_para = Paragraph::new(interface_text) - .block( - Block::default() - .borders(Borders::ALL) - .title(app.i18n.get("network")), - ) - .style(Style::default().fg(Color::White)); - f.render_widget(interface_para, chunks[0]); - - let tcp_count = app - .connections + // Connection statistics + let tcp_count = connections .iter() .filter(|c| c.protocol == Protocol::TCP) .count(); - let udp_count = app - .connections + let udp_count = connections .iter() .filter(|c| c.protocol == Protocol::UDP) .count(); - let process_count = app.processes.len(); - let stats_text: Vec = vec![ + let conn_stats_text: Vec = vec![ + Line::from(format!("TCP Connections: {}", tcp_count)), + Line::from(format!("UDP Connections: {}", udp_count)), + Line::from(format!("Total Connections: {}", connections.len())), + Line::from(""), Line::from(format!( - "{}: {}", - app.i18n.get("tcp_connections"), - tcp_count + "Packets Processed: {}", + stats + .packets_processed + .load(std::sync::atomic::Ordering::Relaxed) )), Line::from(format!( - "{}: {}", - app.i18n.get("udp_connections"), - udp_count - )), - Line::from(format!( - "{}: {}", - app.i18n.get("total_connections"), - app.connections.len() - )), - Line::from(format!("{}: {}", app.i18n.get("processes"), process_count)), - Line::from(""), // Spacer - Line::from(format!( - "{}: {}", - app.i18n.get("total_incoming"), - format_rate_from_bytes_per_second( - app.connections - .iter() - .map(|c| c.current_incoming_rate_bps) - .sum() - ) - )), - Line::from(format!( - "{}: {}", - app.i18n.get("total_outgoing"), - format_rate_from_bytes_per_second( - app.connections - .iter() - .map(|c| c.current_outgoing_rate_bps) - .sum() - ) + "Packets Dropped: {}", + stats + .packets_dropped + .load(std::sync::atomic::Ordering::Relaxed) )), ]; - let stats_para = Paragraph::new(stats_text) - .block( - Block::default() - .borders(Borders::ALL) - .title(app.i18n.get("statistics")), - ) + let conn_stats = Paragraph::new(conn_stats_text) + .block(Block::default().borders(Borders::ALL).title("Statistics")) .style(Style::default().fg(Color::White)); - f.render_widget(stats_para, chunks[1]); // Render stats into the second chunk which now takes remaining space + f.render_widget(conn_stats, chunks[0]); + + // Traffic statistics + let total_incoming: f64 = connections + .iter() + .map(|c| c.current_incoming_rate_bps) + .sum(); + let total_outgoing: f64 = connections + .iter() + .map(|c| c.current_outgoing_rate_bps) + .sum(); + + let traffic_stats_text: Vec = vec![ + Line::from(format!("Total Incoming: {}", format_rate(total_incoming))), + Line::from(format!("Total Outgoing: {}", format_rate(total_outgoing))), + Line::from(""), + Line::from(format!( + "Last Update: {:?} ago", + stats.last_update.read().unwrap().elapsed() + )), + ]; + + let traffic_stats = Paragraph::new(traffic_stats_text) + .block(Block::default().borders(Borders::ALL).title("Traffic")) + .style(Style::default().fg(Color::White)); + f.render_widget(traffic_stats, chunks[1]); Ok(()) } /// Draw connection details view -fn draw_connection_details(f: &mut Frame, app: &mut App, area: Rect) -> Result<()> { - if app.connections.is_empty() { - let text = Paragraph::new(app.i18n.get("no_connections")) +fn draw_connection_details( + f: &mut Frame, + ui_state: &UIState, + connections: &[Connection], + area: Rect, +) -> Result<()> { + if connections.is_empty() { + let text = Paragraph::new("No connections available") .block( Block::default() .borders(Borders::ALL) - .title(app.i18n.get("connection_details")), + .title("Connection Details"), ) .style(Style::default().fg(Color::Red)) .alignment(ratatui::layout::Alignment::Center); @@ -314,89 +327,46 @@ fn draw_connection_details(f: &mut Frame, app: &mut App, area: Rect) -> Result<( return Ok(()); } - let conn_idx = app.selected_connection_idx; - let local_addr_to_format = app.connections[conn_idx].local_addr; - let remote_addr_to_format = app.connections[conn_idx].remote_addr; - - // Format addresses before further immutable borrows of app.connections - let local_display = app.format_socket_addr(local_addr_to_format); - let remote_display = app.format_socket_addr(remote_addr_to_format); - - let conn = &app.connections[conn_idx]; // Now we can immutably borrow again + let conn_idx = ui_state + .selected_connection + .min(connections.len().saturating_sub(1)); + let conn = &connections[conn_idx]; let chunks = Layout::default() .direction(Direction::Vertical) .constraints([Constraint::Percentage(50), Constraint::Percentage(50)]) .split(area); + // Connection details let mut details_text: Vec = Vec::new(); - // Styles for focused IP - let local_ip_style = if app.detail_focus == DetailFocusField::LocalIp { - Style::default() - .fg(Color::Cyan) - .add_modifier(Modifier::BOLD) - } else { - Style::default() - }; - let remote_ip_style = if app.detail_focus == DetailFocusField::RemoteIp { - Style::default() - .fg(Color::Cyan) - .add_modifier(Modifier::BOLD) - } else { - Style::default() - }; - details_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("protocol")), - Style::default().fg(Color::Yellow), - ), + Span::styled("Protocol: ", Style::default().fg(Color::Yellow)), Span::raw(conn.protocol.to_string()), ])); - // Use pre-formatted addresses details_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("local_address")), - Style::default().fg(Color::Yellow), - ), - Span::styled(local_display, local_ip_style), // Apply style + Span::styled("Local Address: ", Style::default().fg(Color::Yellow)), + Span::raw(conn.local_addr.to_string()), ])); details_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("remote_address")), - Style::default().fg(Color::Yellow), - ), - Span::styled(remote_display, remote_ip_style), // Apply style + Span::styled("Remote Address: ", Style::default().fg(Color::Yellow)), + Span::raw(conn.remote_addr.to_string()), ])); - if app.show_locations && !conn.remote_addr.ip().is_unspecified() { - // Commented out private field access - } - details_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("state")), - Style::default().fg(Color::Yellow), - ), + Span::styled("State: ", Style::default().fg(Color::Yellow)), Span::raw(conn.state()), ])); details_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("process")), - Style::default().fg(Color::Yellow), - ), + Span::styled("Process: ", Style::default().fg(Color::Yellow)), Span::raw(conn.process_name.clone().unwrap_or_else(|| "-".to_string())), ])); details_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("pid")), - Style::default().fg(Color::Yellow), - ), + Span::styled("PID: ", Style::default().fg(Color::Yellow)), Span::raw( conn.pid .map(|p| p.to_string()) @@ -405,76 +375,59 @@ fn draw_connection_details(f: &mut Frame, app: &mut App, area: Rect) -> Result<( ])); details_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("age")), - Style::default().fg(Color::Yellow), - ), - Span::raw(format!("{:?}", conn.age())), + Span::styled("Service: ", Style::default().fg(Color::Yellow)), + Span::raw(conn.service_name.clone().unwrap_or_else(|| "-".to_string())), ])); - details_text.push(Line::from("")); // Spacer - details_text.push(Line::from(Span::styled( - "Use Up/Down to select IP, 'c' to copy.", // Hint text - Style::default().fg(Color::DarkGray), - ))); - let details = Paragraph::new(details_text) .block( Block::default() .borders(Borders::ALL) - .title(app.i18n.get("connection_details")), + .title("Connection Information"), ) .style(Style::default().fg(Color::White)) .wrap(Wrap { trim: true }); f.render_widget(details, chunks[0]); + // Traffic details let mut traffic_text: Vec = Vec::new(); + traffic_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("bytes_sent")), - Style::default().fg(Color::Yellow), - ), + Span::styled("Bytes Sent: ", Style::default().fg(Color::Yellow)), Span::raw(format_bytes(conn.bytes_sent)), ])); traffic_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("bytes_received")), - Style::default().fg(Color::Yellow), - ), + Span::styled("Bytes Received: ", Style::default().fg(Color::Yellow)), Span::raw(format_bytes(conn.bytes_received)), ])); traffic_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("packets_sent")), - Style::default().fg(Color::Yellow), - ), + Span::styled("Packets Sent: ", Style::default().fg(Color::Yellow)), Span::raw(conn.packets_sent.to_string()), ])); traffic_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("packets_received")), - Style::default().fg(Color::Yellow), - ), + Span::styled("Packets Received: ", Style::default().fg(Color::Yellow)), Span::raw(conn.packets_received.to_string()), ])); traffic_text.push(Line::from(vec![ - Span::styled( - format!("{}: ", app.i18n.get("last_activity")), - Style::default().fg(Color::Yellow), - ), - Span::raw(format!("{:?}", conn.idle_time())), + Span::styled("Current Rate (In): ", Style::default().fg(Color::Yellow)), + Span::raw(format_rate(conn.current_incoming_rate_bps)), + ])); + + traffic_text.push(Line::from(vec![ + Span::styled("Current Rate (Out): ", Style::default().fg(Color::Yellow)), + Span::raw(format_rate(conn.current_outgoing_rate_bps)), ])); let traffic = Paragraph::new(traffic_text) .block( Block::default() .borders(Borders::ALL) - .title(app.i18n.get("traffic")), + .title("Traffic Statistics"), ) .style(Style::default().fg(Color::White)) .wrap(Wrap { trim: true }); @@ -485,62 +438,48 @@ fn draw_connection_details(f: &mut Frame, app: &mut App, area: Rect) -> Result<( } /// Draw help screen -fn draw_help(f: &mut Frame, app: &App, area: Rect) -> Result<()> { +fn draw_help(f: &mut Frame, area: Rect) -> Result<()> { let help_text: Vec = vec![ Line::from(vec![ Span::styled( - "RustNet ", + "RustNet Monitor ", Style::default() .fg(Color::Green) .add_modifier(Modifier::BOLD), ), - Span::raw(app.i18n.get("help_intro")), + Span::raw("- Network Connection Monitor"), ]), Line::from(""), Line::from(vec![ Span::styled("q, Ctrl+C ", Style::default().fg(Color::Yellow)), - Span::raw(app.i18n.get("help_quit")), + Span::raw("Quit application"), ]), Line::from(vec![ - Span::styled("r ", Style::default().fg(Color::Yellow)), - Span::raw(app.i18n.get("help_refresh")), + Span::styled("Tab ", Style::default().fg(Color::Yellow)), + Span::raw("Switch between tabs"), ]), Line::from(vec![ Span::styled("↑/k, ↓/j ", Style::default().fg(Color::Yellow)), - Span::raw(app.i18n.get("help_navigate")), + Span::raw("Navigate connections"), ]), Line::from(vec![ Span::styled("Enter ", Style::default().fg(Color::Yellow)), - Span::raw(app.i18n.get("help_select")), + Span::raw("View connection details"), ]), Line::from(vec![ Span::styled("Esc ", Style::default().fg(Color::Yellow)), - Span::raw(app.i18n.get("help_back")), - ]), - Line::from(vec![ - Span::styled("l ", Style::default().fg(Color::Yellow)), - Span::raw(app.i18n.get("help_toggle_location")), - ]), - Line::from(vec![ - Span::styled("d ", Style::default().fg(Color::Yellow)), - Span::raw(app.i18n.get("help_toggle_dns")), + Span::raw("Return to overview"), ]), Line::from(vec![ Span::styled("h ", Style::default().fg(Color::Yellow)), - Span::raw(app.i18n.get("help_toggle_help")), - ]), - Line::from(vec![ - Span::styled("Ctrl+D ", Style::default().fg(Color::Yellow)), - Span::raw(app.i18n.get("help_dump_connections")), + Span::raw("Toggle this help screen"), ]), + Line::from(""), + Line::from("Press any key to continue..."), ]; let help = Paragraph::new(help_text) - .block( - Block::default() - .borders(Borders::ALL) - .title(app.i18n.get("help")), - ) + .block(Block::default().borders(Borders::ALL).title("Help")) .style(Style::default().fg(Color::White)) .wrap(Wrap { trim: true }) .alignment(ratatui::layout::Alignment::Left); @@ -551,12 +490,10 @@ fn draw_help(f: &mut Frame, app: &App, area: Rect) -> Result<()> { } /// Draw status bar -fn draw_status_bar(f: &mut Frame, app: &App, area: Rect) { +fn draw_status_bar(f: &mut Frame, connection_count: usize, area: Rect) { let status = format!( - "{} | {} | {}", - app.i18n.get("press_h_for_help"), - format!("{}: {}", app.i18n.get("language"), app.config.language), - format!("{}: {}", app.i18n.get("connections"), app.connections.len()) + " Press 'h' for help | Connections: {} | Tab to switch views ", + connection_count ); let status_bar = Paragraph::new(status) @@ -566,98 +503,64 @@ fn draw_status_bar(f: &mut Frame, app: &App, area: Rect) { f.render_widget(status_bar, area); } -/// Draw loading screen with progress message -fn draw_loading_screen(f: &mut Frame, app: &App) { +/// Draw loading screen +fn draw_loading_screen(f: &mut Frame) { let chunks = Layout::default() - .direction(Direction::Vertical) - .constraints([ - Constraint::Length(3), // Header - Constraint::Min(0), // Content - Constraint::Length(1), // Status - ]) - .split(f.area()); - - // Draw header - let header = Paragraph::new("RustNet - Network Monitor") - .style( - Style::default() - .fg(Color::Green) - .add_modifier(Modifier::BOLD), - ) - .alignment(ratatui::layout::Alignment::Center) - .block(Block::default().borders(Borders::ALL)); - f.render_widget(header, chunks[0]); - - // Draw loading content - let loading_content = Layout::default() .direction(Direction::Vertical) .constraints([ Constraint::Percentage(40), Constraint::Length(5), Constraint::Percentage(40), ]) - .split(chunks[1]); + .split(f.area()); let loading_text = vec![ Line::from(""), Line::from(vec![ - Span::styled(app.get_spinner_char(), Style::default().fg(Color::Yellow)), - Span::styled(" ", Style::default()), - Span::styled(&app.loading_message, Style::default().fg(Color::White)), + Span::styled("⣾ ", Style::default().fg(Color::Yellow)), + Span::styled( + "Loading network connections...", + Style::default().fg(Color::White), + ), ]), Line::from(""), Line::from(vec![Span::styled( - "Please wait while we discover network connections", - Style::default().fg(Color::Cyan), - )]), - Line::from(""), - Line::from(vec![Span::styled( - "This may take 10-30 seconds depending on your system", + "This may take a few seconds", Style::default().fg(Color::DarkGray), )]), ]; let loading_paragraph = Paragraph::new(loading_text) .alignment(ratatui::layout::Alignment::Center) - .block(Block::default().borders(Borders::ALL).title("Loading")); - f.render_widget(loading_paragraph, loading_content[1]); + .block( + Block::default() + .borders(Borders::ALL) + .title("RustNet Monitor"), + ); - // Draw status - let status = "Press Ctrl+C to cancel"; - let status_bar = Paragraph::new(status) - .style(Style::default().fg(Color::White).bg(Color::Blue)) - .alignment(ratatui::layout::Alignment::Center); - f.render_widget(status_bar, chunks[2]); + f.render_widget(loading_paragraph, chunks[1]); } -// format_rate function removed as it's no longer used. - -/// Format rate (given as f64 bytes_per_second) to human readable form (KB/s, MB/s, etc.) -fn format_rate_from_bytes_per_second(bytes_per_second: f64) -> String { +/// Format rate to human readable form +fn format_rate(bytes_per_second: f64) -> String { const KB_PER_SEC: f64 = 1024.0; const MB_PER_SEC: f64 = KB_PER_SEC * 1024.0; const GB_PER_SEC: f64 = MB_PER_SEC * 1024.0; - if bytes_per_second.is_nan() || bytes_per_second.is_infinite() { - return "-".to_string(); - } - if bytes_per_second >= GB_PER_SEC { format!("{:.2} GB/s", bytes_per_second / GB_PER_SEC) } else if bytes_per_second >= MB_PER_SEC { format!("{:.2} MB/s", bytes_per_second / MB_PER_SEC) } else if bytes_per_second >= KB_PER_SEC { format!("{:.2} KB/s", bytes_per_second / KB_PER_SEC) - } else if bytes_per_second > 0.1 || bytes_per_second == 0.0 { - // Show B/s for very small rates or zero + } else if bytes_per_second > 0.0 { format!("{:.0} B/s", bytes_per_second) } else { - // For very small, non-zero rates, indicate less than 1 B/s - "<1 B/s".to_string() + "-".to_string() } } -/// Format bytes to human readable form (KB, MB, etc.) +/// Format bytes to human readable form fn format_bytes(bytes: u64) -> String { const KB: u64 = 1024; const MB: u64 = KB * 1024;