import {tf} from '../TFJS'
import {preProcessRawUint8InputImage, getCenterNeuronCoords, getCenterRectCoordsFromWidth} from './ModelUtils'
import {getPoolWidthFromDefParam} from './OutputUtils'
import {subgraphFunctors} from'./Attribution'

export const outputTypes = {
    NEURONS: 'neurons',
    LAYER: 'layer',
    LOGITS: 'logits',
    GRADIENT: 'gradient',
    GRADCAM: 'gradcam'
}

/**
 * 
 * @param {*} inputInfo: example: {
 *       rawInputShape: [1, 300, 300, 4],
 *       inputShape: [1, 224, 224, 3],
 *       rawInputRange: [0, 255],
 *       rawOutputRange: [-1, 1]
 *   }, 
 */
export function getOutputSpec(inputInfo, outputDefs, poolWidth=1, poolType='avg') {};


export const exampleOutputSpec = {
    inputInfo: {
        rawInputShape: [1, 300, 300, 4],
        inputShape: [1, 224, 224, 3],
        rawInputRange: [0, 255],
        rawOutputRange: [-1, 1]
    },
    outputGroups: [
        {
            outputDefs: [
                {
                    type: outputTypes.NEURONS,
                    layer: 'mixed4a_pre_relu',
                    poolWidth: 1,
                    poolType: 'avg',
                    channels: [308, 224],
                    callback: (outputArray) => {console.log(outputArray)}
                },
                {
                    type: outputTypes.NEURONS,
                    layer: 'mixed4b_pre_relu',
                    poolWidth: 1,
                    poolType: 'avg',
                    channels: [308, 224],
                    callback: (outputArray) => {console.log(outputArray)}
                },
                {
                    type: outputTypes.LAYER,
                    layer: 'mixed4a_pre_relu',
                    poolWidth: -1,
                    poolType: 'avg',
                    callback: (outputArray) => {console.log(outputArray)}
                },
                {
                    type: outputTypes.LAYER,
                    layer: 'mixed4b_pre_relu',
                    poolWidth: -1,
                    poolType: 'avg',
                    callback: (outputArray) => {console.log(outputArray)}
                }
            ]
        },
        {
            outputDefs: [
                {
                    type: outputTypes.LOGITS,
                    layer: 'softmax_2',
                    callback: (outputArray) => {console.log(outputArray)}
                }
            ]
        },
        {
            outputDefs: [
                {
                    type: outputTypes.GRADIENT,
                    layer: 'softmax_2',
                    classIndex: 0,
                    callback: (outputArray) => {console.log(outputArray)}
                }
            ]
        }
    ]
}

export function getOutputFromOutputDef(sourceModel, outputDef) {
    let output;
    switch (outputDef.type) {
        case outputTypes.NEURONS:
            if(outputDef.poolType){
                output = new PooledNeuronsOutput(sourceModel, outputDef);
            } else{
                output = new NeuronsOutput(sourceModel, outputDef);
            } 
            break;
        case outputTypes.LOGITS:
            output = new LogitOutput(sourceModel, outputDef);
            break;
        case outputTypes.GRADCAM:
            output = new GradCamOutput(sourceModel, outputDef);
            break;
        default:
            break;
    }
    return output;
}

/**
 * Represents an inference request. Receives output definition with descriptors
 * of individual inference outputs and input properties.
 * Pre-processes input, creates auxiliary model per output group with flattened output (for efficiency),
 * loads output data asynchronously and calls output callbacks.
 */
export class Inferencer {
    constructor(outputSpec, sourceModel) {
        this.outputSpec = outputSpec;
        this.sourceModel = sourceModel;
        this.auxModelsAndOutputs = this.getAuxModelsAndOutputs(outputSpec);
    }

