// Copyright Marcus Del Favero 2025
// Licensed under the GNU AGPLv3 with an exception, see `README.md` for details
mod quic_config;

use std::{mem, sync::Arc, fmt::Write};

use tokio::{net, time::{self, Duration}, sync::mpsc};
use quinn::{Endpoint, Connection, ConnectionError, TransportErrorCode};

use quic_config::SecurityPolicy;

use super::{message::{Message, WorldMessage, MiscMessage}, message_stream::{singleplayer::SingleplayerMessageStream, quic_client::QuicClientMs, DynMs, MsError, Security}};
use super::{Address, text_io::ServerTextOutput, scores::Scores, serp::PlayRequest, timer::Timer, utils::NotPublic};

use crate::world::{World, MAX_STEPS, player::PlayerId, update::WorldUpdate};
use crate::playing::view::Sfx;
use crate::blocks::Blocks;
use crate::utils::ToStr;

const MAX_SERVER_TEXT_OUTPUT: usize = 4096;
const MAX_SOUNDS: usize = 16;

pub struct Client {
	stream: Box<DynMs>,
	updates: Vec<(WorldUpdate, u64)>,
	text_output: Vec<ServerTextOutput>,
	sounds: Vec<Sfx>,
	scores: Scores,
	timer: Option<(Timer, f64)>,
	sanity_checks: bool,
	sanity_check_world: Option<World>,
	spectator_count: Option<usize>,
}

/**
 * An enum categorising types of errors methods of the Client return.
 */
pub enum ClientError {
	Io(MsError), // I/O errors in the message stream
	Protocol(String), // Protocol violations
	Limit(String), // When a limit to prevent application layer denial of service attacks is violated
}

pub type ClientRes = (PlayerId, Client, World, Blocks);

impl Client {
	pub async fn multiplayer(endpoint: Result<Endpoint, String>, addr: Address, request: PlayRequest, sanity_checks: bool) -> Result<ClientRes, String> {
		time::timeout(Duration::from_secs(10), async {
			let (connection, security) = Client::connect(endpoint?, addr.host.as_ref(), addr.port).await?;
			let (stream, init_msg) = QuicClientMs::build(connection, request, security).await?;
			Client::build(init_msg, Box::new(stream), sanity_checks)
		}).await.unwrap_or_else(|_| Err(String::from("connection timed out")))
	}

