// Copyright Marcus Del Favero 2025
// Licensed under the GNU AGPLv3 with an exception, see `README.md` for details
#[cfg(test)]
mod tests;

use std::{io::Read, fmt::{Display, Formatter, Result as FmtResult}, str::Utf8Error};

use arrayvec::ArrayVec;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use serde::{Serialize, Deserialize, Serializer, Deserializer, de::Error as DeError};

use super::{Error, QuicConnection};

/**
 * A variable-length unsigned integer that supports values from 0 to 0x4040403f
 * (a bit over 2^30) inclusive.
 *
 * This is called "Vu30" as it's of similar size to a hypothetical "u30" type
 * and it's variable length.
 *
 * This integer consists of 1 to 4 bytes with larger numbers requiring more
 * bytes to be represented.
 *
 * The first byte consists of llxx xxxx, where ll specifies the number of
 * remaining bytes (0 – 3) and the remaining bits are the least significant bits
 * of the number.
 *
 * All remaining bytes are of increasing significance, similar to little-endian.
 */
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(remote = "Self")]
pub struct Vu30(u32);

impl Vu30 {
	pub(super) const ZERO: Vu30 = Vu30(0);

	const B2_BEGIN: u32 = 1 << 6;
	const B3_BEGIN: u32 = Vu30::B2_BEGIN + (1 << (6 + 8));
	const B4_BEGIN: u32 = Vu30::B3_BEGIN + (1 << (6 + 8 * 2));
	pub(super) const MAX: Vu30 = Vu30(Vu30::B4_BEGIN + (1 << (6 + 8 * 3)) - 1);

	pub const fn from_u8(x: u8) -> Vu30 {
		Vu30(x as u32)
	}

	pub(super) const fn try_from_u32(x: u32) -> Option<Vu30> {
		if x <= Vu30::MAX.0 { Some(Vu30(x)) }
		else { None }
	}

	pub(super) const fn try_from_usize(x: usize) -> Option<Vu30> {
		if x <= Vu30::MAX.0 as usize { Some(Vu30(x as u32)) }
		else { None }
	}

	pub(super) const fn get(self) -> u32 {
		self.0
	}

	/// Returns the length when serialised.
	pub const fn len(self) -> usize {
		if self.0 < Vu30::B2_BEGIN { 1 }
		else if self.0 < Vu30::B3_BEGIN { 2 }
		else if self.0 < Vu30::B4_BEGIN { 3 }
		else { 4 }
	}
}

// Doesn't get serialised into the vu30 format, but that doesn't matter
impl Serialize for Vu30 {
	fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
		Vu30::serialize(self, s)
	}
}

impl<'de> Deserialize<'de> for Vu30 {
	fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Vu30, D::Error> {
		let x = Vu30::deserialize(d)?;
		if x <= Vu30::MAX {
			Ok(x)
		} else {
			Err(DeError::custom("invalid vu30"))
		}
	}
}


impl Display for Vu30 {
	fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
		self.0.fmt(fmt)
	}
}

#[derive(Clone)]
pub(super) enum ReadError {
	Io,
	TooLong { len: usize, size_limit: usize },
	Utf8(Utf8Error),
}

impl ReadError {
	pub fn into_error(self, conn: &QuicConnection) -> Error {
		match self {
			ReadError::Io => conn.into(),
			ReadError::TooLong { len, size_limit } => Error::Deserialisation(format!("too long, {len} > {size_limit} bytes")),
			ReadError::Utf8(err) => Error::Deserialisation(format!("invalid UTF-8: {err}")),
		}
	}
}

