// Copyright Marcus Del Favero 2025
// Licensed under the GNU AGPLv3 with an exception, see `README.md` for details
mod socket;

use std::{net::SocketAddr, collections::{BTreeMap, btree_map::Entry}};

use quinn::Endpoint;

use socket::Socket;

use super::ServerId;

use crate::net::request;
use crate::protocol::discovery::DiscoveryMessage;
use crate::utils::task::Task;

// Mitigates DoS attacks
const MAX_SERVER_IDS: usize = 16;
const MAX_ADDRS_PER_ID: usize = 16;

enum DiscoveryState {
	Waiting(BTreeMap<SocketAddr, Task<DiscoveryMessage>>),
	Found,
}

pub struct LanDiscoveryClient {
	socket: Option<Socket>,
	discovery: BTreeMap<ServerId, DiscoveryState>,
}

impl LanDiscoveryClient {
	pub fn new() -> LanDiscoveryClient {
		LanDiscoveryClient { socket: None, discovery: BTreeMap::new() }
	}

	pub fn discover(&mut self) -> Result<(), String> {
		match &mut self.socket {
			Some(socket) => socket.reset(),
			None => self.socket = Some(Socket::build()?),
		}

		self.discovery.clear();
		Ok(())
	}

	pub fn update(&mut self, endpoint: Result<Endpoint, String>) -> Vec<(SocketAddr, DiscoveryMessage)> {
		let Some(socket) = &mut self.socket else {
			self.discovery.clear();
			return Vec::new();
		};

		if !socket.update() {
			self.socket = None;
			return Vec::new();
		}

		loop {
			match socket.receive() {
				Ok(Some((id, addr))) => {
					let server_id_count = self.discovery.len();
					let state = match self.discovery.entry(id) {
						Entry::Occupied(e) => e.into_mut(),
						Entry::Vacant(e) => {
							if server_id_count < MAX_SERVER_IDS {
								e.insert(DiscoveryState::Waiting(BTreeMap::new()))
							} else {
								continue;
							}
						},
					};

					match state {
						DiscoveryState::Waiting(tasks) if tasks.len() < MAX_ADDRS_PER_ID => _ = tasks.entry(addr).or_insert_with(|| Task::run(request::discover_with_socket_addr(endpoint.clone(), addr))),
						DiscoveryState::Waiting(_) | DiscoveryState::Found => (),
					}
				},
				Ok(None) => break,
				Err(err) => {
					log::warn!("failed receiving from socket: {err}");
					break;
				},
			}
		}

		let mut messages = Vec::new();
		for state in self.discovery.values_mut() {
			match state {
				DiscoveryState::Waiting(tasks) => {
					if let Some((addr, msg)) = tasks.iter_mut().find_map(|(addr, task)| task.result().transpose().ok().flatten().map(|msg| (*addr, msg))) {
						messages.push((addr, msg));
						*state = DiscoveryState::Found;
					}
				},
				DiscoveryState::Found => (),
			}
		}
		messages
	}

	pub fn stop(&mut self) {
		self.socket = None;
	}

	pub fn open(&self) -> bool {
		self.socket.is_some()
	}
}
