import { imnetNameFromIndex, restrictedImnetNameFromIndex, inceptionNameFromIndex } from '../ImageNetUtil'


export const archId = {
    INCEPTION_V1: "Inception V1",
    RESNET_18: "Resnet 18"
}

export const modelTypes = {
    STD: "Standard",
    STYL: "Stylized",
    ADV: "Adv. trained"
}

export const modelDirNames = [
    'inception_v1_adv_000001',
    'inception_v1_adv_000010',
    'inception_v1_adv_000100',
    'inception_v1_adv_001000',
    'inception_v1_adv_003000',
    'inception_v1_adv_005000',
    'inception_v1_adv_007000',
    'inception_v1_adv_010000',
    'inception_v1_adv_011000',
    'inception_v1_adv_012000',
    'inception_v1_adv_013000',
    'inception_v1_adv_014000',
    'inception_v1_adv_015000',
    'inception_v1_adv_020000',
    'inception_v1_adv_025000',
    'inception_v1_adv_030000',
    'inception_v1_adv_040000',
    'inception_v1_adv_050000',
    'inception_v1_stylized_000001',
    'inception_v1_stylized_000010',
    'inception_v1_stylized_000100',
    'inception_v1_stylized_001000',
    'inception_v1_stylized_005000',
    'inception_v1_stylized_025000',
    'inception_v1_stylized_035000',
    'inception_v1_stylized_090000'
]

export function getModelPathFromDirName(dirName) {
    const base = process.env.PUBLIC_URL;
    if (dirName.includes('ckpt') || dirName.includes('model-')) {
        return base + "Resnet18/" + dirName + "/model.json";
    } else {
        return base + dirName + "/model.json";
    }
}

export function getArchFromDirName(dirName) {
    if (dirName.includes('ckpt') || dirName.includes('model-')) {
        return archId.RESNET_18;
    } else {
        return archId.INCEPTION_V1;
    }
}

export const getModelDirName = (modelType, iteration) => {
    if(modelType === modelTypes.STD) {
        return "inception_v1_adv_000001";
    }
    else if(modelType === modelTypes.STYL) {
        const iterations = getModelIterations(modelType);
        return "inception_v1_stylized_"+(iterations[iteration]+"").padStart(6, '0');
    } else if(modelType === modelTypes.ADV) {
        const iterations = getModelIterations(modelType);
        return "inception_v1_adv_"+(iterations[iteration]+"").padStart(6, '0');
    }
}

export const getModelIterations = (modelType) => {
    if(modelType === modelTypes.STD) {
        return [0];
    }
    let key = "";
    if(modelType === modelTypes.STYL) {
        key = "stylized";
    } else if(modelType === modelTypes.ADV) {
        key = "adv";
    }
    const iterations = [];
    modelDirNames.forEach((dirName) => {
        if(dirName.includes(key)) {
            const tokens = dirName.split('_');
            const it = parseInt(tokens[tokens.length-1]);
            iterations.push(it);
        }
    });
    return iterations;
}

const inceptionV1receptiveFieldDict = {
    'mixed3a': 43,
    'mixed3b': 59,
    'mixed4a': 107,
    'mixed4b': 139,
    'mixed4c': 171,
    'mixed4d': 203,
    'mixed4e': 235, 
    'mixed5a': 315,
    'mixed5b': 379
}

export const inceptionV1InterestingLayers = [
    'conv2d0',
    'conv2d1',
    'conv2d2',
    'mixed3a',
    'mixed3b',
    'mixed4a',
    'mixed4b',
    'mixed4c',
    'mixed4d',
    'mixed4e',
    'mixed5a',
    'mixed5b'
];

const resnet18receptiveFieldDict =  {
}