pub trait ReadVu30: Read {
	// Can't seem to deduplicate this code between the async version annoyingly
	fn read_vu30(&mut self) -> Result<Vu30, ()> {
		let first = self.read_u8()?;
		match first & 0xc0 {
			0x00 => Ok(Vu30(first as u32)),
			0x40 => {
				let next = self.read_u8()?;
				Ok(Vu30(Vu30::B2_BEGIN + (((next as u32) << 6) | (first & 0x3f) as u32)))
			},
			0x80 => {
				let next = self.read_u16_le()?;
				Ok(Vu30(Vu30::B3_BEGIN + (((next as u32) << 6) | (first & 0x3f) as u32)))
			},
			_ /* 0xc0 */ => {
				let mut next = [0; 3];
				self.read_exact(&mut next).map_err(|_| ())?;
				let next_u24_le = next[0] as u32 | ((next[1] as u32) << 8) | ((next[2] as u32) << 16);
				Ok(Vu30(Vu30::B4_BEGIN + ((next_u24_le << 6) | (first & 0x3f) as u32)))
			},
		}
	}

	fn read_u8(&mut self) -> Result<u8, ()> {
		let mut buf = [0; 1];
		self.read_exact(&mut buf).map_err(|_| ())?;
		Ok(buf[0])
	}

	fn read_u16_le(&mut self) -> Result<u16, ()> {
		let mut buf = [0; 2];
		self.read_exact(&mut buf).map_err(|_| ())?;
		Ok((buf[0] as u16) | ((buf[1] as u16) << 8))
	}
}

impl<R> ReadVu30 for R where R: Read {}

pub(super) trait AsyncReadVu30: AsyncReadExt + Unpin {
	async fn read_vu30(&mut self) -> Result<Vu30, ()> {
		let first = self.read_u8().await.map_err(|_| ())?;
		match first & 0xc0 {
			0x00 => Ok(Vu30(first as u32)),
			0x40 => {
				let next = self.read_u8().await.map_err(|_| ())?;
				Ok(Vu30(Vu30::B2_BEGIN + (((next as u32) << 6) | (first & 0x3f) as u32)))
			},
			0x80 => {
				let next = self.read_u16_le().await.map_err(|_| ())?;
				Ok(Vu30(Vu30::B3_BEGIN + (((next as u32) << 6) | (first & 0x3f) as u32)))
			},
			_ /* 0xc0 */ => {
				let mut next = [0; 3];
				self.read_exact(&mut next).await.map_err(|_| ())?;
				let next_u24_le = next[0] as u32 | ((next[1] as u32) << 8) | ((next[2] as u32) << 16);
				Ok(Vu30(Vu30::B4_BEGIN + ((next_u24_le << 6) | (first & 0x3f) as u32)))
			},
		}
	}

	/**
	 * Reads a Vu30 for a length and then reads data up to that length which is
	 * appended to the `bytes` Vec<u8> supplied. Allows setting a size limit to
	 * avoid application layer denial of service attacks.
	 */
	async fn read_into_bytes(&mut self, bytes: &mut Vec<u8>, size_limit: usize) -> Result<(), ReadError> {
		let len = self.read_vu30().await.map_err(|()| ReadError::Io)?.get() as usize;
		if len > size_limit {
			return Err(ReadError::TooLong { len, size_limit });
		}

		bytes.reserve(len);
		let mut reader = self.take(len as u64);
		let count = reader.read_to_end(bytes).await.map_err(|_| ReadError::Io)?;

		debug_assert!(count <= len);
		if count < len {
			return Err(ReadError::Io);
		}

		Ok(())
	}

	/**
	 * A wrapper over `read_into_bytes` which creates and returns a new Vec.
	 */
	async fn read_bytes(&mut self, size_limit: usize) -> Result<Vec<u8>, ReadError> {
		let mut bytes = Vec::new();
		self.read_into_bytes(&mut bytes, size_limit).await?;
		Ok(bytes)
	}

	/**
	 * Similar to `read_bytes` but requires the data to be valid UTF-8.
	 */
	async fn read_string(&mut self, size_limit: usize) -> Result<String, ReadError> {
		let bytes = self.read_bytes(size_limit).await?;
		String::from_utf8(bytes).map_err(|err| ReadError::Utf8(err.utf8_error()))
	}
}