    preProcessInput(flatData) {
        const [rb, rh, rw, rc] = this.outputSpec.inputInfo.rawInputShape;
        const [b, h, w, c] = this.outputSpec.inputInfo.inputShape;
        const [iMin, iMax] = this.outputSpec.inputInfo.rawInputRange;
        const [oMin, oMax] = this.outputSpec.inputInfo.rawOutputRange;
        const imageNetWhiten = this.outputSpec.inputInfo.imageNetWhiten;
        return preProcessRawUint8InputImage(
            flatData, {w:rw, h:rh}, {w:w, h:h}, {min:iMin, max:iMax}, {min:oMin, max:oMax}, imageNetWhiten);
    }

    infer(flatData, metaData, asynchronous) {
        this.outputSpec.outputGroups.forEach((outputGroup, i) => {
            const auxModel = this.auxModelsAndOutputs[i].model;
            const outputs = this.auxModelsAndOutputs[i].outputs;
            const ppInputTensor = this.preProcessInput(flatData);
            let outputTensors = auxModel.predict(ppInputTensor);
            if(outputs.length == 1) {
                outputTensors = [outputTensors];
            }
            let flatOutputTensors = [];
            outputs.forEach((output, j) => {
                const outputTensor = outputTensors[j];
                const {activationTensor, momentsTensor}  = output.getPostProcessedTensor(outputTensor);
                flatOutputTensors.push(activationTensor.flatten());
                if(momentsTensor) {
                    const {mean, variance} = momentsTensor;
                    flatOutputTensors.push(tf.concat([mean.reshape([1]), variance.reshape([1])]).flatten());
                }
            });
            const gradients = {};
            outputs.forEach((output, j) => {
                if (output.includeGradients) {
                    gradients[j] = output.getGradient(ppInputTensor);
                }
            })
            const concatenatedTensor = tf.concat(flatOutputTensors);
            const fulfill = data => {
                let arrayOffset = 0;
                outputs.forEach((output, j) => {
                    const ppData = output.getPostProcessedData(data, arrayOffset);
                    arrayOffset += output.getFlatLength();
                    if(j in gradients) {
                        ppData.gradients = gradients[j];
                    }
                    output.callback(ppData, metaData);
                });
            };
            if(asynchronous) {
                concatenatedTensor.data().then(fulfill);
            } else {
            const data = concatenatedTensor.dataSync();
                fulfill(data);
            }
        });
    }

    getAuxModelsAndOutputs(outputSpec) {
        let outArray = [];
        outputSpec.outputGroups.forEach((outputGroup, i) => {
            const groupOutputs = this.getGroupOutputs(outputGroup);
            const groupLayerOutputs = this.getLayerOutputsFromGroupOutputs(groupOutputs);
            const auxModel = tf.model({inputs: this.sourceModel.input, outputs: groupLayerOutputs});
            outArray.push({model: auxModel, outputs:groupOutputs});
        });
        return outArray;
    }

    getGroupOutputs(outputGroup) {
        let outputs = [];
        outputGroup.outputDefs.forEach((outputDef) => {
            const output = getOutputFromOutputDef(this.sourceModel, outputDef);
            outputs.push(output);
        });
        return outputs;
    }

    getLayerOutputsFromGroupOutputs(groupOutputs) {
        let layerOutputs = [];
        groupOutputs.forEach((output) => {
            layerOutputs.push(output.getLayerOutput());
        });
        return layerOutputs;
    }
}

class Output {
    constructor(sourceModel, outputDef) {
        this.sourceModel = sourceModel;
        this.layer = outputDef.layer;
        this.outputDef = outputDef;
        this.callback = outputDef.callback;
        this.layerOutput = this.getLayerOutput();
        this.includeGradients = outputDef.includeGradients && true;
        if(this.includeGradients) {
            const auxModel = tf.model({inputs: this.sourceModel.input, outputs:[this.layerOutput]});
            const outputFunc = (inputTensor) => {
                const predictionTensor = auxModel.apply(inputTensor).gather(0, -1);
                const ppTensor = this.getPostProcessedTensor(predictionTensor);
                const {activationTensor} = ppTensor;
                return activationTensor;
            }
    
            this.gradFun = tf.grad(outputFunc);
        }

        if (this.constructor === Output) {
            throw new TypeError('Abstract class "Output" cannot be instantiated directly.'); 
        }

        if (this.getPostProcessedTensor === undefined) {
            throw new TypeError('Classes extending the Output abstract class must implement getPostProcessedTensor'); 
        }

        if (this.getPostProcessedData === undefined) {
            throw new TypeError('Classes extending the Output abstract class must implement getPostProcessedData'); 
        }

        if(this.getFlatLength === undefined) {
            throw new TypeError('Classes extending the Output abstract class must implement getFlatLength'); 
        }
    }

