import React, {useState, useRef} from 'react';
import TextField from '@material-ui/core/TextField';
import { CustomSlider } from './ParamControls/CustomSlider';
import {
    FormControl, FormControlLabel, MenuItem,
    Select, InputLabel, IconButton, Switch,
    Collapse, ListItem, ListItemText
} from '@material-ui/core'
import { outputTypes } from '../Model/Outputs'
import { getCenterNeuronCoords } from '../Model/ModelUtils'
import { TuningCurveRenderer } from './TuningCurveComponents/TuningCurveRendererGL'
import { TuningCurveLinePlot } from './TuningCurveComponents/TuningCurveLinePlot'

import { LayerView } from './LayerComponents/LayerView'
import { LogitView } from './LogitComponents/LogitView'
import { getPoolWidthFromDefParam } from '../Model/OutputUtils'
import {
    AddCircleOutline, RemoveCircleOutline,
    ExpandLess, ExpandMore
} from '@material-ui/icons';
import { getNeuronThumbURL, archId, getDefaultNewActivationParams, modelDirNames } from '../Model/ModelMetaInfo'
import { ActivationCanvas } from './ActivationComponents/ActivationRenderer'
import { ClusterView } from './TuningCurveComponents/Clustering/ClusterView'
import {AttributionView} from './LayerComponents/AttributionView'
import { FeatureVisRow} from './FeatureVisComponents/FeatureVisRow'

import * as config from '../config'

export const defaultActivationParams = [
    {
        modelIndex: 0, layer: 'mixed4a_pre_relu', neurons: [308, 222, 234, 325],
        showAllNeurons: true, normalize: 0, poolWidth: 1
    },
    {
        modelIndex: 0, layer: 'mixed4b_pre_relu', neurons: [443, 429],
        showAllNeurons: false, normalize: 0, poolWidth: 1
    }]

export function getDefaultActivations() {
    return {
        act: {
            0: { data: new Float32Array([0.5, -.5, 0.2, 0.3]), shape: { w: 1, h: 1 } },
            1: { data: new Float32Array([0.5, -.5, 0.2, 0.3]), shape: { w: 1, h: 1 } },
            2: { data: new Float32Array([0.5, -.5, 0.2, 0.3]), shape: { w: 1, h: 1 } }
        }
    };
}

const defaultNewActivations = () => { return { data: new Float32Array([0]), shape: { w: 1, h: 1 } } };

function addActivationParams(archType, activationParams, setActivationParams,
    activations, setActivations, modelInds, setModelInds) {
    const num = Object.keys(activations.act).length;
    activations.act[num] = defaultNewActivations();

    const newModelInds = [...modelInds];
    newModelInds.push(0)

    activationParams.push(getDefaultNewActivationParams(archType));

    setActivations(activations);
    setActivationParams(activationParams);
    setModelInds(newModelInds);
}

export function getActivationViewKey(index) {
    return "ActivationView" + index;
}

function removeActivationParams(modelManager, activationParams, setActivationParams,
    activations, setActivations, modelInds, setModelInds) {
    const num = Object.keys(activations.act).length;
    delete activations.act[num - 1];

    const newModelInds = [...modelInds];
    const removedInd = newModelInds.pop();
    modelManager.getModel(removedInd).outputManager.deregisterOutputDefGenerators(getActivationViewKey(newModelInds.length));

    activationParams.pop();

    setActivations(activations);
    setActivationParams(activationParams);
    setModelInds(newModelInds);
}

export function getPlusMinusView(modelManager, activationParams, setActivationParams,
    activations, setActivations, modelInds, setModelInds, setNeedsUpdate) {
    const displayRemoveButton = Object.keys(activationParams).length > 1;
    const model = modelManager.getModel(modelInds[modelInds.length - 1]);
    return (
        <div className="column" style={{ height: "100%" }}>
            <IconButton size="large"
                onClick={() => {
                    addActivationParams(model.archType,
                        activationParams, setActivationParams,
                        activations, setActivations, modelInds, setModelInds);
                    setNeedsUpdate(true);
                }}
            >
                <AddCircleOutline fontSize="inherit" />
            </IconButton>
            {displayRemoveButton ? <IconButton size="large"
                onClick={() => removeActivationParams(modelManager, activationParams, setActivationParams,
                    activations, setActivations, modelInds, setModelInds)}
            >
                <RemoveCircleOutline fontSize="inherit" />
            </IconButton> : ""}
        </div>);
}

