import { clamp } from "@/core/utils";
import { mat4, vec4 } from "gl-matrix";
import { NodeInstanceData } from "./renderer_node";
import { FLOAT_SIZE_BYTES, Renderer, VertexAttributeType, bindVertexAttributes, compileProgram, compileShader, enableBlendSettings, getVertexAttributesTotalFloatCount } from "./renderer";
import { BACKGROUND_COLOR, nodeVisualisationParams } from "../rendering_state";
import { calculateNodeSpreadFactor } from "../node_layout/node_layout";

export function createNodeLineRenderer(gl: WebGL2RenderingContext): Renderer<(nodes:NodeInstanceData[], totalCount: number, globalLocalViewRatio: number)=>void> {
    const vertexShader = compileShader(
        gl, 
        `#version 300 es

        // Standard global uniforms
        uniform mat4 uProjViewMat;
        uniform float uSpreadFactor;

        uniform float uGlobalLocalViewRatio;

        // Shared
        in vec3 aVertexPositionGlobalView;
        in vec3 aVertexPositionLocalView;

        const vec4 DEPTH_OFFSET = vec4(0, 0, 0, -0.5); // To bias lines to be rendered behind nodes

        void main(){
            // Transiting between global and local view position
            vec3 worldPosInCurrentView = mix(aVertexPositionGlobalView, aVertexPositionLocalView, uGlobalLocalViewRatio);

            gl_Position = uProjViewMat * vec4(worldPosInCurrentView*uSpreadFactor, 1.0) + DEPTH_OFFSET;
        }
        `,
        gl.VERTEX_SHADER
    );

    const fragmentShader = compileShader(
        gl,
        `#version 300 es
        precision highp float;

        uniform vec4 uLineColour;
        uniform vec4 uBackgroundColor;
        uniform float uDepthRange;
        uniform float uDepthSharpness;

        out vec4 fragColour;
        
        void main() {
            /*
             * Blend from the line's original colour, towards the scene's background colour, as depth increases.
             * This creates a fog / brightness falloff effect
             */
            float depthFactor = pow(clamp(gl_FragCoord.w*uDepthRange, 0.0, 1.0), uDepthSharpness);
            vec4 colourWithFalloff = mix(uBackgroundColor, uLineColour, depthFactor);
            
            fragColour = colourWithFalloff;
        }
        `,
        gl.FRAGMENT_SHADER
    );

    const program = compileProgram(gl, vertexShader, fragmentShader);

    const uProjViewMat = gl.getUniformLocation(program, "uProjViewMat");
    const uSpreadFactor = gl.getUniformLocation(program, "uSpreadFactor");

    const uGlobalLocalViewRatio = gl.getUniformLocation(program, "uGlobalLocalViewRatio");

    const uLineColour = gl.getUniformLocation(program, "uLineColour");
    const uBackgroundColor = gl.getUniformLocation(program, "uBackgroundColor");
    const uDepthRange = gl.getUniformLocation(program, "uDepthRange");
    const uDepthSharpness = gl.getUniformLocation(program, "uDepthSharpness");

    const vertexAttributes = [
        {location: gl.getAttribLocation(program, "aVertexPositionGlobalView"), type: VertexAttributeType.VEC3F},
        {location: gl.getAttribLocation(program, "aVertexPositionLocalView"), type: VertexAttributeType.VEC3F},
    ];

    const lineVertexBuffer = gl.createBuffer();
    if (!lineVertexBuffer) throw new Error(`Failed to create new instanced vertex buffer`);
    let vertexData: Float32Array|null = null;

    let initializedInstanceData: NodeInstanceData[]|null = null;
    let initializedNodeCount = 0;
    const floatsPerNode = getVertexAttributesTotalFloatCount(vertexAttributes) * 2;

    return {
        prepare: (projViewMat: mat4) => {
            enableBlendSettings(gl, "alpha", true, true);
        
            gl.useProgram(program);

            gl.uniform4fv(uBackgroundColor, BACKGROUND_COLOR);
            gl.uniform4fv(uLineColour, vec4.fromValues(1,0,0,1));

            gl.uniformMatrix4fv(uProjViewMat, false, projViewMat);

            bindVertexAttributes(gl, lineVertexBuffer, false, vertexAttributes);
        },
        draw: (nodes: NodeInstanceData[], totalCount: number, globalLocalViewRatio: number) => {
            if (nodes != initializedInstanceData) {
                vertexData = new Float32Array(floatsPerNode * totalCount);
                initializedInstanceData = nodes;
                initializedNodeCount = 0;

                gl.bufferData(gl.ARRAY_BUFFER, vertexData, gl.STATIC_DRAW);

            }

            if (initializedNodeCount < nodes.length) {
                const numNewNodes = nodes.length - initializedNodeCount;
                const srcStartOffset = floatsPerNode*initializedNodeCount;

                let offset = srcStartOffset;
                for (let idx=initializedNodeCount; idx<nodes.length; ++idx) {
                    const node = nodes[idx];
                    // Node 0 is linked back to itself, so effectively has no visible line - hacky, but simplfies the code here considerably
                    const parent = idx == 0 ? node : nodes[node.parentIdx!];

                    vertexData!.set(node.globalViewPos, offset);     offset += 3;
                    vertexData!.set(node.localViewPos, offset);      offset += 3;
                    vertexData!.set(parent.globalViewPos, offset);   offset += 3;
                    vertexData!.set(parent.localViewPos, offset);    offset += 3;
                }

                const dstStartByte = initializedNodeCount * floatsPerNode * FLOAT_SIZE_BYTES;
                gl.bufferSubData(gl.ARRAY_BUFFER, dstStartByte, vertexData!, srcStartOffset, numNewNodes*floatsPerNode);
            
                initializedNodeCount = nodes.length;
            }

            const numNodesToDraw = clamp(nodeVisualisationParams.numNodes, 1, nodes.length);

            gl.uniform1f(uSpreadFactor, calculateNodeSpreadFactor(numNodesToDraw)); // Spread nodes out as total number increases
            gl.uniform1f(uGlobalLocalViewRatio, globalLocalViewRatio);

            gl.uniform1f(uDepthRange, nodeVisualisationParams.depthRange);
            gl.uniform1f(uDepthSharpness, nodeVisualisationParams.depthSharpness);

            gl.drawArrays(gl.LINES, 0, numNodesToDraw*2);
        }
    };
}