import {tf} from '../TFJS'

export function preProcessRawUint8InputImage (
    data, fromSize, toSize={w:224, h:224},
    inValRange={min:0, max:255}, outValRange={min:-1, max:1}, imageNetWhiten=false) {
    const getDimArray = (whObject) => {
        const {w, h} = whObject;
        return [w, h];
    }
    const [fW, fH] = getDimArray(fromSize);
    const [tW, tH] = getDimArray(toSize);

    const getRangeArray = (mmObject) => {
        const {min,max} = mmObject;
        return [min, max];
    }
    const [iMin, iMax] = getRangeArray(inValRange);
    const [oMin, oMax] = getRangeArray(outValRange);
    const iScale = iMax - iMin;
    const oScale = oMax - oMin;

    let img = tf.tensor(data,
        [1, fH, fW, 4], 'float32');
    img = tf.image.resizeBilinear(img, [tW, tH]).sub(iMin).div(iScale).mul(oScale).add(oMin);
    img = tf.reverse(img, 1);
    img = img.slice([0, 0, 0, 0], [1, tH, tW, 3]);
    if(imageNetWhiten) {
        if(outValRange.min !== 0 || outValRange.max !== 1) {
            throw "invalid output range for whitening!"
        }
        const mean = tf.tensor([0.485, 0.456, 0.406]);
        const std = tf.tensor([0.229, 0.224, 0.225]);
        img = img.sub(mean).div(std);
    }
    return img;
}

export function postProcessImageTensor(imageTensor, fromSize, inValRange={min:-1, max:1}, outValRange={min:0, max:255}) {
    const {width, height} = fromSize;
    const [b, h, w, c] = imageTensor.shape;
    const alpha = tf.fill([b, h, w, 1], 1.0);
    const rgba = tf.concat([imageTensor, alpha], 3);

    const getRangeArray = (mmObject) => {
        const {min,max} = mmObject;
        return [min, max];
    }
    const resized = tf.image.resizeBilinear(rgba, [width, height]);
    const [iMin, iMax] = getRangeArray(inValRange);
    const [oMin, oMax] = getRangeArray(outValRange);
    const iScale = iMax - iMin;
    const oScale = oMax - oMin;
    const rangeCorrected = resized.sub(iMin).div(iScale).mul(oScale).add(oMin);
    return rangeCorrected.asType('int32');
}

export function getCenterNeuronCoords(inputShape={w:32, h:32}) {
    const {w, h} = inputShape;
    const cx = Math.ceil((w-1) / 2.0);
    const cy = Math.ceil((h-1) / 2.0);
    return {x:cx, y:cy};
}

export function getCenterRectCoordsFromWidth(width, inputShape={w:32, h:32}) {
    const {w, h} = inputShape;
    if(width === -1) {
        return {x:0, y: 0, w:w, h:h};
    }
    const cx = (w-1) / 2.0;
    const cy = (h-1) / 2.0;
    const leftX = Math.ceil(cx+0.5-width/2.0);
    const topY = Math.ceil(cy+0.5-width/2.0);
    return {x:leftX, y:topY, w:width, h:width};
}

export function transferWeights (sourceModel, targetModel, layerName) {
    const sourceLayer = sourceModel.getLayer(layerName);
    const targetLayer = targetModel.getLayer(layerName);
    const weights = sourceLayer.getWeights();
    if(!weights) return;

    targetLayer.setWeights(weights);
    console.log("modified weights of layer "+layerName);
}

export function mixWeights (sourceModel1, sourceModel2, targetModel, layerName, mixRatio) {
    const sourceLayer1 = sourceModel1.getLayer(layerName);
    const sourceLayer2 = sourceModel2.getLayer(layerName);
    const targetLayer = targetModel.getLayer(layerName);
    const weights1 = sourceLayer1.getWeights();
    const weights2 = sourceLayer2.getWeights();
    if(!weights1 || !weights2) return;

    const newWeights = weights1.map((w1, i) => {
        const w2 = weights2[i];
        return w1.mul(1-mixRatio).add(w2.mul(mixRatio));
    });

    targetLayer.setWeights(newWeights);
    //console.log("mixed weights of layer "+layerName);
}

export function transferLayerWeightsByKeyword(sourceModel, targetModel, layerNameKeyword) {
    sourceModel.getLayerNames().forEach((layerName) => {
        if (layerName.includes(layerNameKeyword)){
            transferWeights(sourceModel.sourceModel, targetModel.sourceModel, layerName);
        }
    });
}

export function mixLayerWeightsByKeyword(sourceModel1, sourceModel2, targetModel, layerNameKeyword, mixRatio) {
    sourceModel1.getLayerNames().forEach((layerName) => {
        if (layerName.includes(layerNameKeyword)){
            mixWeights(sourceModel1.sourceModel, sourceModel2.sourceModel, targetModel.sourceModel, layerName, mixRatio);
        }
    });
}

export function pruneLayer(sourceModel, layerName, pruneRange) {
    const [pmin, pmax] = pruneRange;

    const layer = sourceModel.getLayer(layerName);
    const weights = layer.getWeights();
    const newWeights = [...weights];
    if(weights.length > 0) { // not a pooling layer or other l. without params
        //not interested in bias (for now?)
        const wts = weights[0];
        //conv layer
        let inds, data;
        if(wts.shape.length === 4){
            // spatial + input dimension average, conv wt kernel shape is [h, w, i, o]
            let avgd = tf.mean(wts.abs(), [0, 1, 2]);
            avgd = avgd.dataSync();
            
            const argArray = Array.from(avgd).map((d, i) => [d, i]);
            const sorted = argArray.sort(([a1], [a2]) => a1-a2);
            data = sorted.map(([d,]) => d);
            inds = sorted.map(([,i]) => i);
        } else if(wts.shape.length === 2) { // fc layer
            // input dimension average, conv wt kernel shape is [i, o]
            let avgd = tf.mean(wts.abs(), [0]);
            avgd = avgd.dataSync();
            
            const argArray = Array.from(avgd).map((d, i) => [d, i]);
            const sorted = argArray.sort(([a1], [a2]) => a1-a2);
            data = sorted.map(([d,]) => d);
            inds = sorted.map(([,i]) => i);
        }
        const mask = inds.map((ind, i) => 1);
        inds.forEach((ind, i) => {
            if(i < pmin*(inds.length-1) || i > pmax*(inds.length-1)) {
                mask[ind] = 0;
            }
        });
        console.log("prunemask", mask, inds, data)
        const maskTensor = tf.tensor1d(mask);
        const newWts = wts.mul(maskTensor);
        newWeights[0] = newWts;
    }

    layer.setWeights(newWeights);
}

export function pruneLayerWeightsByKeyword(model, layerNameKeyword, pruneRange) {
    model.getLayerNames().forEach((layerName) => {
        if (layerName.includes(layerNameKeyword)){
            pruneLayer(model.sourceModel, layerName, pruneRange);
        }
    });
}