import React, {useState} from 'react'

import {ActivationCanvas} from '../ActivationComponents/ActivationRenderer'
import {getDefaultActivations, getLayerSelect} from '../ActivationViewUtil'
import {outputTypes} from '../../Model/Outputs'
import {FormControl, Select, InputLabel,
    MenuItem} from '@material-ui/core'

function getOutputDefGenerator(actParams, callback, includeMoments, gradCamTarget) {
    const type = outputTypes.GRADCAM;
    const layer = actParams.layer;

    return () => {
        return {
            type: type, layer: layer, 
            includeMoments: includeMoments, callback: callback,
            gradCamTarget: gradCamTarget, gradCamLayer: "softmax2"
        };
    };
}

export function GradCamView(props) {
    const {modelManager, width, setNeedsUpdate} = props;
    const [activations, setActivations] = useState(getDefaultActivations());
    const [targetClass, setTargetClass] = useState(0);
    const [targetLayer, setTargetLayer] = useState("mixed5b")
    const [attackModelInd, setAttackModelInd] = useState(0);

    const model = modelManager.getModel(attackModelInd, true);

    const gradCamCallback = (ret) => {
        const actCopy = { ...activations };
        let newActivations = actCopy;
        newActivations.act[0].data = ret.data;
        newActivations.act[0].shape.w = ret.w;
        newActivations.act[0].shape.h = ret.h;
        if ('variance' in ret) {
            newActivations.act[0].normalize = 4 / Math.sqrt(ret.variance);
        } else {
            delete newActivations.act[0].normalize;
        }
        setActivations(newActivations);
        //modelManager.getModel(attackModelInd).outputManager.deregisterOutputDefGenerators("gradcamview");
    };

    const defGenerator = getOutputDefGenerator({layer: targetLayer, neurons: [0]}, gradCamCallback, true, targetClass);
    modelManager.getModel(attackModelInd).outputManager.registerOutputDefGenerators("gradcamview", [defGenerator]);

    let modelNames = [...modelManager.modelDirs];
    modelNames.push("Edited Model");

    const modelLayers = model.getLayerNames(true);

    return(
        <div className="column">
            <ActivationCanvas width={width} height={width}
                activationParams={{ neurons: [0] }} activations={activations.act[0]} />
            <div className="column" style={{ padding: "10px", width: "100%" }}>
                <div className="row" style={{ width: "100%" }}>
                    <FormControl style={{ width: "100%" }}>
                        <InputLabel id="attackModelLabel">Model</InputLabel>
                        <Select
                            labelId="attackModelLabel"
                            id="demo-simple-select"
                            value={attackModelInd}
                            onChange={(evt) => {
                                setAttackModelInd(evt.target.value);
                                setNeedsUpdate(true);
                            }}
                        >
                            {
                                modelNames.map((l, i) => {
                                    return <MenuItem value={i}>{l}</MenuItem>
                                })
                            }
                        </Select>
                    </FormControl>
                </div>
                <div className="row" style={{ width: "100%" }}>
                    <FormControl style={{ width: "100%" }}>
                        <InputLabel id="demo-simple-select-label">Target Class</InputLabel>
                        <Select
                            labelId="demo-simple-select-label"
                            id="demo-simple-select"
                            value={targetClass}
                            onChange={(evt) => {
                                setTargetClass(evt.target.value);
                                setNeedsUpdate(true);
                            }}
                        >
                            {(() => {
                                const menuItems = [];
                                for (let i = 0; i < model.getNumClasses(); i++) {
                                    menuItems.push(<MenuItem value={i}>{i + " " + model.getClassName(i)}</MenuItem>);
                                }
                                return menuItems;
                            })()
                            }
                        </Select>
                    </FormControl>
                </div>
                <div className="row" style={{ width: "100%" }}>
                {getLayerSelect(modelLayers, targetLayer, (evt) => {
                    setTargetLayer(evt.target.value);
                    setNeedsUpdate(true);
                })}
                </div>
            </div>
        </div>);
}