// Copyright Marcus Del Favero 2025
// Licensed under the GNU AGPLv3 with an exception, see `README.md` for details
pub mod udp;
pub mod quic;
pub mod hostname;
pub mod security;

use std::net::{SocketAddr, SocketAddrV4, IpAddr, Ipv4Addr, Ipv6Addr};

use sysinfo::{Networks, NetworkData};
use quinn::Endpoint;
use tokio::{task, time::{self, Duration}};

#[derive(Clone, Copy)]
pub struct Ipv4Network {
	pub addr: Ipv4Addr,
	pub net_prefix: u8,
}

impl Ipv4Network {
	pub fn networks(networks: &Networks) -> impl Iterator<Item = Ipv4Network> + '_ {
		networks
			.values()
			.flat_map(NetworkData::ip_networks)
			.filter_map(|ip_net| match ip_net.addr {
				IpAddr::V4(addr) => Some(Ipv4Network { addr, net_prefix: ip_net.prefix }),
				IpAddr::V6(_) => None,
			})
	}

	pub fn directed_broadcast(self) -> Ipv4Addr {
		Ipv4Addr::from_bits(self.addr.to_bits() | self.not_netmask())
	}

	pub fn contains(self, addr: Ipv4Addr) -> bool {
		let netmask = !self.not_netmask();
		(self.addr.to_bits() & netmask) == (addr.to_bits() & netmask)
	}

	/**
	 * Custom `contains` implementation when determining whether an IP address
	 * belongs to a local network.
	 *
	 * This method mitigates against the problem of a client being part of a
	 * malicious network with a DHCP server that report's the network's prefix as
	 * 0, or something else very small.
	 *
	 * This would result in all, or large parts of the internet being considered
	 * on the local network and so a connecting client won't get a TLS warning
	 * and the status bar will say that the connection's local.
	 *
	 * This method mitigates against this bug by excluding networks with a prefix
	 * that's too small.
	 *
	 * Note that this fix isn't perfect as a DHCP server can choose a /16 prefix
	 * which a particular targeted server's IP address would fall in.
	 *
	 * Also this doesn't solve the problem of a DHCP server specifying a DNS
	 * server that could redirect things to a node on the local network.
	 *
	 * I think that /16 is good enough for IPv4 as it's very unlikely that
	 * someone would be allocated a larger block of public IP addresses.
	 *
	 * I'm using /48 for IPv6 as that seems to be the largest blocks allocated to
	 * regular users
	 * (https://en.wikipedia.org/wiki/IPv6_address#General_allocation).
	 */
	pub fn contains_secure(self, addr: Ipv4Addr) -> bool {
		if self.net_prefix < 16 { return false; }
		self.contains(addr)
	}

	fn not_netmask(self) -> u32 {
		let host_bits = u32::BITS - self.net_prefix as u32;
		(1 << host_bits) - 1
	}
}

#[derive(Clone, Copy)]
pub struct Ipv6Network {
	pub addr: Ipv6Addr,
	pub net_prefix: u8,
}

impl Ipv6Network {
	pub fn networks(networks: &Networks) -> impl Iterator<Item = Ipv6Network> + '_ {
		networks
			.values()
			.flat_map(NetworkData::ip_networks)
			.filter_map(|ip_net| match ip_net.addr {
				IpAddr::V4(_) => None,
				IpAddr::V6(addr) => Some(Ipv6Network { addr, net_prefix: ip_net.prefix }),
			})
	}

	fn contains(self, addr: Ipv6Addr) -> bool {
		let netmask = !self.not_netmask();
		(self.addr.to_bits() & netmask) == (addr.to_bits() & netmask)
	}

	/**
	 * See `Ipv4Network::contains_secure` for more information.
	 */
	pub fn contains_secure(self, addr: Ipv6Addr) -> bool {
		if self.net_prefix < 48 { return false; }
		self.contains(addr)
	}

	fn not_netmask(self) -> u128 {
		let host_bits = u128::BITS as u128 - self.net_prefix as u128;
		(1 << host_bits) - 1
	}
}

pub trait SocketAddrToCanonical {
	fn to_canonical(self) -> SocketAddr;
}

impl SocketAddrToCanonical for SocketAddr {
	fn to_canonical(self) -> SocketAddr {
		if let IpAddr::V4(ip_addr) = self.ip().to_canonical() {
			SocketAddr::V4(SocketAddrV4::new(ip_addr, self.port()))
		} else {
			self
		}
	}
}

pub trait WaitIdleTimeout {
	async fn wait_idle_timeout(&self);
}

impl WaitIdleTimeout for Endpoint {
	async fn wait_idle_timeout(&self) {
		log::info!("Gracefully finishing any connections...");
		if time::timeout(Duration::from_secs(15), self.wait_idle()).await.is_err() {
			log::info!("Finishing connections timed out, exiting now");
		}
	}
}

#[derive(Clone, Copy, PartialEq, Eq)]
pub enum IpCategory {
	Loopback,
	Local,
	Global,
}

pub trait IpClassify {
	fn classify(&mut self, addr: IpAddr) -> IpCategory {
		let addr = addr.to_canonical();

		// First checks loopback, then local and finally global
		if addr.is_loopback() { return IpCategory::Loopback; }
		if self.get_networks().values().flat_map(NetworkData::ip_networks).any(|ip_net| addr == ip_net.addr) { return IpCategory::Loopback; }

		match addr {
			IpAddr::V4(addr) => {
				if
					addr.is_private() ||
					addr.is_link_local() ||
					Ipv4Network::networks(self.get_networks()).any(|ip_net| ip_net.contains_secure(addr))
				{
					IpCategory::Local
				} else { IpCategory::Global }
			},
			IpAddr::V6(addr) => {
				if
					addr.is_unicast_link_local() /* Link-local */ ||
					addr.is_unique_local() /* Private networks */ ||
					Ipv6Network::networks(self.get_networks()).any(|ip_net| ip_net.contains_secure(addr))
				{
					IpCategory::Local
				} else { IpCategory::Global }
			},
		}
	}

	fn get_networks(&mut self) -> &Networks;
}

impl IpClassify for &Networks {
	fn get_networks(&mut self) -> &Networks {
		self
	}
}

impl IpClassify for Option<Networks> {
	fn get_networks(&mut self) -> &Networks {
		/*
		 * I/O is performed by this call based on a quick look at the source code.
		 * From benchmarking in the dev profile, it took a bit more than a
		 * millisecond to complete, but still using `block_in_place` as edge cases
		 * might make it take longer.
		 */
		self.get_or_insert_with(|| task::block_in_place(Networks::new_with_refreshed_list))
	}
}
