import { clamp } from "@/core/utils";
import { mat4 } from "gl-matrix";
import { NodeInstanceData } from "./renderer_node";
import { FLOAT_SIZE_BYTES, Renderer, VertexAttributeType, bindVertexAttributes, compileProgram, compileShader, enableBlendSettings, getVertexAttributesTotalFloatCount } from "./renderer";
import { nodeVisualisationParams } from "../rendering_state";
import { calculateNodeSpreadFactor } from "../node_layout/node_layout";

const MAX_LINES = 100000;

export function createNodeLocalLineRenderer(gl: WebGL2RenderingContext): Renderer<(nodes:NodeInstanceData[], focusNodeIdx: number, maxGenerationFromFocus: number, totalCount: number, globalLocalViewRatio: number)=>void> {
    const vertexShader = compileShader(
        gl, 
        `#version 300 es

        // Standard global uniforms
        uniform mat4 uProjViewMat;
        uniform float uSpreadFactor;

        uniform float uGlobalLocalViewRatio;

        // Shared
        in vec3 aVertexPositionGlobalView;
        in vec3 aVertexPositionLocalView;
        in vec4 aVertexColour;

        out vec4 vColour;

        const vec4 DEPTH_OFFSET = vec4(0, 0, 0, -0.5); // To bias lines to be rendered behind nodes

        void main(){
            // Transiting between global and local view position
            vec3 worldPosInCurrentView = mix(aVertexPositionGlobalView, aVertexPositionLocalView, uGlobalLocalViewRatio);

            gl_Position = uProjViewMat * vec4(worldPosInCurrentView*uSpreadFactor, 1.0) + DEPTH_OFFSET;
            vColour = aVertexColour;
        }
        `,
        gl.VERTEX_SHADER
    );

    const fragmentShader = compileShader(
        gl,
        `#version 300 es
        precision highp float;

        uniform float uDepthRange;
        uniform float uDepthSharpness;

        in vec4 vColour;

        out vec4 fragColour;
        
        void main() {
            /*
             * Blend from the line's original colour, towards the scene's background colour, as depth increases.
             * This creates a fog / brightness falloff effect
             */
            float depthFactor = pow(clamp(gl_FragCoord.w*uDepthRange, 0.0, 1.0), uDepthSharpness);
            vec4 colourWithFalloff = mix(vec4(0), vec4(vColour[0],vColour[1],vColour[2],1.0), depthFactor*vColour[3]*0.85);
            
            fragColour = colourWithFalloff;
        }
        `,
        gl.FRAGMENT_SHADER
    );

    const program = compileProgram(gl, vertexShader, fragmentShader);

    const uProjViewMat = gl.getUniformLocation(program, "uProjViewMat");
    const uSpreadFactor = gl.getUniformLocation(program, "uSpreadFactor");

    const uGlobalLocalViewRatio = gl.getUniformLocation(program, "uGlobalLocalViewRatio");

    const uDepthRange = gl.getUniformLocation(program, "uDepthRange");
    const uDepthSharpness = gl.getUniformLocation(program, "uDepthSharpness");

    const vertexAttributes = [
        {location: gl.getAttribLocation(program, "aVertexPositionGlobalView"), type: VertexAttributeType.VEC3F},
        {location: gl.getAttribLocation(program, "aVertexPositionLocalView"), type: VertexAttributeType.VEC3F},
        {location: gl.getAttribLocation(program, "aVertexColour"), type: VertexAttributeType.VEC4F},
    ];

    const lineVertexBuffer = gl.createBuffer();
    if (!lineVertexBuffer) throw new Error(`Failed to create new instanced vertex buffer`);
    let vertexData: Float32Array|null = null;
    let numLinesToDrawByNodeCount: number[] = [];

    const initializedInstanceData = {
       nodes: null as (NodeInstanceData[]|null),
       focusNodeIdx: null as (number|null),
       maxGenerationFromFocus: null as (number|null)
    };
    let initializedNodeCount = 0;
    let localLineCount = 0;
    const floatsPerLine = getVertexAttributesTotalFloatCount(vertexAttributes) * 2;

    return {
        prepare: (projViewMat: mat4) => {
            enableBlendSettings(gl, "additive", true, false);
        
            gl.useProgram(program);

            gl.uniformMatrix4fv(uProjViewMat, false, projViewMat);

            bindVertexAttributes(gl, lineVertexBuffer, false, vertexAttributes);
        },
        draw: (nodes: NodeInstanceData[], focusNodeIdx: number, maxGenerationFromFocus: number, totalCount: number, globalLocalViewRatio: number) => {
            if (initializedInstanceData.nodes != nodes || initializedInstanceData.focusNodeIdx != focusNodeIdx || initializedInstanceData.maxGenerationFromFocus != maxGenerationFromFocus) {
                vertexData = new Float32Array(floatsPerLine * MAX_LINES);
                initializedInstanceData.nodes = nodes;
                initializedInstanceData.focusNodeIdx = focusNodeIdx;
                initializedInstanceData.maxGenerationFromFocus = maxGenerationFromFocus;
                initializedNodeCount = 0;
                localLineCount = 0;
                numLinesToDrawByNodeCount = new Array(totalCount);

                gl.bufferData(gl.ARRAY_BUFFER, vertexData, gl.STATIC_DRAW);
            }

            if (initializedNodeCount < nodes.length) {
                const oldLocalLineCount = localLineCount;
                const srcStartOffset = floatsPerLine*oldLocalLineCount;

                let offset = srcStartOffset;

                const addLine = (nodeIdx: number, distFromRoot: number) => {
                    const node = nodes[nodeIdx];
                    const nodeAlpha = 1.0 - distFromRoot/(maxGenerationFromFocus+0.25);
                    vertexData!.set(node.globalViewPos, offset);     offset += 3;
                    vertexData!.set(node.localViewPos, offset);      offset += 3;
                    vertexData!.set(node.colour, offset);            offset += 3;
                    vertexData!.set([nodeAlpha], offset);            offset += 1;

                    const parent = nodes[node.parentIdx!];
                    const parentAlpha = 1.0 - Math.max(0, distFromRoot-1)/(maxGenerationFromFocus+0.25);
                    vertexData!.set(parent.globalViewPos, offset);    offset += 3;
                    vertexData!.set(parent.localViewPos, offset);     offset += 3;
                    vertexData!.set(parent.colour, offset);           offset += 3;
                    vertexData!.set([parentAlpha], offset);           offset += 1;
                    
                    localLineCount++;
                }

                for (let idx=initializedNodeCount; idx<nodes.length; ++idx) {
                    if (idx == focusNodeIdx) {
                        // Walk path from focus node back towards root, drawing all lines
                        let parentIdx = idx as (number|null);
                        while (parentIdx != null && parentIdx != 0) {
                            addLine(parentIdx, 0);
                            parentIdx = parentIdx = nodes[parentIdx].parentIdx;
                        }
                    } else {
                        // Measure distance from focus node, and decide whether to draw the line
                        let distFromFocus = 0;
                        let parentIdx = idx as (number|null);
                        while (distFromFocus <= maxGenerationFromFocus) {
                            if (parentIdx == null || parentIdx == focusNodeIdx) {
                                break;
                            }
    
                            ++distFromFocus;
                            parentIdx = nodes[parentIdx].parentIdx;
                        }
    
                        if (distFromFocus > 0 && distFromFocus <= maxGenerationFromFocus && parentIdx != null) {
                            addLine(idx, distFromFocus);
                        }
                    }

                    numLinesToDrawByNodeCount[idx] = localLineCount;
                }

                const numNewLines = localLineCount-oldLocalLineCount;
                const dstStartByte = oldLocalLineCount * floatsPerLine * FLOAT_SIZE_BYTES;
                gl.bufferSubData(gl.ARRAY_BUFFER, dstStartByte, vertexData!, srcStartOffset, numNewLines*floatsPerLine);
            
                initializedNodeCount = nodes.length;
            }

            const numNodesToDraw = clamp(nodeVisualisationParams.numNodes, 1, nodes.length);

            // Determine how many lines we should draw, based on the number of nodes shown
            // This prevents us from drawing lines for nodes which are hidden
            const numLinesToDraw = numLinesToDrawByNodeCount[numNodesToDraw-1];

            gl.uniform1f(uSpreadFactor, calculateNodeSpreadFactor(numNodesToDraw)); // Spread nodes out as total number increases
            gl.uniform1f(uGlobalLocalViewRatio, globalLocalViewRatio);

            gl.uniform1f(uDepthRange, nodeVisualisationParams.depthRange);
            gl.uniform1f(uDepthSharpness, nodeVisualisationParams.depthSharpness);

            gl.drawArrays(gl.LINES, 0, numLinesToDraw*2);
        }
    };
}