function getNeuronString(params) {
    let neuronString = ""
    params.neurons.forEach((n, i) => {
        if (i) {
            neuronString += " ";
        }
        neuronString += n;
    });
    return neuronString;
}

function getFeatureVisRow(model, params, neuronWidth, usePriorVis) {
    return params.neurons.map((n, i) => {
        const src = getNeuronThumbURL(model, params.layer, n, usePriorVis);
        return (
            <img width={neuronWidth} height={neuronWidth} src={src}></img>
        )
    })
}

function getModelSelect(modelNames, modelInd, onChange) {
    const wrappedModelNames = [...modelNames];
    wrappedModelNames.push("Edited Model");
    return (
        <FormControl >
            <InputLabel id="demo-simple-select-label">Model</InputLabel>
            <Select
                labelId="demo-simple-select-label"
                id="demo-simple-select"
                value={modelInd}
                onChange={onChange}
            >
                {
                    wrappedModelNames.map((m, i) => {
                        return <MenuItem value={i}>{m}</MenuItem>
                    })
                }
            </Select>
        </FormControl>)
}

export function getLayerSelect(layers, layer, onChange) {
    return (
        <FormControl >
            <InputLabel id="demo-simple-select-label">Layer</InputLabel>
            <Select
                labelId="demo-simple-select-label"
                id="demo-simple-select"
                value={layer}
                onChange={onChange}
            >
                {
                    layers.map((l) => {
                        return <MenuItem value={l}>{l}</MenuItem>
                    })
                }
            </Select>
        </FormControl>)
}

export function getNeuronIndicatorDivs(neuronWidth, activations, singleActivationParams) {
    let divArray = [];
    let currentLeft = 0;
    const shape = activations.shape;
    for (let i = 0; i < singleActivationParams.neurons.length; i++) {
        const centerNeuronWidth = neuronWidth / shape.w;
        const { x, y } = getCenterNeuronCoords(shape);
        const poolWidth = getPoolWidthFromDefParam(singleActivationParams.poolWidth, shape.h, shape.w);
        const cnMarginX = currentLeft + (x - Math.floor(poolWidth / 2)) * centerNeuronWidth;
        const cnMarginY = (y - Math.floor(poolWidth / 2)) * centerNeuronWidth;
        currentLeft += neuronWidth;

        divArray.push(<div style={{
            position: 'absolute',
            pointerEvents: 'none',
            top: cnMarginY,
            left: cnMarginX,
            width: centerNeuronWidth * poolWidth,
            height: centerNeuronWidth * poolWidth,
            zIndex: 1,
            border: '2px dotted white',
            visibility: config.SHOW_RECEPTIVE_FIELD ? 'visible' : 'hidden'
        }}>
        </div>);
    }
    return divArray;
}

