// Copyright Marcus Del Favero 2025
// Licensed under the GNU AGPLv3 with an exception, see `README.md` for details
use std::{fs, mem, iter, path::Path, f32::consts::TAU};

use glam::{Vec2, Mat4};
use glium::{Surface, Program, VertexBuffer, Display, Frame, DrawParameters, vertex::PerInstance, index::{NoIndices, PrimitiveType}, uniform};
use glutin::surface::WindowSurface;

use super::{Forcefield, Vertex};

use crate::app::filesystem::{Filesystem, FsBase};
use crate::utils::maths::{self, glam_fix::Fix};

const MAX_LEVEL: u8 = 6;
const LIGHTNING_COUNT: usize = 19;
const LINE_VERTEX_COUNT: usize = (1 << MAX_LEVEL) * LIGHTNING_COUNT;
const VERTEX_COUNT: usize = LINE_VERTEX_COUNT * 2 + 2;
const MIDPOINT_COUNT: usize = ((1 << MAX_LEVEL) - 1) * LIGHTNING_COUNT;
const ENDPOINT_DIFF: f32 = 0.25;
const ENDPOINT_RANGE: f32 = ENDPOINT_DIFF / 6.0;
const HALF_WIDTH: f32 = 0.0075;
const PULSE_PER_VERTEX: f32 = 6.0 * TAU / (LINE_VERTEX_COUNT as f32);

pub struct AnimatedLightning {
	midpoints: Box<[AnimatedMidpoint]>,
	endpoints: Box<[Endpoint]>,
}

struct AnimatedMidpoint {
	m0: Midpoint,
	m1: Midpoint,
	t: f32,
	t_vel: f32,
	pos_size: f32,
}

struct Midpoint {
	pos: Vec2,
	t: f32,
}

struct Endpoint {
	pos: Vec2,
	angle: f32,
	t: f32,
	v0: EndpointVariation,
	v1: EndpointVariation,
}

struct EndpointVariation {
	angle: f32,
	scale: f32,
}

impl AnimatedLightning {
	pub fn new() -> AnimatedLightning {
		let mut midpoints = Vec::with_capacity(MIDPOINT_COUNT);
		for _ in 0..LIGHTNING_COUNT {
			AnimatedMidpoint::generate(MAX_LEVEL, &mut midpoints);
		}
		debug_assert_eq!(midpoints.len(), MIDPOINT_COUNT);

		let mut endpoints = Vec::with_capacity(LIGHTNING_COUNT);
		for i in 0..LIGHTNING_COUNT {
			endpoints.push(Endpoint::new(i as f32 * ENDPOINT_DIFF * (1.0 + 1.0 / LIGHTNING_COUNT as f32) * TAU, i as f32 / LIGHTNING_COUNT as f32));
		}

		AnimatedLightning {
			midpoints: midpoints.into_boxed_slice(),
			endpoints: endpoints.into_boxed_slice(),
		}
	}

	pub fn update(&mut self, dt: f32) {
		for midpoint in &mut self.midpoints { midpoint.update(dt); }
		for endpoint in &mut self.endpoints { endpoint.update(dt); }
	}

	fn generate(&self, points: &mut Vec<Vec2>) {
		let mut midpoint_i = 0;
		for (p0, p1) in self.endpoints.windows(2).map(|p| (&p[0], &p[1])).chain(iter::once((self.endpoints.last().unwrap(), self.endpoints.first().unwrap()))) {
			self.lightning(p0.pos, p1.pos, MAX_LEVEL, &mut midpoint_i, points);
		}
	}

	fn lightning(&self, p0: Vec2, p1: Vec2, level: u8, midpoint_i: &mut usize, points: &mut Vec<Vec2>) {
		if level == 0 {
			points.push(p0);
			return;
		}

		let midpoint = self.midpoints[*midpoint_i].get();
		*midpoint_i += 1;

		let length = p0.distance(p1);
		let p = p0.lerp(p1, midpoint.t) + midpoint.pos * length;

		self.lightning(p0, p, level - 1, midpoint_i, points);
		self.lightning(p, p1, level - 1, midpoint_i, points);
	}
}

impl AnimatedMidpoint {
	fn generate(level: u8, output: &mut Vec<AnimatedMidpoint>) {
		debug_assert_ne!(level, 0);

		let pos_size = 0.5 / (level as f32).powf(0.4375);
		output.push(AnimatedMidpoint {
			m0: Midpoint::new(pos_size),
			m1: Midpoint::new(pos_size),
			t: rand::random::<f32>(),
			t_vel: 15.0 / (level as f32),
			pos_size,
		});

		if level > 1 {
			AnimatedMidpoint::generate(level - 1, output);
			AnimatedMidpoint::generate(level - 1, output);
		}
	}

