import { clamp } from "@/core/utils";
import { mat4 } from "gl-matrix";
import { FLOAT_SIZE_BYTES, QUAD_VERTEX_BUFFER, Renderer, VertexAttributeType, bindVertexAttributes, compileProgram, compileShader, enableBlendSettings, getVertexAttributesTotalFloatCount } from "./renderer";
import { camera, layoutParams, nodeVisualisationParams, currentTime, nebulaVisualisationParams } from "../rendering_state";
import { NodeInstanceData } from "./renderer_node";
import { calculateNodeSpreadFactor } from "../node_layout/node_layout";

export function createNebulaRenderer(gl: WebGL2RenderingContext, noiseImage: TexImageSource): Renderer<(nodes:NodeInstanceData[], totalCount: number, globalLocalViewRatio: number)=>void> {
    const vertexShader = compileShader(
        gl, 
        `#version 300 es
        // Standard global uniforms
        uniform float uTime;
        uniform mat4 uProjViewMat;
        uniform mat4 uUnrotateMat;
        uniform float uSpreadFactor;
        uniform float uTimelinePosition;

        uniform float uGlobalLocalViewRatio;

        uniform float uDepthRange;
        uniform float uDepthSharpness;

        // Tuning global uniforms
        uniform float uNodeSize;
        uniform float uNodeGenerationScaleFactor;

        uniform float uNebulaScale;
        uniform float uNebulaMaxGeneration;
        uniform float uNebulaMinGeneration;

        // Shared
        in vec3 aVertexPosition;
        in vec2 aVertexUV;

        // Instanced
        in vec3 aNodePositionGlobalView;
        in vec3 aNodePositionLocalView;
        in vec3 aNodeColour;
        in float aNodeGeneration;
        in float aNodeJoinedTime;

        out vec2 vUV;
        out vec3 vColour;
        out float instanceTime;
        out float vAppearFactor;

        const float appearAgeRange = 250000.0;

        const vec4 DEPTH_OFFSET = vec4(0, 0, 0, 0.1); // To bias nebulas to be rendered in front of nodes

        void main(){
            instanceTime = uTime*0.5 + float(gl_InstanceID%1000)*1234.0;

            // Nodes with indices close to the current shown node count will "blink" into existence
            float age = clamp(uTimelinePosition-aNodeJoinedTime, 0.0, appearAgeRange);
            vAppearFactor = pow(age/appearAgeRange, 1.2);

            /*
             * Calculate local vertex position
             */

            // Nodes which are of a higher generation - i.e. deeper into the tree - may be scaled smaller
            float generationScale = pow(uNodeGenerationScaleFactor, aNodeGeneration);

            // Should be larger when appear effect is active
            generationScale = mix(0.07, generationScale, vAppearFactor);

            // Cull based on generation
            generationScale *= step(aNodeGeneration, uNebulaMaxGeneration);
            generationScale *= step(uNebulaMinGeneration, aNodeGeneration);

            float finalNodeSize = uNodeSize * generationScale * uNebulaScale;

            // Calculate local vertex position by de-rotating so the node always faces the camera
            vec3 vertexLocalPos = vec3(uUnrotateMat * vec4(aVertexPosition, 1)) * finalNodeSize;

            /*
             * Calculate node world position
             */

            // Transiting between global and local view position
            vec3 worldPosInCurrentView = mix(aNodePositionGlobalView, aNodePositionLocalView, uGlobalLocalViewRatio);

            // Apply spacing factor (which causes all nodes to spread out away from the origin)
            vec3 nodeWorldPos = (worldPosInCurrentView*uSpreadFactor);

            /*
             * Calculate final vertex world position
             */
            gl_Position = uProjViewMat * vec4(nodeWorldPos + vertexLocalPos, 1.0) + DEPTH_OFFSET;

            /*
             * Blend from the node's original colour, towards black, as depth increases.
             * This creates a fog / brightness falloff effect
             */
            float depthBrightnessFactor = pow(clamp(uDepthRange/gl_Position.w, 0.0, 1.0), uDepthSharpness);

            // Also fade out when very close, to create a "clouds clearing" effect
            depthBrightnessFactor *= 1.0-pow(clamp(0.06*uDepthRange/gl_Position.w, 0.0, 1.0), 1.0);

            // Calculate final colour
            vColour = mix(vec3(0,0,0), aNodeColour, depthBrightnessFactor);

            // Cull clouds which are very dim (by scaling down to zero)
            gl_Position *= step(0.001, depthBrightnessFactor);

            /*
             * Just pass the UV coordinates  through as-is
             */
            vUV = aVertexUV;
        }
        `,
        gl.VERTEX_SHADER
    );

    const fragmentShader = compileShader(
        gl,
        `#version 300 es
        precision highp float;

        uniform float uNebulaIntensity;
        uniform float uNebulaSharpness;
        uniform float uNebulaTexture;

        uniform sampler2D uNoiseTexture;

        in vec2 vUV;
        in vec3 vColour;
        in float instanceTime;
        in float vAppearFactor;

        out vec4 fragColour;
        
        void main() {
            /*
             * Define circle shape
             */
            float distSq = (vUV[0]-0.5)*(vUV[0]-0.5) + (vUV[1]-0.5)*(vUV[1]-0.5);
            float insideCircleFactor = distSq / (0.5*0.5);
            float alpha = clamp(1.0-insideCircleFactor,0.0,1.0);
            float appearAlpha = pow(alpha,5.0);

            alpha = pow(alpha, uNebulaSharpness) * uNebulaIntensity;

            vec2 noiseUV0 = vUV + vec2(instanceTime*0.03, instanceTime*0.02+7.3);
            vec2 noiseUV1 = vUV + vec2(-instanceTime*0.025, -instanceTime*0.02+1.2);
            float noiseValue = texture(uNoiseTexture, noiseUV0)[0] + texture(uNoiseTexture, noiseUV1)[0];
            alpha *= mix(1.0, noiseValue, uNebulaTexture);

            /*
             * Determine final colour
             */
            vec4 colour = mix(vec4(0,0,0,1), vec4(vColour,1), alpha);

            vec4 appearColour = mix(vec4(0,0,0,1), mix(vec4(1,1,1,1), vec4(vColour,1), vAppearFactor), appearAlpha);

            fragColour = mix(appearColour, colour, vAppearFactor);
        }
        `,
        gl.FRAGMENT_SHADER
    );

    const program = compileProgram(gl, vertexShader, fragmentShader);

    const uTime = gl.getUniformLocation(program, "uTime");
    const uTimelinePosition = gl.getUniformLocation(program, "uTimelinePosition");
    const uProjViewMat = gl.getUniformLocation(program, "uProjViewMat");
    const uUnrotateMat = gl.getUniformLocation(program, "uUnrotateMat");
    const uSpreadFactor = gl.getUniformLocation(program, "uSpreadFactor");

    const uGlobalLocalViewRatio = gl.getUniformLocation(program, "uGlobalLocalViewRatio");

    const uNoiseTexture = gl.getUniformLocation(program, "uNoiseTexture");

    const uNodeSize = gl.getUniformLocation(program, "uNodeSize");
    const uNodeGenerationScaleFactor = gl.getUniformLocation(program, "uNodeGenerationScaleFactor");

    const uNebulaScale = gl.getUniformLocation(program, "uNebulaScale");
    const uNebulaIntensity = gl.getUniformLocation(program, "uNebulaIntensity");
    const uNebulaSharpness = gl.getUniformLocation(program, "uNebulaSharpness");
    const uNebulaTexture = gl.getUniformLocation(program, "uNebulaTexture");
    const uNebulaMaxGeneration = gl.getUniformLocation(program, "uNebulaMaxGeneration");
    const uNebulaMinGeneration = gl.getUniformLocation(program, "uNebulaMinGeneration");

    const uDepthRange = gl.getUniformLocation(program, "uDepthRange");
    const uDepthSharpness = gl.getUniformLocation(program, "uDepthSharpness");

    const vertexAttributes = [
        {location: gl.getAttribLocation(program, "aVertexPosition"), type: VertexAttributeType.VEC2F},
        {location: gl.getAttribLocation(program, "aVertexUV"),       type: VertexAttributeType.VEC2F},
    ];

    const instanceAttributes = [
        {location: gl.getAttribLocation(program, "aNodePositionGlobalView"),   type: VertexAttributeType.VEC3F},
        {location: gl.getAttribLocation(program, "aNodePositionLocalView"),   type: VertexAttributeType.VEC3F},
        {location: gl.getAttribLocation(program, "aNodeColour"),     type: VertexAttributeType.VEC3F},
        {location: gl.getAttribLocation(program, "aNodeGeneration"), type: VertexAttributeType.FLOAT},
        {location: gl.getAttribLocation(program, "aNodeJoinedTime"), type: VertexAttributeType.FLOAT},
    ];

    const unrotateMat = mat4.create();

    const noiseTexture = gl.createTexture();
    gl.bindTexture(gl.TEXTURE_2D, noiseTexture);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
    gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, noiseImage);

    const instancedVertexBuffer = gl.createBuffer();
    if (!instancedVertexBuffer) throw new Error(`Failed to create new instanced vertex buffer`);
    let instanceData: Float32Array|null = null;
    
    let initializedInstanceData: NodeInstanceData[]|null = null;
    let initializedNodeCount = 0;
    const floatsPerInstance = getVertexAttributesTotalFloatCount(instanceAttributes);

    return {
        prepare: (projViewMat: mat4) => {
            enableBlendSettings(gl, "additive", true, false);
        
            gl.useProgram(program);

            gl.uniform1f(uTime, currentTime%1000000); // Wrap to avoid precision issues

            gl.uniform1f(uNebulaScale, nebulaVisualisationParams.scale);
            gl.uniform1f(uNebulaIntensity, nebulaVisualisationParams.intensity);
            gl.uniform1f(uNebulaSharpness, nebulaVisualisationParams.sharpness);
            gl.uniform1f(uNebulaTexture, nebulaVisualisationParams.texture);
            gl.uniform1f(uNebulaMaxGeneration, nebulaVisualisationParams.maxGeneration);
            gl.uniform1f(uNebulaMinGeneration, nebulaVisualisationParams.minGeneration);

            gl.uniformMatrix4fv(uProjViewMat, false, projViewMat);

            mat4.fromQuat(unrotateMat, camera.rotation);
            mat4.invert(unrotateMat, unrotateMat);
            gl.uniformMatrix4fv(uUnrotateMat, false, unrotateMat);

            gl.activeTexture(gl.TEXTURE0);
            gl.bindTexture(gl.TEXTURE_2D, noiseTexture);
            gl.uniform1i(uNoiseTexture, 0);

            bindVertexAttributes(gl, QUAD_VERTEX_BUFFER!, false, vertexAttributes);
            bindVertexAttributes(gl, instancedVertexBuffer, true, instanceAttributes);
        },
        draw: (nodes: NodeInstanceData[], totalCount: number, globalLocalViewRatio: number) => {
            
            if (initializedInstanceData != nodes) {
                instanceData = new Float32Array(floatsPerInstance * totalCount);
                initializedInstanceData = nodes;
                initializedNodeCount = 0;

                gl.bufferData(gl.ARRAY_BUFFER, instanceData, gl.STATIC_DRAW);
            }

            if (initializedNodeCount < nodes.length) {
                const numNewNodes = nodes.length - initializedNodeCount;
                const srcStartOffset = floatsPerInstance*initializedNodeCount;

                let offset = srcStartOffset;
                for (let idx=initializedNodeCount; idx<nodes.length; ++idx) {
                    instanceData!.set(nodes[idx].globalViewPos, offset);   offset += 3;
                    instanceData!.set(nodes[idx].localViewPos, offset);    offset += 3;
                    instanceData!.set(nodes[idx].colour, offset);          offset += 3;
                    instanceData!.set([nodes[idx].generation], offset);    offset += 1;
                    instanceData!.set([nodes[idx].joinedTime], offset);    offset += 1;
                }

                const dstStartByte = initializedNodeCount * floatsPerInstance * FLOAT_SIZE_BYTES;
                gl.bufferSubData(gl.ARRAY_BUFFER, dstStartByte, instanceData!, srcStartOffset, numNewNodes*floatsPerInstance);
            
                initializedNodeCount = nodes.length;
            }

            const numNodesToDraw = clamp(nodeVisualisationParams.numNodes, 1, nodes.length);

            gl.uniform1f(uTimelinePosition, nodeVisualisationParams.timelinePosition);

            gl.uniform1f(uSpreadFactor, calculateNodeSpreadFactor(numNodesToDraw)); // Spread nodes out as total number increases
            gl.uniform1f(uGlobalLocalViewRatio, globalLocalViewRatio);

            gl.uniform1f(uNodeSize, layoutParams.nodeSize);
            gl.uniform1f(uNodeGenerationScaleFactor, layoutParams.nodeGenerationScaleFactor);

            gl.uniform1f(uDepthRange, nodeVisualisationParams.depthRange);
            gl.uniform1f(uDepthSharpness, nodeVisualisationParams.depthSharpness);

            gl.drawArraysInstanced(gl.TRIANGLE_FAN, 0, 4, numNodesToDraw);
        }
    };
}