export function ActivationVisColumn(columnProps) {
    const { activationParams, setActivationParams, minColWidth,
        activations,
        layerActivations, logitActivations, neuronWidth, updateReceptiveField,
        modelInds, setModelInds, setNeedsUpdate, usePriorVis,
        modelManager, tuningCurveResults,
        componentVisibility, setComponentVisibility, index } = columnProps;

    const [neuronText, setNeuronText] = useState(undefined);

    const neuronTextFieldRef=useRef();

    const params = activationParams[index];

    if(neuronText) {
        const tokens = neuronText.trim().split(" ");
        const neurons = tokens.map((t) => parseInt(t));
        let change = false;
        params.neurons.forEach((n,i) => {
            if(n !== neurons[i]) {
                change = true;
            }
        });
        if(params.neurons.length !== neurons.length) {
            change = true;
        }

        if(change) {
            setNeuronText(getNeuronString(params));
        }
    } else {
        setNeuronText(getNeuronString(params));
    }

    const model = modelManager.getModel(modelInds[index], true, () => {
        setNeedsUpdate(true)
    });
    const modelLayers = model.getLayerNames(true);
    const colWidth = neuronWidth * Math.max(params.neurons.length, minColWidth);
    const width = neuronWidth * params.neurons.length;
    let neuronString = getNeuronString(params);
    const normString = activationParams[index].normalize == 0 ? 'layer stdev.' : Math.floor(Math.pow(10, activationParams[index].normalize));

    const getConditionalStyle = (show) => {
        return {
            transition: "height 1s",
            height: show ? "auto" : 0,
            overflow: "hidden"
        }
    }

    const modelNames = modelManager.modelDirs;
    const canvasWidth = getCanvasWidth(params, neuronWidth);
    return (
        <div className="column">
            <ListItem button onClick={() => {
                setComponentVisibility('activations', !componentVisibility['activations']);
                setNeedsUpdate(true);
            }}>
                <ListItemText primary="Neuron Activations" />
                {componentVisibility['activations'] ? <ExpandLess /> : <ExpandMore />}
            </ListItem>
            <Collapse in={componentVisibility['activations']} timeout="auto">
                <div className="row">
                    {(<div style={{
                        width: canvasWidth, height: neuronWidth,
                        position: 'relative'
                    }}>
                        <ActivationCanvas width={canvasWidth} height={neuronWidth}
                            activationParams={params} activations={activations.act[index]} />
                        {config.SHOW_RECEPTIVE_FIELD ? getNeuronIndicatorDivs(
                            neuronWidth, activations.act[index], activationParams[index]) : ""}
                    </div>)}
                </div>
                <div className='row'>
                    <FeatureVisRow
                    model={model}
                    params={params}
                    neuronWidth={neuronWidth}
                    usePriorVis={usePriorVis}
                    />
                </div>
                <div className={'column'} style={{ padding: "10px", width: width }}>
                    <TextField ref={neuronTextFieldRef} id="standard-required" label="Neurons"
                    value={neuronText} defaultValue={neuronString}
                        onChange={(evt) => {
                            const tokens = evt.target.value.trim().split(" ");
                            const neurons = tokens.map((t) => parseInt(t));
                            let newParams = [...activationParams];
                            newParams[index].neurons = neurons;
                            setNeuronText(evt.target.value);
                            setActivationParams(newParams);
                            setNeedsUpdate(true);
                        }} />
                </div>
            </Collapse>
            <ListItem button onClick={() => {
                setComponentVisibility('layer', !componentVisibility['layer']);
                setNeedsUpdate(true);
            }}>
                <ListItemText primary="All Layer Neurons" />
                {componentVisibility['layer'] ? <ExpandLess /> : <ExpandMore />}
            </ListItem>
            <Collapse in={componentVisibility['layer']} timeout="auto">
                <div className='row'
                    style={getConditionalStyle(componentVisibility['layer'])}>
                    <LayerView
                        model={model}
                        activationParams={activationParams}
                        layerActivations={layerActivations}
                        width={colWidth}
                        height={neuronWidth}
                        index={index}
                        usePriorVis={usePriorVis} />
                </div>
            </Collapse>
            <ListItem button onClick={() => {
                setComponentVisibility('logits', !componentVisibility['logits']);
                setNeedsUpdate(true);
            }}>
                <ListItemText primary="Classification" />
                {componentVisibility['logits'] ? <ExpandLess /> : <ExpandMore />}
            </ListItem>
            <Collapse in={componentVisibility['logits']} timeout="auto">
                {<div className='row'
                    style={getConditionalStyle(componentVisibility['logits'])}>
                    <LogitView
                        model={model}
                        activationParams={activationParams}
                        logitActivations={logitActivations}
                        width={colWidth}
                        height={neuronWidth}
                        index={index} />
                </div>}
            </Collapse>
            <ListItem button onClick={() => {
                setComponentVisibility('tuningCurves', !componentVisibility['tuningCurves']);
            }}>
                <ListItemText primary="Tuning Curves" />
                {componentVisibility['tuningCurves'] ? <ExpandLess /> : <ExpandMore />}
            </ListItem>
            <Collapse in={componentVisibility['tuningCurves']} timeout="auto">
                <ListItem button onClick={() => {
                    setComponentVisibility('activationsInTuningCurves', !componentVisibility['activationsInTuningCurves']);
                    setNeedsUpdate(true);
                }}>
                    <ListItemText primary="Neuron Activations" />
                    {componentVisibility['activationsInTuningCurves'] ? <ExpandLess /> : <ExpandMore />}
                </ListItem>
                <Collapse in={componentVisibility['activationsInTuningCurves']} timeout="auto">
                    <div className="row">
                        {(<div style={{
                            width: canvasWidth, height: neuronWidth,
                            position: 'relative'
                        }}>
                            <ActivationCanvas width={canvasWidth} height={neuronWidth}
                                activationParams={params} activations={activations.act[index]} />
                            {config.SHOW_RECEPTIVE_FIELD ? getNeuronIndicatorDivs(
                                neuronWidth, activations.act[index], activationParams[index]) : ""}
                        </div>)}
                    </div>
                    <div className='row'>
                        {getFeatureVisRow(model, params, neuronWidth, usePriorVis)}
                    </div>
                </Collapse>
                <div className='row'
                    style={getConditionalStyle(componentVisibility['tuningCurves'] &&
                        Object.entries(tuningCurveResults).length !== 0)}>
                    {componentVisibility['tuningCurves'] &&
                        <TuningCurveRenderer
                            visible={componentVisibility['tuningCurves']}
                            tuningCurveResult={tuningCurveResults[index]}
                            neurons={activationParams[index].neurons}
                            width={width}
                            height={neuronWidth}
                            index={index} />}
                </div>
                <ListItem button onClick={() => {
                    setComponentVisibility('tuningCurveClusterPlot', !componentVisibility['tuningCurveClusterPlot']);
                }}>
                    <ListItemText primary="Clusters" />
                    {componentVisibility['tuningCurveClusterPlot'] ? <ExpandLess /> : <ExpandMore />}
                </ListItem>
                <Collapse in={componentVisibility['tuningCurveClusterPlot']} timeout="auto">
                    <div className="row">
                        <ClusterView
                            model={model}
                            tuningCurveResult={tuningCurveResults[index]}
                            activationParams={activationParams}
                            setActivationParams={setActivationParams}
                            setNeedsUpdate={setNeedsUpdate}
                            index={index}
                            usePriorVis={usePriorVis}
                            width={colWidth}
                            height={Math.min(colWidth / 1.2, neuronWidth*6)} />
                    </div>
                </Collapse>
            </Collapse>
            <div className={'column'} style={{ padding: "10px", width: colWidth }}>
                {getModelSelect(modelNames,
                    modelInds[index], (evt) => {
                        let newModelInds = [...modelInds]
                        newModelInds[index] = evt.target.value;
                        setModelInds(newModelInds);
                        setNeedsUpdate(true);
                    })}
                {getLayerSelect(modelLayers, params.layer, (evt) => {
                    let newParams = [...activationParams];
                    newParams[index].layer = evt.target.value;
                    setActivationParams(newParams);
                    updateReceptiveField();
                    setNeedsUpdate(true);
                })}
                <FormControlLabel
                    control={
                        <Switch checked={activationParams[index].showAllNeurons}
                            onChange={(evt) => {
                                let newParams = [...activationParams];
                                newParams[index].showAllNeurons = evt.target.checked;
                                setActivationParams(newParams);
                            }} />}
                    label={activationParams[index].showAllNeurons ?
                        "Show all neurons" : "Show top neurons"}
                />
                <CustomSlider
                    labelText={"Pool width: " + (activationParams[index].poolWidth === -1 ? "pool all" : activationParams[index].poolWidth)}
                    valueLabelDisplay={"auto"}
                    step={2}
                    min={-1}
                    max={21}
                    value={activationParams[index].poolWidth}
                    onChange={(evt, value) => {
                        let newParams = [...activationParams];
                        newParams[index].poolWidth = value;
                        setActivationParams(newParams);
                        setNeedsUpdate(true);
                    }}
                />
                <CustomSlider
                    labelText={"Normalization: div. by " + normString}
                    valueLabelDisplay={"auto"}
                    step={0.02}
                    min={0}
                    max={3}
                    value={activationParams[index].normalize}
                    onChange={(evt, value) => {
                        let newParams = [...activationParams];
                        newParams[index].normalize = value;
                        setActivationParams(newParams);
                        setNeedsUpdate(true);
                    }}
                />
            </div>
        </div>);
}