export const originalResnet18Layer2Layer = {
    'conv2d_1': 'conv0/W',
    'batch_normalization_1': 'conv0/bn/FusedBatchNormV3',

    'res0a_branch2a': 'group0/block0/conv1/W',
    'bn0a_branch2a': 'group0/block0/conv1/bn/FusedBatchNormV3',
    'res0a_branch2b': 'group0/block0/conv2/W',
    'bn0a_branch2b': 'group0/block0/conv2/bn/FusedBatchNormV3',
    'res0b_branch2a': 'group0/block1/conv1/W',
    'bn0b_branch2a': 'group0/block1/conv1/bn/FusedBatchNormV3',
    'res0b_branch2b': 'group0/block1/conv2/W',
    'bn0b_branch2b': 'group0/block1/conv2/bn/FusedBatchNormV3',

    'res1a_branch2a': 'group1/block0/conv1/W',
    'bn1a_branch2a': 'group1/block0/conv1/bn/FusedBatchNormV3',
    'res1a_branch2b': 'group1/block0/conv2/W',
    'bn1a_branch2b': 'group1/block0/conv2/bn/FusedBatchNormV3',
    'res1a_branch1': 'group1/block0/convshortcut/W',
    'bn1a_branch1': 'group1/block0/convshortcut/bn/FusedBatchNormV3',
    'res1b_branch2a': 'group1/block1/conv1/W',
    'bn1b_branch2a': 'group1/block1/conv1/bn/FusedBatchNormV3',
    'res1b_branch2b': 'group1/block1/conv2/W',
    'bn1b_branch2b': 'group1/block1/conv2/bn/FusedBatchNormV3',

    'res2a_branch2a': 'group2/block0/conv1/W',
    'bn2a_branch2a': 'group2/block0/conv1/bn/FusedBatchNormV3',
    'res2a_branch2b': 'group2/block0/conv2/W',
    'bn2a_branch2b': 'group2/block0/conv2/bn/FusedBatchNormV3',
    'res2a_branch1': 'group2/block0/convshortcut/W',
    'bn2a_branch1': 'group2/block0/convshortcut/bn/FusedBatchNormV3',
    'res2b_branch2a': 'group2/block1/conv1/W',
    'bn2b_branch2a': 'group2/block1/conv1/bn/FusedBatchNormV3',
    'res2b_branch2b': 'group2/block1/conv2/W',
    'bn2b_branch2b': 'group2/block1/conv2/bn/FusedBatchNormV3',

    'res3a_branch2a': 'group3/block0/conv1/W',
    'bn3a_branch2a': 'group3/block0/conv1/bn/FusedBatchNormV3',
    'res3a_branch2b': 'group3/block0/conv2/W',
    'bn3a_branch2b': 'group3/block0/conv2/bn/FusedBatchNormV3',
    'res3a_branch1': 'group3/block0/convshortcut/W',
    'bn3a_branch1': 'group3/block0/convshortcut/bn/FusedBatchNormV3',
    'res3b_branch2a': 'group3/block1/conv1/W',
    'bn3b_branch2a': 'group3/block1/conv1/bn/FusedBatchNormV3',
    'res3b_branch2b': 'group3/block1/conv2/W',
    'bn3b_branch2b': 'group3/block1/conv2/bn/FusedBatchNormV3',

    'dense_1': 'linear'
}

function getResnetLayerName(originalResnetLayerName) {
    if(! originalResnetLayerName.includes("group")) {
        return "";
    }
    const tokens = originalResnetLayerName.split("/")
    const group = parseInt(tokens[0][5]);
    const block = parseInt(tokens[1][5]);
    const convStr = tokens[2][4];
    const conv = parseInt(convStr);
    if(conv === 1) {
        return "layer_import_group"+group+"_block"+block+"_conv1_Relu";
    } else if (conv === 2) {
        return "layer_import_group"+group+"_block"+block+"_Relu";
    }
}

export function getNeuronThumbURL(model, layerName, neuron, usePriorVis=false) {
    let aId = model.archType;
    switch(aId){
        case archId.INCEPTION_V1:{
            const cleanLayer = layerName.split("_")[0];
            const modelPathTokens = model.modelPath.split("/")[0].split("_");
            const modelNum = parseInt(modelPathTokens[modelPathTokens.length-1]);
            let trainingType;
            if (model.modelPath.includes("adv")) {
                trainingType = "adv";
            } else if(model.modelPath.includes("stylized")) {
                trainingType = "stylized";
            }
            //let src = "https://storage.googleapis.com/fls/nickc/neuron_renders/" + cleanLayer + "_" + neuron + ".jpg";
            //const src = "https://openai-encyclopedia-public.storage.googleapis.com/production/2020-04-13/inceptionv1/lucid.feature_vis/feature_vis/"+
            //"alpha%3DFalse%26negative%3DFalse%26objective%3Dneuron%26op%3D"+cleanLayer+"%253A0%26steps%3D1024/channel-"+ neuron + ".png";
            const src = process.env.PUBLIC_URL + "neuron_thumbs/inceptionv1/"+trainingType+"/"+(usePriorVis&&(cleanLayer.includes("mixed"))?"fft/":"no_fft/")+"model-"+modelNum+"/"+cleanLayer+"_"+(neuron+'').padStart(3, '0')+".jpg";
            console.log(src)
            return src;
        }
        case archId.RESNET_18:{
            const modelPath = model.modelPath.split("/")[1];
            const origLayer = originalResnet18Layer2Layer[layerName];
            if (!origLayer) {
                return "";
            }
            let src;
            if(modelPath.includes("ckpt")) {
                src = process.env.PUBLIC_URL + "neuron_thumbs_resnet/trained/"+(usePriorVis?"with_priors/":"no_priors/")+modelPath+"/neuron_objective/"+
                getResnetLayerName(origLayer)+"/neuron_"+(neuron+'').padStart(4, '0')+".png";
            } else {
                let subdir;
                if(modelPath.includes("adv_to_std")) {
                    subdir = "adv_to_std/";
                } else if(modelPath.includes("std_to_adv")) {
                    subdir = "std_to_adv/";
                }
                src = process.env.PUBLIC_URL + "neuron_thumbs_resnet/transfer/"+subdir+(usePriorVis?"with_priors/":"no_priors/")+modelPath+"/neuron_objective/"+
                getResnetLayerName(origLayer)+"/neuron_"+(neuron+'').padStart(4, '0')+".png";
            }
            return src;
        }
        default:
            return "";
    }
}

