// Copyright Marcus Del Favero 2025
// Licensed under the GNU AGPLv3 with an exception, see `README.md` for details
mod config;

use std::{sync::Arc, fmt::Write, time::Duration, net::{IpAddr, Ipv6Addr, SocketAddr}};

use tokio::{net, sync::mpsc};
use quinn::{Endpoint, Connection, ConnectionError, TransportErrorCode, TransportConfig, VarInt};

use config::SecurityPolicy;

use crate::net::utils::{IpClassify, IpCategory, hostname::Hostname, security::Security};
use crate::utils::ToStr;

// Milliseconds
const MAX_IDLE_TIMEOUT: u32 = 120_000;
const KEEPALIVE_INTERVAL: u64 = 105_000;

#[derive(Clone, Copy)]
pub enum IdleConfig {
	/**
	 * The default configuration of a 30 second max idle timeout and no
	 * keep-alive packets.
	 *
	 * This should be used by clients as they might be behind a NAT and timeouts
	 * longer than 30 seconds can cause NAT bindings to timeout (see RFC 9308
	 * section 3.2, https://www.rfc-editor.org/rfc/rfc9308.html#section-3.2).
	 */
	Default,

	/**
	 * Uses a larger max idle timeout of two minutes.
	 *
	 * This would be used rather than `LargerAndKeepAlive` if the server isn't
	 * publishing, as only one node needs to send keep-alive packets and it makes
	 * more sense for that to be the publishing server rather than a publication
	 * accepting server.
	 */
	Larger,

	/**
	 * Uses a larger max idle timeout of two minutes and sends keep-alive packets
	 * to keep connections active indefinitely even if no data is being sent.
	 *
	 * This exists so that publish-subscribe connections don't get timed out when
	 * there is no network activity, which is a normal use case when no changes
	 * need to be published.
	 *
	 * The larger max idle timeout is used to reduce the rate at which keep-alive
	 * packets need to be sent out, which reduces bandwidth and power.
	 *
	 * Since this is intended for communication between two publicly available
	 * servers, NAT problems are less likely to affect things. Though the cited
	 * RFC does mention some cases such as firewalls can cause problems.
	 */
	LargerAndKeepAlive,
}

pub fn client_endpoint() -> Result<Endpoint, String> {
	Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)).map_err(|err| format!("cannot create QUIC endpoint: {err}"))
}

/**
 * Annoyingly `quinn` doesn't allow selectively changing these settings once a
 * connection has been created, so these settings need to apply to all
 * connections.
 */
pub fn transport_config(idle_config: IdleConfig) -> Arc<TransportConfig> {
	let mut config = TransportConfig::default();

	match idle_config {
		IdleConfig::Default => (),
		IdleConfig::Larger | IdleConfig::LargerAndKeepAlive => _ = config.max_idle_timeout(Some(VarInt::from_u32(MAX_IDLE_TIMEOUT).into())),
	}

	match idle_config {
		IdleConfig::Default | IdleConfig::Larger => (),
		IdleConfig::LargerAndKeepAlive => _ = config.keep_alive_interval(Some(Duration::from_millis(KEEPALIVE_INTERVAL))),
	}

	Arc::new(config)
}