    getGradient(inputTensor) {
        if(!this.includeGradients) {
            throw "This Output hasn't been initialized with included gradients.";
        }
        const grad = tf.tidy(() => {
            return this.gradFun(inputTensor);
        });
        const [b, h, w, c] = grad.shape;
        return {data:grad.dataSync(), w:w, h:h, shape:{w:w, h:h}};
    }

    getLayerOutput() {
        const layer = this.sourceModel.getLayer(this.layer);
        const layerOutput = layer.output;
        return layerOutput;
    }

    getOutputShape ()  {
        return this.layerOutput.shape;
    }

    getAllChannels() {
        const [b, h, w, c] = this.getOutputShape();
        let channels=[];
        for(let i=0;i<c;i++) {
            channels.push(i);
        }
        return channels;
    }

    getPoolOp(poolString) {
        if(poolString === 'avg') {
            return tf.mean;
        } else if(poolString === 'max') {
            return tf.max;
        }
        return tf.mean;
    }
}

export class NeuronsOutput extends Output {

    constructor(sourceModel, outputDef) {
        super(sourceModel, outputDef);
        this.channels = outputDef.channels === -1 ? this.getAllChannels() : outputDef.channels;
        this.includesMoments = outputDef.includeMoments;
        }

    getFlatLength () {
        const [b, h, w, c] = this.getOutputShape();
        return h*w*this.channels.length + (this.includesMoments ? 2 : 0);
    }

    getPostProcessedTensor (predictionTensor)  {
        const [b, h, w, c] = this.getOutputShape();
        const moments = tf.moments(predictionTensor);
        const outputChannelTensor = predictionTensor.gather(this.channels, 3);
        const transposedTensor = outputChannelTensor.transpose([1, 3, 2, 0])
                        .reshape([1, this.channels.length, w, h]);
        return {activationTensor: transposedTensor, momentsTensor: moments};
    }

    getPostProcessedData  (flatData, offset)  {
        const [b, h, w, c] = this.getOutputShape();
        const length = h*w*this.channels.length;

        const data = flatData.slice(offset, offset + length);
        if(this.includesMoments) {
            const moments = flatData.slice(offset + length, offset + length + 2);
            return {data:data, w:w, h:h, mean:moments[0], variance:moments[1]};
        }
        return {data:data, w:w, h:h};
    }
}

export class PooledNeuronsOutput extends NeuronsOutput {
    constructor(sourceModel, outputDef) {
        super(sourceModel, outputDef);
        const [b, h, w, c] = this.getOutputShape();
        this.poolWidth = getPoolWidthFromDefParam(outputDef.poolWidth, h, w);
        this.poolType = outputDef.poolType;
        this.includesMoments = outputDef.includeMoments;
    }

    getFlatLength () {// TODO: pooling other than 1 width doesn't work!
        return this.channels.length + (this.includesMoments ? 2 : 0);
    }

    getPostProcessedTensor (predictionTensor)  {
        const moments = tf.moments(predictionTensor);
        const [b, h, w, c] = this.getOutputShape();
        const outputChannelTensor = predictionTensor.gather(this.channels, 3);
        const {x, y} = getCenterRectCoordsFromWidth(this.poolWidth, {w:w, h:h});
        const slicedCenterTensor = outputChannelTensor.slice([0, y, x, 0], [1, this.poolWidth, this.poolWidth, this.channels.length]);
        const pooledTensor = this.getPoolOp(this.poolType)(slicedCenterTensor, [1,2]);

        const reshapedTensor = pooledTensor.reshape([this.channels.length]);
        return {activationTensor: reshapedTensor, momentsTensor: moments};
    }