export function getLayerImageURL(model, layerName, usePriorVis=false) {
    const aId = model.archType;
    const modelPath = model.modelPath.split("/")[1];
    switch(aId){
        case archId.INCEPTION_V1:{
            const cleanLayer = layerName.split("_")[0];
            const neuronImageDir = process.env.PUBLIC_URL+"layer_neuron_images/";
            return neuronImageDir + cleanLayer + ".jpg";
        }
        case archId.RESNET_18:{
            let neuronImageDir;
            if(modelPath.includes("ckpt")) {
                neuronImageDir = process.env.PUBLIC_URL+"layer_neuron_images_resnet/trained/"+
                (usePriorVis?"with_priors/":"no_priors/")+modelPath+"/";
            } else {
                let subdir;
                if(modelPath.includes("adv_to_std")) {
                    subdir = "adv_to_std/";
                } else if(modelPath.includes("std_to_adv")) {
                    subdir = "std_to_adv/";
                }
                neuronImageDir = process.env.PUBLIC_URL + "layer_neuron_images_resnet/transfer/"+subdir+(usePriorVis?"with_priors/":"no_priors/")+modelPath+"/";
            }
            const origLayer = originalResnet18Layer2Layer[layerName];
            return neuronImageDir + getResnetLayerName(origLayer) + ".jpg";
        }
    }
}

export function getClassExampleUrl(classIndex) {
    return process.env.PUBLIC_URL + "class_examples/" + "class_example_" + (""+classIndex).padStart(4, '0') + ".jpg";
}

export function getLayerReceptiveField (rfDict, layerName) {
    for(let [key, value] of Object.entries(rfDict)) {
        if(layerName.includes(key)) {
            return value;
        }
    }
    return 224;
}

export function getReceptiveFieldDict(aId) {
    switch(aId) {
        case archId.INCEPTION_V1:
            return inceptionV1receptiveFieldDict;
        case archId.RESNET_18:
            return resnet18receptiveFieldDict;
    }
}

export function getDefaultActivationParams(aId){
    switch(aId) {
        case archId.INCEPTION_V1:
            return [
                {
                    archId: archId.INCEPTION_V1, layer: 'mixed4a_pre_relu', neurons: [308],
                    showAllNeurons: true, normalize: 0, poolWidth: 1
                }];
        case archId.RESNET_18:
            return [
                {
                    archId: archId.RESNET_18, layer: 'bn3b_branch2b', neurons: [450, 437, 204, 256],
                    showAllNeurons: true, normalize: 0, poolWidth: 1
                }];
    }
} 

export function getDefaultNewActivationParams (aId) {
    switch(aId) {
        case archId.INCEPTION_V1:
            return { archId: archId.INCEPTION_V1, layer: 'conv2d0', neurons: [0, 1], normalize: 0, poolWidth: 1 } ;
        case archId.RESNET_18:
            return { archId: archId.RESNET_18, layer: 'bn0a_branch2a', neurons: [0, 1], normalize: 0, poolWidth: 1 } ;
    }
}

export function getClassName(aId, classIndex) {
    switch(aId) {
        case archId.INCEPTION_V1:
            return inceptionNameFromIndex(classIndex);
        case archId.RESNET_18:
            return restrictedImnetNameFromIndex(classIndex);
    }
}