impl<R> AsyncReadVu30 for R where R: AsyncReadExt + Unpin {}

#[derive(Debug)]
pub(super) enum WriteError {
	Io,
	Vu30TooLarge,
}

impl WriteError {
	pub fn into_error(self, conn: &QuicConnection) -> Error {
		match self {
			WriteError::Io => conn.into(),
			WriteError::Vu30TooLarge => WriteError::vu30_too_large(),
		}
	}

	fn vu30_too_large() -> Error {
		Error::Serialisation(String::from("vu30 too large"))
	}
}

pub(super) trait AsyncWriteVu30: AsyncWriteExt + Unpin {
	async fn write_vu30(&mut self, x: Vu30) -> Result<(), WriteError> {
		let mut data = ArrayVec::new();
		data.extend_vu30(x);
		self.write_all(&data).await.map_err(|_| WriteError::Io)
	}

	async fn write_bytes(&mut self, data: &[u8]) -> Result<(), WriteError> {
		self.write_vu30(Vu30::try_from_usize(data.len()).ok_or(WriteError::Vu30TooLarge)?).await?;
		self.write_all(data).await.map_err(|_| WriteError::Io)
	}
}

impl<W> AsyncWriteVu30 for W where W: AsyncWriteExt + Unpin {}

pub trait ExtendVu30 {
	fn extend_vu30(&mut self, mut x: Vu30) {
		if x.0 < Vu30::B2_BEGIN { // 1 byte
			self.push_u8(x.0 as u8);
		} else if x.0 < Vu30::B3_BEGIN { // 2 bytes
			x.0 -= Vu30::B2_BEGIN;
			let first = 0x40 | (x.0 & 0x3f) as u8;
			let next = (x.0 >> 6) as u8;
			self.push_u8(first);
			self.push_u8(next);
		} else if x.0 < Vu30::B4_BEGIN { // 3 bytes
			x.0 -= Vu30::B3_BEGIN;
			let first = 0x80 | (x.0 & 0x3f) as u8;
			let next = (x.0 >> 6) as u16;
			self.push_u8(first);
			self.push_u16(next);
		} else { // 4 bytes
			x.0 -= Vu30::B4_BEGIN;
			let first = 0xc0 | (x.0 & 0x3f) as u8;
			let next = x.0 >> 6;
			self.push_u8(first);
			self.push_u16((next & 0xffff) as u16);
			self.push_u8((next >> 16) as u8);
		}
	}

	fn extend_bytes(&mut self, data: &[u8]) -> Result<(), Error> {
		let len = Vu30::try_from_usize(data.len()).ok_or_else(WriteError::vu30_too_large)?;
		self.extend_vu30(len);
		self.push_slice(data);
		Ok(())
	}

	// Fluent interface for convenience
	fn vu30(mut self, x: Vu30) -> Self where Self: Sized {
		self.extend_vu30(x);
		self
	}

	fn string(mut self, s: &str) -> Result<Self, Error> where Self: Sized {
		self.extend_bytes(s.as_bytes())?;
		Ok(self)
	}

	fn push_u8(&mut self, x: u8);
	fn push_u16(&mut self, x: u16);
	fn push_slice(&mut self, x: &[u8]);
}

impl ExtendVu30 for Vec<u8> {
	fn push_u8(&mut self, x: u8) {
		self.push(x);
	}

	fn push_u16(&mut self, x: u16) {
		self.extend_from_slice(&x.to_le_bytes());
	}

	fn push_slice(&mut self, x: &[u8]) {
		self.extend_from_slice(x);
	}
}

impl ExtendVu30 for ArrayVec<u8, 4> {
	fn push_u8(&mut self, x: u8) {
		self.push(x);
	}

	fn push_u16(&mut self, x: u16) {
		self.try_extend_from_slice(&x.to_le_bytes()).unwrap();
	}

	fn push_slice(&mut self, x: &[u8]) {
		self.try_extend_from_slice(x).unwrap();
	}
}