/**
 * Generates function that can be registered with OutputManager of model to receive specified model output.
 * @param {*} actParams single activation parameter object (not to be confused with activationParams of ActivationView, encompassing several "actParams")
 * @param {*} callback callback that gets called by the inferencer when activation data is available.
 */
export function getOutputDefGenerator(actParams, callback, includeMoments = false, pool = false, layerOutput = false) {
    const type = outputTypes.NEURONS;
    const layer = actParams.layer;
    const channels = layerOutput ? -1 : actParams.neurons;
    const poolType = 'avg';
    const poolWidth = actParams.poolWidth;
    if (pool) {
        return () => {
            return {
                type: type, layer: layer, channels: channels,
                includeMoments: includeMoments, callback: callback,
                poolWidth: poolWidth, poolType: poolType
            };
        };
    } else {
        return () => {
            return {
                type: type, layer: layer, channels: channels,
                includeMoments: includeMoments, callback: callback
            };
        };
    }
}

export function getLogitOutputDefGenerator(model, callback, includeMoments = false, includeGradients = false) {
    const layer = model.outputs[0].inputs[0].sourceLayer;
    const layerName = layer.name;
    const logitShape = layer.output.shape;
    const outputShape = model.outputs[0].shape;
    if (!logitShape[logitShape.length - 1] === outputShape[outputShape.length - 1]) {
        throw "Logit- and Softmax layers seem to not be separated!";
    }
    const type = outputTypes.LOGITS;
    return () => {
        return {
            type: type, layer: layerName, includeMoments: includeMoments, includeGradients: includeGradients,
            callback: callback
        };
    };
}

