//import * as tf from '@tensorflow/tfjs';
import { tf } from '../TFJS'
import { OutputManager } from './OutputManager'
import {
    getLayerReceptiveField, getReceptiveFieldDict, archId, getClassName,
    originalResnet18Layer2Layer, inceptionV1InterestingLayers
} from './ModelMetaInfo'
import { } from './ModelManager';
import {
    transferWeights, transferLayerWeightsByKeyword,
    mixWeights, mixLayerWeightsByKeyword, pruneLayerWeightsByKeyword
} from './ModelUtils'


function getModelUrl(aId) {
    switch (aId) {
        case archId.INCEPTION_V1:
            return process.env.PUBLIC_URL + '/inceptionv1/model.json';
        case archId.RESNET_18:
            return process.env.PUBLIC_URL + '/Resnet18/model.json';
        default:
            return "";
    }
}

export function getModel(archType, callback) {
    const model = new Model(archType, getModelUrl(archType));
    console.log("assigned model")
    model.init(callback);
    return model;
}

export class Model {

    constructor(archType, modelPath) {
        this.archType = archType;
        this.modelPath = modelPath;
        this.modelName = modelPath.split("/")[modelPath.split("/").length - 2];
        this.sourceModel = null;
        this.outputManager = new OutputManager();
        this.receptiveFieldLayer = "";
        this.isLoading = false;

        this.registeredCallbacks = [];
    }

    isLoaded() {
        return this.sourceModel != null;
    }

    processCallbacks() {
        console.log("processing callbacks")
        while (this.registeredCallbacks.length > 0) {
            const cb = this.registeredCallbacks.shift();
            cb();
        }
    }

    init(callback, strict = true) {
        this.registeredCallbacks.push(callback);
        if (this.isLoading) {
            console.log("model is still loading");
            return;
        }
        if (!this.isLoaded()) {
            console.log("model not loaded")
            const modelUrl = this.modelPath;
            console.log(modelUrl)
            if (!this.isLoading) {
                this.isLoading = true;
                tf.loadLayersModel(
                    modelUrl,
                    { strict: strict }, { onProgress: (fraction) => console.log("loadprogress: " + fraction) }).then(
                        (model) => { this.isLoading = false; this.sourceModel = model; console.log('mm: loaded', this.modelName);/*model.summary()*/; this.processCallbacks(); },
                        (reason) => { this.isLoading = false; console.log('rejected because', reason); this.processCallbacks(); });
            }
        }
        else {
            console.log("model already loaded")
            this.processCallbacks();
        }
    }

    getClassName(index) {
        return getClassName(this.archType, index);
    }

    getNumClasses() {
        switch (this.archType) {
            case archId.INCEPTION_V1:
                return 1000;
            case archId.RESNET_18:
                return 9;
            default:
                throw "invalid architecture!";
        }
    }

    getInputShape() {
        if (this.sourceModel) {
            return this.sourceModel.inputs[0].shape;
        }
    }

    getLayerNames = (onlyValid = false) => {
        let layerNames = [];
        if (this.isLoaded()) {
            this.sourceModel.layers.forEach((l) => {
                layerNames.push(l.name);
            })
            if (onlyValid) {
                switch (this.archType) {
                    case archId.INCEPTION_V1:
                        break;
                    case archId.RESNET_18:
                        layerNames = layerNames.filter((name) => {
                            return (name in originalResnet18Layer2Layer);
                        });
                }
            }
        }
        return layerNames;
    }

    setReceptiveFieldLayer(layerName) {
        this.receptiveFieldLayer = layerName;
    }

    getReceptiveField() {
        if (this.sourceModel && this.receptiveFieldLayer !== "") {

            const rf = {
                rfWidth: getLayerReceptiveField(
                    getReceptiveFieldDict(this.archType),
                    this.receptiveFieldLayer),
                fmWidth: this.sourceModel.getLayer(this.receptiveFieldLayer).output.shape[2],
                inputWidth: this.sourceModel.input.shape[2]
            }
            return rf;
        } else {
            return undefined;
        }
    }