pub async fn connect(endpoint: Endpoint, host: &str, port: u16, idle_config: IdleConfig) -> Result<(Connection, Security), String> {
	let mut policy = SecurityPolicy::new(host);
	let insecure_config = config::new_insecure(idle_config)?;
	let secure_config = config::new_secure(idle_config)?;

	let (conn_sender, mut conn_receiver) = mpsc::channel(16);

	// Performs DNS resolution if a domain name is provided
	let socket_addrs = net::lookup_host((host, port)).await.map_err(|err| format!("failed DNS resolution of host \"{host}\": {err}"))?.collect::<Vec<_>>();
	let host: Arc<str> = Arc::from(host);
	let mut conn_tasks = Vec::with_capacity(socket_addrs.len());

	/*
	 * Attempts to connect to every result from DNS resolution in parallel.
	 *
	 * To see the purpose of this, suppose that the server is bound to
	 * 127.0.0.1 and the user connects to `localhost`. This hostname first
	 * resolves to `::1` and then `127.0.0.1`. If this is done in parallel,
	 * then a lot of time will be spent waiting for a connection to `::1`,
	 * before nearly immediately connecting to `127.0.0.1` once the first
	 * connection times out.
	 */
	for socket_addr in socket_addrs {
		let ip_addr = socket_addr.ip();
		let check_tls_cert = policy.needs_security(ip_addr);

		let security = match policy.classifier().classify(ip_addr) {
			/*
			 * Note, if `check_tls_cert` is true, at this point in time the
			 * connection might be actively eavesdropped, but the user would get
			 * a TLS warning and the connection won't succeed.
			 */
			_ if check_tls_cert => Security::Secure,
			IpCategory::Loopback => Security::Secure,
			IpCategory::Local => Security::Local,
			IpCategory::Global => Security::Insecure,
		};

		let config = if check_tls_cert { secure_config.clone() } else { insecure_config.clone() };

		let conn_sender = conn_sender.clone();

		/*
		 * When connecting to an IPv6 link-local address (fe80::/10), the scope
		 * id is important to resolve the ambiguity of which interface to use.
		 * However, scope ids are expressed as a '%' in the address, which
		 * isn't allowed for the server name, so I'm using any other address
		 * that works (in this case `localhost`).
		 */
		let server_name = if check_tls_cert { &host } else { "localhost" };
		let connection_res = endpoint.connect_with(config, socket_addr, server_name);

		let handle = tokio::spawn(async move {
			match connection_res {
				Ok(connecting) => {
					let res = connecting.await.map_or_else(
						|err| {
							/*
							 * Provides a custom error message if the CA isn't
							 * recognised telling the user how to bypass the
							 * certificate check by manually entering the IP address.
							 *
							 * Note that TLS alert code 48 means the CA is unknown in
							 * both TLS 1.2 and 1.3: RFC 8446 (page
							 * 86) and RFC 5246 (page 69).
							 */
							if let ConnectionError::TransportError(err) = &err && err.code == TransportErrorCode::crypto(48) {
								return Err((ip_addr, format!("unknown certificate authority: if you still want to connect to the server, manually enter the IP address of `{ip_addr}` to bypass the certificate check. WARNING: Doing this makes you vulnerable to active eavesdropping and impersonation attacks (similar to browsing the web without using HTTPS).")));
							}

							Err((ip_addr, err.to_string()))
						},
						|connection| Ok((connection, security)),
					);

					conn_sender.send(res).await
				},
				Err(err) => conn_sender.send(Err((ip_addr, err.to_string()))).await,
			}
		});

		conn_tasks.push(handle);
	}

	drop(conn_sender);

	// Iterates through all connections and uses the first one
	let mut errors = Vec::new();
	let res = loop {
		match conn_receiver.recv().await {
			Some(Ok((conn, security))) => {
				log::debug!("connecting to {host}, security = {}", security.to_str());
				break (conn, security);
			},
			Some(Err(err)) => errors.push(err),
			None if errors.is_empty() => return Err(format!("failed connecting to \"{host}\", no socket adresses were resolved for domain name")),
			None if errors.len() == 1 => return Err(format!("failed connecting to \"{host}\": {}", errors.last().map_or("[no error returned, something went wrong]", |err| err.1.as_str()))),
			None => {
				let mut err = format!("failed connecting to \"{host}\", {} socket addresses resolved and attempted for domain name, errors were: ", conn_tasks.len());
				for (i, e) in errors.into_iter().enumerate() {
					if i != 0 { err.push_str(", "); }
					let _ = write!(err, "\"{}\" from IP address `{}`", e.1, e.0);
				}
				return Err(err);
			}
		}
	};

	// Cancels all tasks to prevent any connections still attempted from wasting resources
	for task in conn_tasks {
		task.abort();
	}

	Ok(res)
}

/**
 * Connects to a socket address without any DNS resolution.
 *
 * The TLS certificate **isn't checked**, so the resulting connection should
 * be assumed to be insecure. This method is designed for the LAN where
 * transport security isn't as important.
 */
pub async fn connect_to_addr(endpoint: Endpoint, addr: SocketAddr, idle_config: IdleConfig) -> Result<Connection, String> {
	let connecting = endpoint.connect_with(config::new_insecure(idle_config)?, addr, "localhost" /* Doesn't matter */).map_err(|err| format!("failed starting connection: {err}"))?;
	let connection = connecting.await.map_err(|err| format!("failed connecting to {addr}: {err}"))?;
	Ok(connection)
}

/**
 * Connections to a socket address without any DNS resolution, but checking the
 * TLS certificate with a provided hostname.
 */
pub async fn connect_to_addr_secure(endpoint: Endpoint, addr: SocketAddr, host: &Hostname, idle_config: IdleConfig) -> Result<Connection, String> {
	let connecting = endpoint.connect_with(config::new_secure(idle_config)?, addr, host.get()).map_err(|err| format!("failed starting connection: {err}"))?;
	let connection = connecting.await.map_err(|err| format!("failed connecting to {addr}: {err}"))?;
	Ok(connection)
}
