import React, { useRef, useMemo, useState } from 'react'
import { Math as THREEMath, MeshStandardMaterial, ShaderMaterial, DoubleSide, TextureLoader, UniformsUtils, DataTexture} from 'three'
import { CustomBlending, AddEquation, SrcAlphaFactor, LessEqualDepth, dataTexture, LuminanceFormat, FloatType, RepeatWrapping} from 'three'
import { useFrame, useLoader,  } from 'react-three-fiber'
import {ActivationShader} from './Shaders'


function getActivationMaterial() {
    const material = new ShaderMaterial({
        uniforms: UniformsUtils.clone(ActivationShader.uniforms),
        vertexShader: ActivationShader.vertexShader,
        fragmentShader: ActivationShader.fragmentShader,
    });
    return material;
}

export function ActivationPlane(props) {
    const {w, x, activation, params} = props;
    const textureData = useRef();

    const neurons = params.neurons;
    const activationTexture = useRef();
    const material = useMemo(() => getActivationMaterial(), []);
    const mesh = useRef();

    let roundedCornerMaskTexture = useLoader(TextureLoader, "assets/rounded_corner_mask.jpg");
    roundedCornerMaskTexture.wrapT = roundedCornerMaskTexture.wrapS = RepeatWrapping;

    useFrame(() => {
        const {data, shape, normalize} = activation;
        if(material.uniforms.cornerMaskTexture.value !== roundedCornerMaskTexture) {
            material.uniforms.cornerMaskTexture.value = roundedCornerMaskTexture;
            material.uniforms.cornerMaskTexture.value.needsUpdate = true;
        }
        if(!textureData.current || textureData.current.length !== data.length) {
            textureData.current = new Float32Array(data.length);
        }
        textureData.current.set(data);
        if(!activationTexture.current || activationTexture.current.image.data !== textureData.current) {
            activationTexture.current = new DataTexture(
                textureData.current, shape.w*neurons.length, shape.h, LuminanceFormat, FloatType);
            activationTexture.current.needsUpdate = true;
        }

        mesh.current.position.set(x, 0, 0);
        mesh.current.material = material;
        material.uniforms.dataTexture.value = activationTexture.current;
        material.uniforms.dataTexture.value.needsUpdate = true;
        material.uniforms.neurons.value = neurons.length;
        material.uniforms.normalize.value = normalize ? normalize/4 : Math.pow(10, -params.normalize);
    });
    return (
        <mesh 
            ref={mesh} >
            <planeBufferGeometry attach="geometry" args={[w*1, 1, 1]} />
        </mesh>
    );
}