export const getNeuronInferenceCallback = (activationsRef) => (ret) => {
    const activations = activationsRef.current;
    activations.data = ret.data;
    activations.shape.w = ret.w;
    activations.shape.h = ret.h;
    if ('variance' in ret) {
        activations.normalize = 1 / Math.sqrt(ret.variance);
    } else {
        delete activations.normalize;
    }
};

const getLayerInferenceCallback = (layerActivations, setLayerActivations, activationParams, i) => (ret) => {
    let newActivations;
    if (!layerActivations) { newActivations = {} }
    else { newActivations = { ...layerActivations }; }
    if (!(i in newActivations)) { newActivations[i] = {} }
    newActivations[i].data = ret.data;
    if ('variance' in ret) {
        newActivations[i].normalize = 1 / Math.sqrt(ret.variance);
    } else {
        newActivations[i].normalize = 1 / Math.pow(10, activationParams.normalize);
    }
    setLayerActivations(newActivations);
};

export const getLogitInferenceCallback = (logitActivations, setLogitActivations, activationParams, i) => (ret) => {
    let activations = {...logitActivations};
    activations[i].data = ret.data;
    if('gradients' in ret) {
        activations[i].gradients = ret.gradients;
    }
    if ('variance' in ret) {
        activations[i].normalize = 1 / Math.sqrt(ret.variance);
    } else {
        activations[i].normalize = 1 / Math.pow(10, activationParams.normalize);
    }
    setLogitActivations(activations);
};

const getAnimationCallback = (animation, activationParams, paramN, i) => (inferenceData, metaData) => {
    if (!('variance' in inferenceData)) {
        inferenceData.mean = 0;
        inferenceData.variance = Math.pow(Math.pow(10, activationParams.normalize), 2);
    }
    animation.callback({
        animation: animation,
        inferenceData: inferenceData,
        activationParams: activationParams,
        paramNumber: paramN,
        paramIndex: i,
        metaData: metaData
    });
};

export function getEnabledDefGenerators(props) {
    const { componentVisibility, animation } = props;
    let enabledGenerators = { "neurons": true };
    if (componentVisibility['layer']) {
        enabledGenerators["layer"] = true;
    }
    if (componentVisibility['logits']) {
        enabledGenerators["logits"] = true;
    }
    return enabledGenerators;
}

export function getDefGenerators(enabledGenerators, activations, index, setActivations, activationParams, animation,
    layerActivations, setLayerActivations, logitActivations, setLogitActivations, model) {
    let defGenerators = [];
    const params = activationParams[index];
    const i = index;
    if ("neurons" in enabledGenerators) {
        const generator = getOutputDefGenerator(
            params, getNeuronInferenceCallback(setActivations), params.normalize == 0);
        defGenerators.push(generator);
    }
    if ("layer" in enabledGenerators) {
        const generator = getOutputDefGenerator(
            params, getLayerInferenceCallback(
                layerActivations, setLayerActivations, params, i), params.normalize == 0, true, true);
        defGenerators.push(generator);
    }
    if ("logits" in enabledGenerators && model.isLoaded()) {
        const generator = getLogitOutputDefGenerator(
            model.sourceModel, getLogitInferenceCallback(
                logitActivations, setLogitActivations, params, i), params.normalize == 0, false);
        defGenerators.push(generator);
    }
    if (animation) {
        const generator = getOutputDefGenerator(
            params, getAnimationCallback(
                animation, params, activationParams.length, i),
            params.normalize == 0, true, true);
        defGenerators.push(generator);
    }
    return defGenerators;
}

export function getCanvasWidth(singleActivationParams, neuronWidth) {
    const canvasWidth = neuronWidth * singleActivationParams.neurons.length;
    return canvasWidth;
}