import React, { useState, useMemo } from 'react';
import './App.css';
import './styles.css'
import { getModel } from './Model/Model'
import { Sidebar } from './Components/Sidebar'
import { Side } from 'three';
import { Inferencer } from './Model/Outputs'
import * as tf from '@tensorflow/tfjs'
import { Animation2D } from './Components/SceneComponents/Animation2D'
import { Animation } from './Components/SceneComponents/Animation'
import { handleAnimation, handleTuningCurveData, handleModel } from './AppUtil'
import { model, time } from '@tensorflow/tfjs';
import { SceneView } from './Components/SceneView'
import { ActivationView } from './Components/ActivationView'
import {archId, getModelPathFromDirName, modelDirNames} from './Model/ModelMetaInfo'
import {ModelManager} from './Model/ModelManager'
import {TextField} from '@material-ui/core'
import * as md5 from 'md5'
import {TuningCurveClusterPlot} from './Components/TuningCurveComponents/Clustering/TuningCurveClusterPlot'
import { formatPrefix } from 'd3';

const sceneCanvasWidth = 300;

function subgraphFunctors(inputLayer, outputLayer) {
    const outChannels = outputLayer.output.shape[outputLayer.output.shape.length-1];
    const channelFunctors = [];
    for(let i=0; i<outChannels; i++) {
        const functor = (x) => {
            const outputs = {};
            outputs[inputLayer.name] = x;
        
            const processLayer = (layer) => {
                const inputs = [];
                layer.inboundNodes[0].inboundLayers.forEach((inboundLayer) => {
                    if(inboundLayer.name in outputs) {
                        inputs.push(outputs[inboundLayer.name]);
                    }
                });
                if(inputs.length === layer.inboundNodes[0].inboundLayers.length) {
                    if(layer.name.indexOf("re_lu") !== -1) {
                        outputs[layer.name] = inputs[0];
                    } else if (layer.name.indexOf("max_pool") !== -1) {
                        outputs[layer.name] = tf.avgPool(inputs[0], layer.poolSize, layer.strides, layer.padding);
                    } else {
                        outputs[layer.name] = layer.apply(inputs);
                    }
                    if(layer.name === outputLayer.name) {
                        return;
                    }
                    layer.outboundNodes.forEach((outNode) => {
                        processLayer(outNode.outboundLayer);
                    });
                }
            }
        
            inputLayer.outboundNodes.forEach((outNode) => {
                processLayer(outNode.outboundLayer);
            });
    
            const output = outputs[outputLayer.name];
            const [b, h, w, d] = output.shape;
            const cy = Math.floor(h/2);
            const cx = Math.floor(w/2);
            return output.gather([0, cy, cx, i]);
        }
        channelFunctors.push(functor);
    }
    return channelFunctors;
}

function App() {

    const [updateState, setUpdateState] = useState(true);
    const [modelDirs, setModelDirs] = useState({model1: modelDirNames[1], model2: modelDirNames[2]});
    const update = () => {setUpdateState(!updateState);};

    const modelManager = useMemo(() => {
        return new ModelManager(2);
    }, []);
    modelManager.setModelDirs(modelDirs);

    const model = useMemo(() => {
        return modelManager.getModel(0, true, ()=>{update()});
    },[modelDirs.model1]);

    if(model.sourceModel) {
        const conv2d2 = model.sourceModel.getLayer("conv2d2");
        const mixed3a = model.sourceModel.getLayer("mixed4a_pre_relu");

        const functors = subgraphFunctors(conv2d2, mixed3a);
        const grads = functors.map((functor) => tf.grad(functor));

        grads.forEach((grad) => {
            const t = new Date().getTime();
            grad(tf.ones([1, 28, 28, 192]).mul(0.5)).dataSync()
            console.log("took", new Date().getTime() - t, "millisecs")

        });
    }

    return null;
}

export default App;