	fn get(&self) -> Midpoint {
		let t = maths::smoothstep(self.t);
		Midpoint {
			pos: self.m0.pos.lerp(self.m1.pos, t),
			t: maths::lerp(self.m0.t, self.m1.t, t),
		}
	}

	fn update(&mut self, dt: f32) {
		self.t += self.t_vel * dt;
		animate_lerp(&mut self.t, &mut self.m0, &mut self.m1, || Midpoint::new(self.pos_size));
	}
}

impl Midpoint {
	fn new(pos_size: f32) -> Midpoint {
		Midpoint {
			pos: (rand::random::<Vec2>() * 2.0 - 1.0) * pos_size,
			t: rand::random::<f32>() * 0.2 + 0.4,
		}
	}
}

impl Endpoint {
	fn new(angle: f32, t: f32) -> Endpoint {
		Endpoint {
			pos: Vec2::ZERO, angle, t,
			v0: EndpointVariation::new(), v1: EndpointVariation::new(),
		}
	}

	fn update(&mut self, dt: f32) {
		self.t += dt;
		animate_lerp(&mut self.t, &mut self.v0, &mut self.v1, EndpointVariation::new);

		let t = maths::smoothstep(self.t);
		let angle = self.angle + maths::lerp(self.v0.angle, self.v1.angle, t);
		let scale = maths::lerp(self.v0.scale, self.v1.scale, t);
		self.pos = Vec2::from_angle(angle) * scale;
	}
}

impl EndpointVariation {
	fn new() -> EndpointVariation {
		EndpointVariation {
			angle: (rand::random::<f32>() * 2.0 - 1.0) * ENDPOINT_RANGE * TAU,
			scale: rand::random::<f32>() * 0.125 + 0.875,
		}
	}
}

fn animate_lerp<T>(t: &mut f32, x0: &mut T, x1: &mut T, new: impl Fn() -> T) {
	if *t >= 2.0 {
		*x0 = new();
		*x1 = new();
		*t -= t.floor();
	} else if *t >= 1.0 {
		mem::swap(x0, x1);
		*x1 = new();
		*t -= 1.0;
	}
}

pub struct LightningRenderer {
	points: Vec<Vec2>,
	vbo: VertexBuffer<Vertex>,
	program: Program,
}

impl LightningRenderer {
	pub fn new(display: &Display<WindowSurface>, fs: &Filesystem) -> LightningRenderer {
		let vsh = fs::read_to_string(fs.get(FsBase::Static, Path::new("shaders/forcefield_lightning.vsh"))).unwrap();
		let fsh = fs::read_to_string(fs.get(FsBase::Static, Path::new("shaders/forcefield_lightning.fsh"))).unwrap();

		LightningRenderer {
			points: Vec::new(),
			vbo: VertexBuffer::empty_persistent(display, VERTEX_COUNT).unwrap(),
			program: Program::from_source(display, &vsh, &fsh, None).unwrap(),
		}
	}

	pub fn render(&mut self, lightning: &AnimatedLightning, frame: &mut Frame, i_vbo: PerInstance<'_>, proj_matrix: &Mat4, params: &DrawParameters) {
		self.points.clear();
		lightning.generate(&mut self.points);

		{
			let mut buf = self.vbo.map_write();
			let mut i = 0;

			// Not perfect but good enough
			let len = self.points.len();
			let points = self.points
				.windows(3)
				.map(|p| (p[0], p[1], p[2]))
				.chain(iter::once((self.points[len - 2], self.points[len - 1], self.points[0])))
				.chain(iter::once((self.points[len - 1], self.points[0], self.points[1])))
				.chain(iter::once((self.points[0], self.points[1], self.points[2])));

			for (p0, p, p1) in points {
				let dir = (p1 - p0).normalise_or(Vec2::X);
				let off = Vec2::new(dir.y, -dir.x) * HALF_WIDTH;
				buf.set(i    , Vertex { v_pos: (p - off).to_array() });
				buf.set(i + 1, Vertex { v_pos: (p + off).to_array() });
				i += 2;
			}
		}

		let uniforms = uniform! { u_matrix: proj_matrix.to_cols_array_2d(), u_radius: Forcefield::RADIUS, u_pulse_per_vertex: PULSE_PER_VERTEX };
		frame.draw((&self.vbo, i_vbo), NoIndices(PrimitiveType::TriangleStrip), &self.program, &uniforms, params).unwrap();
	}
}
