import { clamp } from "@/core/utils";
import { mat4, vec3 } from "gl-matrix";
import { FLOAT_SIZE_BYTES, QUAD_VERTEX_BUFFER, Renderer, VertexAttributeType, bindVertexAttributes, compileProgram, compileShader, enableBlendSettings, getVertexAttributesTotalFloatCount } from "./renderer";
import { camera, layoutParams, nodeVisualisationParams, currentTime, BACKGROUND_COLOR } from "../rendering_state";
import { calculateNodeSpreadFactor } from "../node_layout/node_layout";

export type NodeInstanceData = {
    globalViewPos: vec3,
    localViewPos: vec3,
    colour: vec3,
    generation: number,
    joinedTime: number,
    parentIdx: number|null,
}

export function createNodeRenderer(gl: WebGL2RenderingContext): Renderer<(nodes:NodeInstanceData[], totalCount: number, globalLocalViewRatio: number)=>void> {
    const vertexShader = compileShader(
        gl, 
        `#version 300 es
        // Standard global uniforms
        uniform float uTime;
        uniform mat4 uProjViewMat;
        uniform mat4 uUnrotateMat;
        uniform float uSpreadFactor;
        uniform float uTimelinePosition;

        uniform float uGlobalLocalViewRatio;

        // Tuning global uniforms
        uniform float uBobSpeed;
        uniform float uBobAmount;
        uniform float uSwimSpeed;
        uniform float uSwimAmount;
        uniform float uPulseSpeed;
        uniform float uPulseAmount;

        uniform float uNodeSize;
        uniform float uNodeGenerationScaleFactor;

        // Shared
        in vec3 aVertexPosition;
        in vec2 aVertexUV;

        // Instanced
        in vec3 aNodePositionGlobalView;
        in vec3 aNodePositionLocalView;
        in vec3 aNodeColour;
        in float aNodeGeneration;
        in float aNodeJoinedTime;

        out vec2 vUV;
        out vec3 vColour;
        out float vAppearFactor;

        const float appearAgeRange = 250000.0;

        void main(){
            float instanceTime = uTime*0.5 + float(gl_InstanceID%1000);

            // Nodes with indices close to the current shown node count will "blink" into existence
            float age = clamp(uTimelinePosition-aNodeJoinedTime, 0.0, appearAgeRange);
            vAppearFactor = pow(age/appearAgeRange, 1.2);

            /*
             * Calculate local vertex position
             */

            // Nodes which are of a higher generation - i.e. deeper into the tree - may be scaled smaller
            float generationScale = pow(uNodeGenerationScaleFactor, aNodeGeneration);

            // Ambient pulsing effect
            float pulseScale = 1.0 + sin(instanceTime*uPulseSpeed)*uPulseAmount;

            float finalNodeSize = uNodeSize * generationScale * pulseScale;

            // Calculate local vertex position by de-rotating so the node always faces the camera
            vec3 vertexLocalPos = vec3(uUnrotateMat * vec4(aVertexPosition, 1)) * finalNodeSize;

            /*
             * Calculate node world position
             */

            // Ambient motion effects
            vec3 bobOffset = vec3(0,uBobAmount,0)*sin(instanceTime*uBobSpeed);
            
            float swimTime = instanceTime*uSwimSpeed;
            vec3 swimOffset = vec3(uSwimAmount,uSwimAmount,uSwimAmount) * vec3(
                sin((swimTime+0.0)) * sin((swimTime+0.0)*1.1),
                sin((swimTime+3.0)*1.2) * sin((swimTime+4.0)*1.4),
                sin((swimTime+5.0)*1.3) * sin((swimTime+6.0)*1.7)
            );

            // These effects are also scaled by generation scale, so that ambient motion is proportional to the node's final size
            vec3 ambientMotionOffset = (bobOffset + swimOffset)*generationScale;

            // Transiting between global and local view position
            vec3 worldPosInCurrentView = mix(aNodePositionGlobalView, aNodePositionLocalView, uGlobalLocalViewRatio);

            // Apply spacing factor (which causes all nodes to spread out away from the origin), and any ambient motion
            vec3 nodeWorldPos = (worldPosInCurrentView*uSpreadFactor) + ambientMotionOffset;

            /*
             * Calculate final vertex world position
             */
            gl_Position = uProjViewMat * vec4(nodeWorldPos + vertexLocalPos, 1.0);

            /*
             * Just pass the UV coordinates and colour through as-is
             */
            vUV = aVertexUV;
            vColour = aNodeColour;
        }
        `,
        gl.VERTEX_SHADER
    );

    const fragmentShader = compileShader(
        gl,
        `#version 300 es
        precision highp float;

        uniform vec4 uBackgroundColor;
        uniform float uDepthRange;
        uniform float uDepthSharpness;

        in vec2 vUV;
        in vec3 vColour;
        in float vAppearFactor;

        out vec4 fragColour;
        
        void main() {
            /*
             * Define circle shape
             */
            float distSq = (vUV[0]-0.5)*(vUV[0]-0.5) + (vUV[1]-0.5)*(vUV[1]-0.5);
            float insideCircleFactor = distSq / (0.5*0.5);
            float alpha = 1.0-smoothstep(0.95, 1.0, insideCircleFactor);
            alpha *= vAppearFactor;

            /*
             * Blend from the node's original colour, towards the scene's background colour, as depth increases.
             * This creates a fog / brightness falloff effect
             */
            float depthFactor = pow(clamp(gl_FragCoord.w*uDepthRange, 0.0, 1.0), uDepthSharpness);
            vec4 colourWithFalloff = mix(uBackgroundColor, vec4(vColour, 1), depthFactor);
            
            /*
             * Determine final colour
             */
            fragColour = mix(uBackgroundColor, colourWithFalloff, alpha);
            if (alpha == 0.0) discard;
        }
        `,
        gl.FRAGMENT_SHADER
    );

    const program = compileProgram(gl, vertexShader, fragmentShader);

    const uTime = gl.getUniformLocation(program, "uTime");
    const uTimelinePosition = gl.getUniformLocation(program, "uTimelinePosition");
    const uProjViewMat = gl.getUniformLocation(program, "uProjViewMat");
    const uUnrotateMat = gl.getUniformLocation(program, "uUnrotateMat");
    const uSpreadFactor = gl.getUniformLocation(program, "uSpreadFactor");

    const uGlobalLocalViewRatio = gl.getUniformLocation(program, "uGlobalLocalViewRatio");

    const uNodeSize = gl.getUniformLocation(program, "uNodeSize");
    const uNodeGenerationScaleFactor = gl.getUniformLocation(program, "uNodeGenerationScaleFactor");

    const uBobSpeed = gl.getUniformLocation(program, "uBobSpeed");
    const uBobAmount = gl.getUniformLocation(program, "uBobAmount");
    const uSwimSpeed = gl.getUniformLocation(program, "uSwimSpeed");
    const uSwimAmount = gl.getUniformLocation(program, "uSwimAmount");
    const uPulseSpeed = gl.getUniformLocation(program, "uPulseSpeed");
    const uPulseAmount = gl.getUniformLocation(program, "uPulseAmount");

    const uBackgroundColor = gl.getUniformLocation(program, "uBackgroundColor");
    const uDepthRange = gl.getUniformLocation(program, "uDepthRange");
    const uDepthSharpness = gl.getUniformLocation(program, "uDepthSharpness");

    const vertexAttributes = [
        {location: gl.getAttribLocation(program, "aVertexPosition"), type: VertexAttributeType.VEC2F},
        {location: gl.getAttribLocation(program, "aVertexUV"),       type: VertexAttributeType.VEC2F},
    ];

    const instanceAttributes = [
        {location: gl.getAttribLocation(program, "aNodePositionGlobalView"),  type: VertexAttributeType.VEC3F},
        {location: gl.getAttribLocation(program, "aNodePositionLocalView"),   type: VertexAttributeType.VEC3F},
        {location: gl.getAttribLocation(program, "aNodeColour"),     type: VertexAttributeType.VEC3F},
        {location: gl.getAttribLocation(program, "aNodeGeneration"), type: VertexAttributeType.FLOAT},
        {location: gl.getAttribLocation(program, "aNodeJoinedTime"), type: VertexAttributeType.FLOAT},
    ];

    const unrotateMat = mat4.create();

    const instancedVertexBuffer = gl.createBuffer();
    if (!instancedVertexBuffer) throw new Error(`Failed to create new instanced vertex buffer`);
    let instanceData: Float32Array|null = null;
    
    let initializedInstanceData: NodeInstanceData[]|null = null;
    let initializedNodeCount = 0;
    const floatsPerInstance = getVertexAttributesTotalFloatCount(instanceAttributes);

    return {
        prepare: (projViewMat: mat4) => {
            enableBlendSettings(gl, "alpha", true, true);

            gl.useProgram(program);

            gl.uniform1f(uTime, currentTime%1000000); // Wrap to avoid precision issues
            gl.uniform4fv(uBackgroundColor, BACKGROUND_COLOR);

            gl.uniformMatrix4fv(uProjViewMat, false, projViewMat);

            mat4.fromQuat(unrotateMat, camera.rotation);
            mat4.invert(unrotateMat, unrotateMat);
            gl.uniformMatrix4fv(uUnrotateMat, false, unrotateMat);

            bindVertexAttributes(gl, QUAD_VERTEX_BUFFER!, false, vertexAttributes);
            bindVertexAttributes(gl, instancedVertexBuffer, true, instanceAttributes);
        },
        draw: (nodes: NodeInstanceData[], totalCount: number, globalLocalViewRatio: number) => {
            
            if (initializedInstanceData != nodes) {
                instanceData = new Float32Array(floatsPerInstance * totalCount);
                initializedInstanceData = nodes;
                initializedNodeCount = 0;

                gl.bufferData(gl.ARRAY_BUFFER, instanceData, gl.STATIC_DRAW);
            }

            if (initializedNodeCount < nodes.length) {
                const numNewNodes = nodes.length - initializedNodeCount;
                const srcStartOffset = floatsPerInstance*initializedNodeCount;

                let offset = srcStartOffset;
                for (let idx=initializedNodeCount; idx<nodes.length; ++idx) {
                    instanceData!.set(nodes[idx].globalViewPos, offset);   offset += 3;
                    instanceData!.set(nodes[idx].localViewPos, offset);    offset += 3;
                    instanceData!.set(nodes[idx].colour, offset);          offset += 3;
                    instanceData!.set([nodes[idx].generation], offset);    offset += 1;
                    instanceData!.set([nodes[idx].joinedTime], offset);    offset += 1;
                }

                const dstStartByte = initializedNodeCount * floatsPerInstance * FLOAT_SIZE_BYTES;
                gl.bufferSubData(gl.ARRAY_BUFFER, dstStartByte, instanceData!, srcStartOffset, numNewNodes*floatsPerInstance);
            
                initializedNodeCount = nodes.length;
            }

            const numNodesToDraw = clamp(nodeVisualisationParams.numNodes, 1, nodes.length);

            gl.uniform1f(uTimelinePosition, nodeVisualisationParams.timelinePosition);

            gl.uniform1f(uSpreadFactor, calculateNodeSpreadFactor(numNodesToDraw)); // Spread nodes out as total number increases
            gl.uniform1f(uGlobalLocalViewRatio, globalLocalViewRatio);

            gl.uniform1f(uNodeSize, layoutParams.nodeSize);
            gl.uniform1f(uNodeGenerationScaleFactor, layoutParams.nodeGenerationScaleFactor);

            gl.uniform1f(uBobSpeed, nodeVisualisationParams.bobSpeed);
            gl.uniform1f(uBobAmount, nodeVisualisationParams.bobAmount);
            gl.uniform1f(uSwimSpeed, nodeVisualisationParams.swimSpeed);
            gl.uniform1f(uSwimAmount, nodeVisualisationParams.swimAmount);
            gl.uniform1f(uPulseSpeed, nodeVisualisationParams.pulseSpeed);
            gl.uniform1f(uPulseAmount, nodeVisualisationParams.pulseAmount);
            gl.uniform1f(uDepthRange, nodeVisualisationParams.depthRange);
            gl.uniform1f(uDepthSharpness, nodeVisualisationParams.depthSharpness);

            gl.drawArraysInstanced(gl.TRIANGLE_FAN, 0, 4, numNodesToDraw);
        }
    };
}