import {Model, MixedModel} from './Model'
import * as tf from '@tensorflow/tfjs'
import {getArchFromDirName, getModelPathFromDirName} from './ModelMetaInfo'
import {transferWeights} from './ModelUtils'
import {create} from 'zustand'

function tempModifyWeights (model, layerName) {
    const l = model.sourceModel.getLayer(layerName);
    const [kernel_wts, biasTensor] = l.getWeights();
    const shp = kernel_wts.shape;
    const wtData = kernel_wts.dataSync();
    let override = new Float32Array(4000);
    override = tf.tensor(override).add(0).dataSync();
    wtData.set(override);
    const newKernelWtTensor = tf.tensor(wtData, shp);
    l.setWeights([newKernelWtTensor, biasTensor]);
    console.log("modified weights of layer "+l.name);
}

function transferWeightsTest(sourceModel, targetModel) {
    const modifyLayers = ["conv2d0", "conv2d1", "conv2d2"];
    modifyLayers.forEach((modLayer) => {
        sourceModel.getLayerNames().forEach((layerName) => {
            if (layerName.includes(modLayer)){
                transferWeights(sourceModel.sourceModel, targetModel.sourceModel, layerName);
            }
        });
    });
}

const [useStore] = create(set => ({
    loaded: {0: false, 1: false, 2: false},
    setLoaded: (index, status) => set(state => {
        const newLoaded = {...state.loaded};
        newLoaded[index] = status;
        return { loaded: newLoaded }
    })
  }));


export class ModelManager {
    constructor(modelDirs, capacity=2){
        this.models = []
        for(let i=0; i<capacity; i++) {
            this.models.push(null);
        }
        this.capacity = capacity;
        this.mixedModel = null;
        this.useStore = useStore;
        this.modelDirs = [modelDirs.model1, modelDirs.model2];
    }

    setModelDirs(modelDirs) {
        this.modelDirs = [modelDirs.model1, modelDirs.model2];
        this.modelDirs.forEach((modelDir, i) => {
            if(!this.models[i] || modelDir !== this.models[i].modelName) {
                if (this.models[i]) {
                    this.models[i].dispose()
                }
                this.setModel(i, ()=>{}, true);
            }
        });
    }

    getModel(index, forceInit=false, callback=()=>{}) {
        if(index == 2) {
            return this.getMixedModel(forceInit, callback);
        }
        const dirName = this.modelDirs[index];
        if(!this.models[index] || !(dirName === this.models[index].modelName)) {
            this.setModel(index, callback, forceInit);
        }
        else if(!this.models[index].isLoaded() && forceInit) {
            this.models[index].init(callback);
        } else {
            callback();
        }
        return this.models[index];
    }

    getMixedModel(forceInit=true, callback=()=>{}) {
        if(this.models[0].archType !== this.models[1].archType) {
            throw "can't mix models of different architecture!";
        }
        if(!this.mixedModel) {
            console.log("mm: new mixed model")
            this.mixedModel = new MixedModel(this.models[0].archType, this.models[0], this.models[1])
        }
        if(this.models[0] != this.mixedModel.model1 || this.models[1] != this.mixedModel.model2) {
            console.log("mm: setsourcemodels")
            this.mixedModel.setSourceModels(this.models[0], this.models[1]);
        }
        if(!this.mixedModel.isLoaded()) {
            console.log("mm: initing")
            this.mixedModel.init(callback);
        }
        return this.mixedModel;
    }

    setModel(index, callback, init) {
        const dirName = this.modelDirs[index];
        if(index > this.capacity-1) {
            throw "invalid index!";
        }
        const archType = getArchFromDirName(dirName);
        const modelPath = getModelPathFromDirName(dirName);
        this.models[index] = new Model(archType, modelPath);
        if (init) {
            this.models[index].init(callback);
        }
    }
}