	pub(super) async fn connect(endpoint: Endpoint, host: &str, port: u16) -> Result<(Connection, Security), String> {
		let mut policy = SecurityPolicy::new(host);
		let insecure_config = quic_config::new_insecure()?;
		let secure_config = quic_config::new_secure()?;

		let (conn_sender, mut conn_receiver) = mpsc::channel(16);

		// Performs DNS resolution if a domain name is provided
		let socket_addrs = net::lookup_host((host, port)).await.map_err(|err| format!("failed DNS resolution of host \"{host}\": {err}"))?.collect::<Vec<_>>();
		let host: Arc<str> = Arc::from(host);
		let mut conn_tasks = Vec::with_capacity(socket_addrs.len());

		/*
		 * Attempts to connect to every result from DNS resolution in parallel.
		 *
		 * To see the purpose of this, suppose that the server is bound to
		 * 127.0.0.1 and the user connects to `localhost`. This hostname first
		 * resolves to `::1` and then `127.0.0.1`. If this is done in parallel,
		 * then a lot of time will be spent waiting for a connection to `::1`,
		 * before nearly immediately connecting to `127.0.0.1` once the first
		 * connection times out.
		 */
		for socket_addr in socket_addrs {
			let ip_addr = socket_addr.ip();
			let check_tls_cert = policy.needs_security(ip_addr);

			let security = if socket_addr.ip().is_loopback() || check_tls_cert {
				/*
				 * Note, if `check_tls_cert` is true, at this point in time the
				 * connection might be actively eavesdropped, but the user would get
				 * a TLS warning and the connection won't succeed.
				 */
				Security::Secure
			} else if policy.ip_addr_not_public(ip_addr) {
				Security::Local
			} else {
				Security::Insecure
			};

			let config = if check_tls_cert { secure_config.clone() } else { insecure_config.clone() };

			let conn_sender = conn_sender.clone();

			/*
			 * When connecting to an IPv6 link-local address (fe80::/10), the scope
			 * id is important to resolve the ambiguity of which interface to use.
			 * However, scope ids are expressed as a '%' in the address, which
			 * isn't allowed for the server name, so I'm using any other address
			 * that works (in this case `localhost`).
			 */
			let server_name = if check_tls_cert { &host } else { "localhost" };
			let connection_res = endpoint.connect_with(config, socket_addr, server_name);

			let handle = tokio::spawn(async move {
				match connection_res {
					Ok(connecting) => {
						let res = connecting.await.map_or_else(
							|err| {
								if let ConnectionError::TransportError(err) = &err {
									/*
									 * Provides a custom error message if the CA isn't
									 * recognised telling the user how to bypass the
									 * certificate check by manually entering the IP
									 * address.
									 *
									 * Note that TLS alert code 48 means the CA is
									 * unknown in both TLS 1.2 and 1.3: RFC 8446 (page
									 * 86) and RFC 5246 (page 69).
									 */
									if err.code == TransportErrorCode::crypto(48) {
										return Err((ip_addr, format!("unknown certificate authority: if you still want to connect to the server, manually enter the IP address of `{ip_addr}` to bypass the certificate check. WARNING: Doing this makes you vulnerable to active eavesdropping and impersonation attacks (similar to browsing the web without using HTTPS).")));
									}
								}

								Err((ip_addr, err.to_string()))
							},
							|connection| Ok((connection, security)),
						);

						conn_sender.send(res).await
					}
					Err(err) => conn_sender.send(Err((ip_addr, err.to_string()))).await,
				}
			});

			conn_tasks.push(handle);
		}

		drop(conn_sender);

		// Iterates through all connections and uses the first one
		let mut errors = Vec::new();
		let res = loop {
			match conn_receiver.recv().await {
				Some(Ok((conn, security))) => {
					log::debug!("connecting to {host}, security = {}", security.to_str());
					break (conn, security);
				},
				Some(Err(err)) => errors.push(err),
				None if errors.is_empty() => return Err(format!("failed connecting to \"{host}\", no socket adresses were resolved for domain name")),
				None if errors.len() == 1 => return Err(format!("failed connecting to \"{host}\": {}", errors.last().map_or("[no error returned, something went wrong]", |err| err.1.as_str()))),
				None => {
					let mut err = format!("failed connecting to \"{host}\", {} socket addresses resolved and attempted for domain name, errors were: ", conn_tasks.len());
					for (i, e) in errors.into_iter().enumerate() {
						if i != 0 { err.push_str(", "); }
						let _ = write!(err, "\"{}\" from IP address `{}`", e.1, e.0);
					}
					return Err(err);
				}
			}
		};

		// Cancels all tasks to prevent any connections still attempted from wasting resources
		for task in conn_tasks {
			task.abort();
		}

		Ok(res)
	}

	pub fn singleplayer(request: PlayRequest, sanity_checks: bool) -> ClientRes {
		let (init_msg, stream) = SingleplayerMessageStream::new(request, sanity_checks);
		Client::build(init_msg, Box::new(stream), sanity_checks).unwrap()
	}

	fn build(init_msg: Message, stream: Box<DynMs>, sanity_checks: bool) -> Result<ClientRes, String> {
		// Receives the response to connecting
		let Message::World(WorldMessage::InitWorld { player_id: id, world, blocks: compressed_blocks, extension: _ }) = init_msg else {
			return Err(String::from("expected InitWorld message, got something else"));
		};

		world.validate_init()?;

		let client = Client {
			stream,
			updates: Vec::new(),
			text_output: Vec::new(),
			sounds: Vec::new(),
			scores: Scores::default(),
			timer: None,
			sanity_checks,
			sanity_check_world: None,
			spectator_count: None,
		};

		Ok((id, client, world, Blocks::try_from(compressed_blocks)?))
	}

