import * as tf from '@tensorflow/tfjs'
import {preProcessRawUint8InputImage, postProcessImageTensor} from './ModelUtils'

const NUM_ITERATIONS = 8;

export class PGDAttack {
    constructor(sourceModel, sourceImageData) {
        this.sourceModel = sourceModel;
        const {data, width, height} = sourceImageData;
        this.targetWidth = width;
        this.targetHeight = height;
        this.inMin = -117 // InceptionV1 Specific
        this.inMax = 138 // InceptionV1 Specific
        const [b, h, w, c] = sourceModel.inputs[0].shape;
        this.sourceImageTensor = this.imagePreProcess(data, width, height, w, h);
        this.logitModel = tf.model({
            inputs:this.sourceModel.inputs, outputs:[this.getLogitOutput()]});
        this.prediction = this.getPrediction();
        console.log(this.prediction);
        this.attackedTensor = tf.clone(this.sourceImageTensor);
    }

    getPrediction(imageTensor, logits) {
        if(!imageTensor) {
            imageTensor = this.sourceImageTensor;
        }
        if(!logits) {
            logits = this.getLogits(imageTensor);
        }
        const pred = logits.softmax().flatten().dataSync();

        const argArray = Array.from(pred).map((d, i) => [d, i]);
        const sorted = argArray.sort(([a1], [a2]) => a2-a1);
        const data = sorted.map(([d,]) => d);
        const inds = sorted.map(([,i]) => i);

        // not ideal: checks for initial prediction
        const origClsProb = this.prediction ? pred[this.prediction.cls] : data[0];
        return {prob:data[0], cls: inds[0], origClsProb: origClsProb};
    }

    getInputShape() {
        return this.sourceModel.inputs[0].shape;
    }

    doStep(eps, lp, targetClass) {
        const [ret, attacked] = tf.tidy(() => {
            const {attacked, grad} = this.getAdv(this.sourceImageTensor, eps, lp, targetClass)(this.attackedTensor);
            const pred = this.getPrediction(attacked, null);
    
            const postProcessed = postProcessImageTensor(attacked,
            {width:this.targetWidth, height:this.targetHeight});
            const [b, h, w, c] = postProcessed.shape;
            return [{data: new Uint8Array(postProcessed.dataSync()), w:w, h:h,
            startPred: this.prediction, currentPred: pred}, attacked];
        });
        this.attackedTensor = attacked;
        return ret;
    }

    getCurrent(alpha=1.0, origAlpha=1.0) {
        return tf.tidy(() => {
            let baseTensor = this.attackedTensor;
            if(!(alpha === 1.0 && origAlpha === 1.0)) {
                const diff = this.attackedTensor.sub(this.sourceImageTensor);
                if(origAlpha !== 1.0) {
                    baseTensor = tf.zerosLike(this.sourceImageTensor).mul(1.0-origAlpha).add(this.sourceImageTensor.mul(origAlpha)); 
                } else {
                    baseTensor = this.sourceImageTensor;
                }
                baseTensor = baseTensor.add(diff.mul(alpha));
            }
            baseTensor = tf.maximum(tf.minimum(baseTensor, 1.0), -1.0);
            const pred = this.getPrediction(baseTensor, null);
            const postProcessed = postProcessImageTensor(baseTensor,
            {width:this.targetWidth, height:this.targetHeight});
            const [b, h, w, c] = postProcessed.shape;
            return {data: new Uint8Array(postProcessed.dataSync()), w:w, h:h,
            startPred: this.prediction, currentPred: pred};
        });

    }

    getAdv(x, eps, lp, targetClass) {
        return tf.tidy(() => {
            let dir = 1;
            if(targetClass === -1){
                targetClass = this.prediction.cls;
                dir = -1;
            }

            const y = targetClass;

            const stepSize = eps/NUM_ITERATIONS * 2;

            let adv, body

            const stopGradient = tf.customGrad((x_, save) => {
                save([x_]);
                return {
                value: x_.add(0), // if we'd just return x, the gradient override would not be used
                gradFunc: (dy, saved) => [tf.zerosLike(saved[0])]
                };
            });
            
            if(lp==="2") {
                const normDivisor = (v) => {
                    const norm = tf.norm(v, 2);
                    return norm;
                }
        
                const l2LinfProject = (v) => {
                    const clipped = tf.clipByValue(v, -1, 1);
                    const diff = clipped.sub(x);
                    const norm = normDivisor(diff);
                    const normalized = diff.div(norm).mul(tf.minimum(eps, norm));
                    return x.add(normalized);
                }
        
                let randomPoint = tf.randomNormal(x.shape);
                randomPoint = randomPoint.div(normDivisor(randomPoint));
                adv = l2LinfProject(x.add(randomPoint.mul(eps)));
        
                body = (adv_) => {
                    const loss = (x_) => {
                        const logits = this.getLogits(x_);
                        return this.getLossFromLogits(logits, y);
                    }
                    let grad = tf.grad(loss)(adv_);
                    grad = grad.div(normDivisor(grad))
                    adv_ = stopGradient(l2LinfProject(adv_.add(grad.mul(stepSize).mul(dir))));
                    return {attacked:stopGradient(adv_), grad:grad};
                }
            } else if(lp==="inf") {
                const unif = tf.randomUniform(x.shape, -eps, eps);
                adv = tf.clipByValue(x.add(unif), -1, 1);
                const linfProject = (v) => {
                    v = tf.clipByValue(v, -1, 1);
                    v = tf.minimum(tf.maximum(v, x.add(-eps)), x.add(eps));
                    return v;
                }

                body = (adv_) => {
                    const loss = (x_) => {
                        const logits = this.getLogits(x_);
                        return this.getLossFromLogits(logits, y);
                    }
                    let grad = tf.grad(loss)(adv_);
                    grad = tf.sign(grad);
                    adv_ = stopGradient(linfProject(adv_.add(grad.mul(stepSize).mul(dir))));
                    return {attacked:stopGradient(adv_), grad:grad};
                }
            }

            return body;
        });
    }

    getLossFromLogits(logits, y) {
        const flatLogits = logits.flatten();
        const logitData = flatLogits.dataSync();
        let oneHot = new Int32Array(flatLogits.shape[0]);
        oneHot[y] = 1;
        oneHot = tf.tensor(oneHot);
        const loss = tf.losses.softmaxCrossEntropy(oneHot, flatLogits).mul(-1);
        console.log("loss", loss.dataSync());
        return loss;
    }

    getLogits(x) {
        const correctRangeImage = x.add(1).div(2).mul(255).add(this.inMin);
        return this.logitModel.predict(correctRangeImage);
    }

    getLogitOutput() {
        const logits = this.sourceModel.outputs[0].inputs[0];
        return logits;
    }

    imagePreProcess(flatData, inWidth, inHeight, modelInWidth, modelInHeight) {
        return preProcessRawUint8InputImage(
            flatData, {w:inWidth, h:inHeight}, {w:modelInWidth, h:modelInHeight}, {min:0, max:255}, {min:-1, max:1});
    }
}