import {tf} from '../../TFJS'
import {drawPixelsToCanvas} from './FeatureVisUtil'


export class FeatureVisHandler {
    constructor(lossF, deprocessF, variables, canvasRef, resolution) {
        this.lossF = lossF;
        this.deprocessF = deprocessF;
        this.variables = variables;
        this.optimizer = new tf.train.adam(0.05);
        this.canvasRef = canvasRef;
        this.resolution = resolution;
        this.shouldForceStop = false;
        this.animFrameRequest = undefined;
    }

    drawCurrentVarsToCanvas() {
        const pixelData = tf.tidy(() => {
            return this.deprocessF(this.variables).dataSync();
        });
        drawPixelsToCanvas(pixelData, this.canvasRef.current, this.resolution, this.resolution);
    }

    doSteps = (iterations=1) => {
        if (this.shouldForceStop || iterations <= 0) {
            this.shouldForceStop = false;
            return;
        }
        this.iterations = iterations;
        this.optimizer.minimize(this.lossF(this.variables), true, this.variables);
        this.drawCurrentVarsToCanvas();
        this.animFrameRequest = requestAnimationFrame(() => {this.doSteps(iterations-1)});
    }

    forceStop = () => {
        this.shouldForceStop = true;
    }

    reset = (newVariables) => {
        if(this.animFrameRequest) {
            cancelAnimationFrame(this.animFrameRequest);
        }
        this.shouldForceStop = false;
        this.variables.forEach(variable => {
            variable.dispose();
        });
        this.optimizer = new tf.train.adam(0.05);
        this.variables = newVariables;
        this.drawCurrentVarsToCanvas();
        this.animFrameRequest = requestAnimationFrame(() => {});
    }
}