	pub fn get_scores(&self) -> &Scores { &self.scores }
	pub fn get_spectator_count(&self) -> Option<usize> { self.spectator_count }
	pub fn get_ping(&self) -> Duration { self.stream.ping() }
	pub fn get_security(&self) -> Security { self.stream.security() }

	pub fn get_time(&mut self, shift: f32) -> Option<String> {
		self.timer.as_mut().map(|(timer, max)| {
			let displayed = timer.displayed_seconds(timer.raw_seconds(), shift as f64);
			*max = timer.most_ahead(displayed, *max);
			timer.format_string(*max)
		})
	}

	pub fn is_networked(&self) -> bool {
		self.stream.is_networked()
	}

	pub fn send(&mut self, msg: Message) -> Result<(), ClientError> {
		self.stream.send(msg).map_err(ClientError::Io)
	}

	pub fn flush(&mut self) -> Result<(), ClientError> {
		self.stream.flush().map_err(ClientError::Io)
	}

	pub fn process_messages(&mut self) -> Result<(), ClientError> {
		loop {
			match self.stream.receive() {
				Ok(Some(Message::World(WorldMessage::WorldUpdate(update, ack)))) => {
					self.updates.push((update, ack));

					if self.updates.len() > MAX_STEPS { // Client has been inactive for a while
						return Err(ClientError::Limit(String::from("too many steps accumulated")));
					}

					if self.sanity_check_world.take().is_some() {
						log::debug!("ignoring sanity check, received another world event"); // The world sent by the server is outdated, so this must be ignored
					}
				},
				Ok(Some(Message::World(WorldMessage::InitScores(scores)))) => {
					self.scores = scores;
					self.scores.validate()?;
				},
				Ok(Some(Message::World(WorldMessage::DScores(delta)))) => {
					delta.apply(&mut self.scores);
					self.scores.validate()?;
				},
				Ok(Some(Message::World(WorldMessage::Timer(timer)))) => self.timer = timer.map(|t| { let s = t.raw_seconds(); (t, s) }),
				Ok(Some(Message::World(WorldMessage::SanityCheck(world)))) => {
					if self.sanity_checks {
						self.sanity_check_world = Some(world);
					}
				},
				Ok(Some(Message::World(WorldMessage::GameStateTransition(trans)))) => {
					if self.sounds.len() < MAX_SOUNDS {
						self.sounds.push(Sfx::GameStateTransition(trans));
					}
				},
				Ok(Some(Message::Misc(MiscMessage::ServerTextOutput(msg)))) => {
					self.text_output.push(msg);

					if self.text_output.len() > MAX_SERVER_TEXT_OUTPUT {
						return Err(ClientError::Limit(String::from("too much server text output accumulated")));
					}
				},
				Ok(Some(Message::Misc(MiscMessage::SpectatorCount(count)))) => self.spectator_count = Some(count),
				Ok(Some(Message::World(WorldMessage::Extension(_)) | Message::Misc(MiscMessage::Extension(_)))) => (),
				Ok(Some(_)) => return Err(ClientError::Protocol(String::from("sending an invalid message"))),
				Ok(None) => return Ok(()),
				Err(err) => return Err(ClientError::Io(err)),
			}
		}
	}

	pub fn update_timer(&mut self, dt: f32) {
		if let Some((timer, _)) = &mut self.timer {
			timer.update(dt);
		}
	}

	pub fn get_text_output(&mut self) -> Vec<ServerTextOutput> { mem::take(&mut self.text_output) }
	pub fn get_world_updates(&mut self) -> Vec<(WorldUpdate, u64)> { mem::take(&mut self.updates) }
	pub fn get_sounds(&mut self) -> Vec<Sfx> { mem::take(&mut self.sounds) }
	pub fn get_sanity_check_world(&mut self) -> Option<World> { self.sanity_check_world.take() }
}