    dispose() {
        this.outputManager.deregisterAll();
        if (this.sourceModel) {
            this.sourceModel.dispose();
            this.sourceModel = null;
        }
    }
}

export function testOutput(model) {
    const archType = model.archType;
    const sourceModel = model.sourceModel;
    if (!(archType === archId.INCEPTION_V1)) {
        console.log("not inception model");
        return;
    }
    const testImage = (id) => {
        const image = document.getElementById(id);
        let imagetensor = tf.browser.fromPixels(image);
        imagetensor = tf.expandDims(imagetensor.sub(117), 0);
        const conv2d0out = sourceModel.getLayer("conv2d0").output;
        const auxModel = tf.model({ inputs: sourceModel.inputs, outputs: [conv2d0out] });
        const convOut = auxModel.predict(imagetensor).flatten().arraySync();
        console.log("convout", convOut);
        const classOut = sourceModel.predict(imagetensor)[0].flatten().arraySync();
        let max = -Infinity;
        let maxInd = -1;
        for (let i = 0; i < classOut.length; i++) {
            if (classOut[i] > max) {
                max = classOut[i];
                maxInd = i;
            }
        }
        console.log("class: " + model.getClassName(maxInd), maxInd, max);
        console.log("classout", classOut);
    }
    testImage("test1");
}

export class MixedModel extends Model {
    constructor(archType, model1, model2) {
        const modelPath = "inceptionv1_no_weights/model.json";
        super(archType, modelPath);
        this.isMixedModel = true;
        this.model1 = model1;
        this.model2 = model2;
        this.layerMix = { all: 0.5 };
        this.weightLayers = [...inceptionV1InterestingLayers, 'softmax2'];
        this.weightLayers.forEach(layerKeyword => {
            this.layerMix[layerKeyword] = 0.5;
        });
        this.layerPruning = {all: [0, 1]};
        this.weightLayers.forEach(layerKeyword => {
            this.layerPruning[layerKeyword] = [0, 1];
        });
    }

    init(callback) {
        if (this.isLoaded()) {
            console.log("mm isloaded")
            this.updateWeights();
            callback();
            return;
        }
        /*if (!this.model1.isLoaded()) {
            this.model1.init(() => { this.init(callback) });
        }
        if (!this.model2.isLoaded()) {
            this.model2.init(() => { this.init(callback) });
        }*/
        if (this.model1.isLoaded() && this.model2.isLoaded()) {
            console.log("mm superinit")
            super.init(() => {
                this.updateWeights();
                callback();
            }, false);
        }
    }

    isLoaded() {
        return super.isLoaded() && this.model1.isLoaded() && this.model2.isLoaded();
    }

    getLayerMix(layerKeyword) {
        return this.layerMix[layerKeyword];
    }

    setLayerMix(layerKeyword, value) {
        this.layerMix[layerKeyword] = value;
        if (layerKeyword !== "all") {
            this.applyMixedWeights(layerKeyword);
        }
    }

    setLayerPruningRange(layerKeyword, value) {
        this.layerPruning[layerKeyword] = value;
        if (layerKeyword !== "all") {
            this.applyPruningRange(layerKeyword);
        }
    }

    applyPruningRange(layerKeyword) {
        this.applyMixedWeights(layerKeyword);
        pruneLayerWeightsByKeyword(this, layerKeyword, this.layerPruning[layerKeyword]);
    }

    setSourceModels(model1, model2) {
        this.model1 = model1;
        this.model2 = model2;
    }

    updateWeights() {
        if(this.isLoaded()) {
            console.log("updating weights")
            this.weightLayers.forEach(layerKeyword => {
                this.applyMixedWeights(layerKeyword);
            });
        }
    }

    applyMixedWeights(layerKeyword) {
        console.log("mix:", this.layerMix[layerKeyword])
        mixLayerWeightsByKeyword(this.model1, this.model2, this, layerKeyword, this.layerMix[layerKeyword]);
    }

    initWithModel1Weights() {
        const layerNames = this.model1.getLayerNames();
        layerNames.forEach((name) => {
            transferWeights(this.model1.sourceModel, this.sourceModel, name)
        });
    }
}