import * as tf from '@tensorflow/tfjs'

export function subgraphFunctors(inputLayer, outputLayer) {
    const outChannels = outputLayer.output.shape[outputLayer.output.shape.length-1];
    const channelFunctors = [];
    for(let i=0; i<outChannels; i++) {
        const functor = (x) => {
            const outputs = {};
            outputs[inputLayer.name] = x;
        
            const processLayer = (layer) => {
                const inputs = [];
                layer.inboundNodes[0].inboundLayers.forEach((inboundLayer) => {
                    if(inboundLayer.name in outputs) {
                        inputs.push(outputs[inboundLayer.name]);
                    }
                });
                if(inputs.length === layer.inboundNodes[0].inboundLayers.length) {
                    if(false && (layer.name.indexOf("re_lu") !== -1 ||
                    (layer.name.indexOf("mixed") !== -1 && layer.name.indexOf("_") === -1))) {
                        outputs[layer.name] = inputs[0];
                    } else if ( false && layer.name.indexOf("max_pool") !== -1) {
                        outputs[layer.name] = tf.avgPool(inputs[0], layer.poolSize, layer.strides, layer.padding);
                    } else {
                        outputs[layer.name] = layer.apply(inputs);
                    }
                    if(layer.name === outputLayer.name) {
                        return;
                    }
                    layer.outboundNodes.forEach((outNode) => {
                        processLayer(outNode.outboundLayer);
                    });
                }
            }
        
            inputLayer.outboundNodes.forEach((outNode) => {
                processLayer(outNode.outboundLayer);
            });
    
            const outp = outputs[outputLayer.name];
            if(outp.shape.length === 4) {
                const [b, h, w, d] = outp.shape;
                const cy = Math.floor(h/2);
                const cx = Math.floor(w/2);
                const out = tf.slice(outp, [0, cy, cx, i]);
                return out;
            } else if(outp.shape.length === 2) {
                const [b, d] = outp.shape;
                const out = tf.gather(outp, [i], [1]);
                return out;
            }
        }
        channelFunctors.push(functor);
    }
    return channelFunctors;
}

export function getAttribution(model, layerName1, layerName2, neuron=0) {
    return tf.tidy(() => {
        const layer1 = model.sourceModel.getLayer(layerName1);
        const layer2 = model.sourceModel.getLayer(layerName2);

        const functors = subgraphFunctors(layer1, layer2);
        const grad = tf.grad(functors[neuron]);
        
        /*grads.forEach((grad) => {
            grad(tf.ones(layer1.output.shape).mul(0.5));      
        });*/
        const inputShape = [...layer1.output.shape];
        inputShape[0] = 1;
        const input = tf.truncatedNormal(inputShape, 0, 25).mul(1);
        let channelAttribs = tf.sum(grad(input), [0,1,2]);
        const {mean, variance} = tf.moments(channelAttribs);
        channelAttribs = channelAttribs.div(variance.pow(0.5).mul(2));
        return channelAttribs.abs();
    }); 
}

export function getGradCam(model, layerName, outClass=0) {
    return tf.tidy(() => {
        const layer1 = model.sourceModel.getLayer(layerName);

    });
}