import {
    tf
} from '../../TFJS'

import {
    jitter,
    fixedScale,
    compose,
    standardTransforms
} from '../../LucidJS/optvis/transform';
import {
    naiveFromImage,
    imgLaplacianPyramid,
    randLaplacianPyramid
} from '../../LucidJS/optvis/param/image';
import {
    channel,
    deepdream,
    neuron,
    spatial,
    output,
    activationModification,
    style
} from '../../LucidJS/optvis/objectives';

export function getInput(inputSize, pyramidLayers, decorrelate) {
    const w = inputSize;
    const h = inputSize;
    const ch = 3;
    const pyrL = pyramidLayers;


    const [pyramidF, trainable] = randLaplacianPyramid(w, h,
        ch, 1, 0.01, decorrelate, pyrL);
    return [(trainable) => pyramidF(trainable), trainable];
}

export function getTransformF() {
    let transforms = [jitter(5)];
    return compose(transforms);
}

export function getNeuronObjective(model, layer, n, inputWidth) {
    const downSample = inputWidth / model.input.shape[1];
    const options = {
        layer: layer,
        channel: n
    }
    return neuron(model, options, downSample);
}

export function getChannelObjective(model, layer, ch) {
    const options = {
        layer: layer,
        channel: ch
    }
    return channel(model, options);
}

export function getLoss(trainable, paramF, objectiveF, transformF) {
    const loss = tf.tidy(() => {
        const ret = objectiveF(transformF(paramF(trainable)));
        return ret;
      });
    return loss;
}

export function deprocessImage(x) {
    return tf.tidy(() => {
        const max = x.max();
        const min = x.min();
        const range = max.sub(min);
        x = x.sub(min).div(range);
        // Add a small positive number (EPSILON) to the denominator to prevent
        // division-by-zero.
        // Clip to [0, 1].
        //x = x.add(0.5);
        x = tf.clipByValue(x, 0, 1);
        x = x.mul(255);
        return tf.clipByValue(x, 0, 255).asType('float32');
    });
}

/**
 * Writes rgb pixel array into rgba canvas pixel array.
 * @param {*} canvasPixels rgba pixel array
 * @param {*} rgbData rgb pixel array
 * @param {*} width 
 * @param {*} height 
 * @param {*} channel channel offset
 * @param {*} mult multiply src pixel data (e.g. to convert from normalized 0-1 to 0-255)
 */
export function fillCanvasPixelsWithRgbAndAlpha(canvasPixels, rgbData, width, height, channel = 0, mult = 1) {
    const cOffset = channel * width * height * 3;
    for (let x = 0; x < width; x++) {
        for (let y = 0; y < height; y++) {
            const rgbaInd = (y * width + x) * 4;
            const rgbInd = cOffset + (y * width + x) * 3;
            canvasPixels[rgbaInd] = rgbData[rgbInd] * mult;
            canvasPixels[rgbaInd + 1] = rgbData[rgbInd + 1] * mult;
            canvasPixels[rgbaInd + 2] = rgbData[rgbInd + 2] * mult;
            canvasPixels[rgbaInd + 3] = 255;
        }
    }
}

/**
 * Fills canvas with pixel data.
 * @param {*} pixelData Raw pixel array
 * @param {*} canvas Canvas to draw on to
 * @param {*} w width of pixel data
 * @param {*} h height of pixel data
 * @param {boolean} rgb rgb or greyscale status of pixel data
 * @param {*} channel channel to take from pixel data in case multiple channels are contained
 */
export function drawPixelsToCanvas(pixelData, canvas, w, h, channel = 0) {
    let canvCtx = canvas.getContext("2d");
    const cw = canvas.width;
    const ch = canvas.height;
    const wratio = cw / w;
    const hratio = ch / h;

    if (w !== cw || h !== ch) {
        let tempCanvas = document.createElement("canvas");
        tempCanvas.width = w;
        tempCanvas.height = h;
        let tempCtx = tempCanvas.getContext('2d');
        let imData = tempCtx.createImageData(w, h);

        fillCanvasPixelsWithRgbAndAlpha(imData.data, pixelData, w, h, channel);

        tempCtx.putImageData(imData, 0, 0);
        canvCtx.scale(wratio, hratio);
        canvCtx.drawImage(tempCanvas, 0, 0);
        canvCtx.scale(1 / wratio, 1 / hratio);
    } else {
        let imData = canvCtx.createImageData(w, h);

        fillCanvasPixelsWithRgbAndAlpha(imData.data, pixelData, w, h, channel);

        canvCtx.putImageData(imData, 0, 0);
    }
}