    getPostProcessedData  (flatData, offset)  {
        const length = this.channels.length;
        const data = flatData.slice(offset, offset + length);
        if(this.includesMoments) {
            const moments = flatData.slice(offset + length, offset + length + 2);
            return {data:data, w:length, h:1, mean:moments[0], variance:moments[1]};
        }
        return {data:data, w:length, h:1};
    }
}

export class LogitOutput extends Output {

    constructor(sourceModel, outputDef) {
        super(sourceModel, outputDef);
        this.includesMoments = outputDef.includeMoments;
    }

    getFlatLength () {
        const [b, c] = this.getOutputShape();
        return c + (this.includesMoments ? 2 : 0);
    }

    getPostProcessedTensor (predictionTensor)  {
        const max = tf.max(predictionTensor.flatten().abs()).pow(2);
        const mean = tf.tensor([0]);
        return {activationTensor: predictionTensor, momentsTensor: {mean:mean, variance:max}};
    }

    getPostProcessedData  (flatData, offset)  {
        const [b, c] = this.getOutputShape();
        const length = c;
        const data = flatData.slice(offset, offset + length);
        if(this.includesMoments) {
            const moments = flatData.slice(offset + length, offset + length + 2);
            return {data:data, w:length, h:1, mean:moments[0], variance:moments[1]};
        }
        return {data:data, w:length, h:1};
    }
}

export class GradCamOutput extends Output {

    constructor(sourceModel, outputDef) {
        super(sourceModel, outputDef);
        this.includesMoments = outputDef.includeMoments;
        this.gradCamTarget = outputDef.gradCamTarget;
        this.gradCamLayer = outputDef.gradCamLayer;

        this.gradCamLayerOutput = this.getGradCamLayerOutput();
        const outputFunc = (featureMapTensor) => {
            const headFunctors = subgraphFunctors(
                this.sourceModel.getLayer(this.layer),
            this.sourceModel.getLayer(this.gradCamLayer));
            const outputTensor = headFunctors[this.gradCamTarget](featureMapTensor);
            console.log("prediction", outputTensor.dataSync())
            return outputTensor;
        }
        this.gradCamFun = tf.grad(outputFunc);
    }

    getOutputShape ()  {
        return this.getLayerOutput().shape;
    }

    getGradCamLayerOutput() {
        const gradCamLayer = this.sourceModel.getLayer(this.gradCamLayer);
        const gradCamLayerOutput = gradCamLayer.output;
        return gradCamLayerOutput;
    }

    getFlatLength () {
        const [b, h, w, c] = this.getOutputShape();
        return h*w + (this.includesMoments ? 2 : 0);
    }

    getPostProcessedTensor (predictionTensor)  {
        const grad = this.gradCamFun(predictionTensor);
        let neuronImportance = grad.mean([0, 1, 2]);
        let gradCam = predictionTensor.mul(neuronImportance);
        gradCam = gradCam.mean([3], true);
        const max = tf.max(gradCam.flatten().abs()).pow(2);
        const mean = tf.tensor([0]);
        return {activationTensor: gradCam, momentsTensor: {mean:mean, variance:max}};
    }

    getPostProcessedData  (flatData, offset)  {
        const [b, h, w, c] = this.getOutputShape();
        const length = h*w;
        const data = flatData.slice(offset, offset + length);
        if(this.includesMoments) {
            const moments = flatData.slice(offset + length, offset + length + 2);
            console.log("variance", moments[1])
            return {data:data, w:w, h:h, mean:moments[0], variance:moments[1]};
        }
        return {data:data, w:w, h:h};
    }
}

export class GradientOutput extends Output {

    constructor(sourceModel, outputDef) {
        super(sourceModel, outputDef);
    }

    getFlatLength () {
        throw("not implemented!");
    }

    getPostProcessedTensor (predictionTensor)  {
        throw("not implemented!");
    }

    getPostProcessedData  (flatData, offset)  {
        throw("not implemented!");
    }
}