/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.utils;

import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.OrthogonalDistribution;
import org.deeplearning4j.nn.conf.distribution.TruncatedNormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.deeplearning4j.nn.weights.WeightInitIdentity;
import org.deeplearning4j.nn.weights.WeightInitVarScalingNormalFanAvg;
import org.deeplearning4j.nn.weights.WeightInitVarScalingNormalFanIn;
import org.deeplearning4j.nn.weights.WeightInitVarScalingNormalFanOut;
import org.deeplearning4j.nn.weights.WeightInitVarScalingUniformFanAvg;
import org.deeplearning4j.nn.weights.WeightInitVarScalingUniformFanIn;
import org.deeplearning4j.nn.weights.WeightInitVarScalingUniformFanOut;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasInitilizationUtils {
    private static final Logger log = LoggerFactory.getLogger(KerasInitilizationUtils.class);

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static IWeightInit mapWeightInitialization(String kerasInit, KerasLayerConfiguration conf, Map<String, Object> initConfig, int kerasMajorVersion) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        double scale;
        if (kerasInit == null) throw new IllegalStateException("Error getting Keras weight initialization");
        if (kerasInit.equals(conf.getINIT_GLOROT_NORMAL()) || kerasInit.equals(conf.getINIT_GLOROT_NORMAL_ALIAS())) {
            return WeightInit.XAVIER.getWeightInitFunction();
        }
        if (kerasInit.equals(conf.getINIT_GLOROT_UNIFORM()) || kerasInit.equals(conf.getINIT_GLOROT_UNIFORM_ALIAS())) {
            return WeightInit.XAVIER_UNIFORM.getWeightInitFunction();
        }
        if (kerasInit.equals(conf.getINIT_LECUN_NORMAL()) || kerasInit.equals(conf.getINIT_LECUN_NORMAL_ALIAS())) {
            return WeightInit.LECUN_NORMAL.getWeightInitFunction();
        }
        if (kerasInit.equals(conf.getINIT_LECUN_UNIFORM()) || kerasInit.equals(conf.getINIT_LECUN_UNIFORM_ALIAS())) {
            return WeightInit.LECUN_UNIFORM.getWeightInitFunction();
        }
        if (kerasInit.equals(conf.getINIT_HE_NORMAL()) || kerasInit.equals(conf.getINIT_HE_NORMAL_ALIAS())) {
            return WeightInit.RELU.getWeightInitFunction();
        }
        if (kerasInit.equals(conf.getINIT_HE_UNIFORM()) || kerasInit.equals(conf.getINIT_HE_UNIFORM_ALIAS())) {
            return WeightInit.RELU_UNIFORM.getWeightInitFunction();
        }
        if (kerasInit.equals(conf.getINIT_ONE()) || kerasInit.equals(conf.getINIT_ONES()) || kerasInit.equals(conf.getINIT_ONES_ALIAS())) {
            return WeightInit.ONES.getWeightInitFunction();
        }
        if (kerasInit.equals(conf.getINIT_ZERO()) || kerasInit.equals(conf.getINIT_ZEROS()) || kerasInit.equals(conf.getINIT_ZEROS_ALIAS())) {
            return WeightInit.ZERO.getWeightInitFunction();
        }
        if (kerasInit.equals(conf.getINIT_UNIFORM()) || kerasInit.equals(conf.getINIT_RANDOM_UNIFORM()) || kerasInit.equals(conf.getINIT_RANDOM_UNIFORM_ALIAS())) {
            if (kerasMajorVersion == 2) {
                double minVal = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_MINVAL());
                double maxVal = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_MAXVAL());
                return new WeightInitDistribution((Distribution)new UniformDistribution(minVal, maxVal));
            }
            double scale2 = 0.05;
            if (!initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) return new WeightInitDistribution((Distribution)new UniformDistribution(-scale2, scale2));
            scale2 = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
            return new WeightInitDistribution((Distribution)new UniformDistribution(-scale2, scale2));
        }
        if (kerasInit.equals(conf.getINIT_NORMAL()) || kerasInit.equals(conf.getINIT_RANDOM_NORMAL()) || kerasInit.equals(conf.getINIT_RANDOM_NORMAL_ALIAS())) {
            if (kerasMajorVersion == 2) {
                double mean = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
                double stdDev = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
                return new WeightInitDistribution((Distribution)new NormalDistribution(mean, stdDev));
            }
            double scale3 = 0.05;
            if (!initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) return new WeightInitDistribution((Distribution)new NormalDistribution(0.0, scale3));
            scale3 = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
            return new WeightInitDistribution((Distribution)new NormalDistribution(0.0, scale3));
        }
        if (kerasInit.equals(conf.getINIT_CONSTANT()) || kerasInit.equals(conf.getINIT_CONSTANT_ALIAS())) {
            double value = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_VALUE());
            return new WeightInitDistribution((Distribution)new ConstantDistribution(value));
        }
        if (kerasInit.equals(conf.getINIT_ORTHOGONAL()) || kerasInit.equals(conf.getINIT_ORTHOGONAL_ALIAS())) {
            if (kerasMajorVersion == 2) {
                double gain;
                try {
                    gain = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
                    return new WeightInitDistribution((Distribution)new OrthogonalDistribution(gain));
                }
                catch (Exception e) {
                    gain = ((Integer)initConfig.get(conf.getLAYER_FIELD_INIT_GAIN())).intValue();
                }
                return new WeightInitDistribution((Distribution)new OrthogonalDistribution(gain));
            }
            double scale4 = 1.1;
            if (!initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) return new WeightInitDistribution((Distribution)new OrthogonalDistribution(scale4));
            scale4 = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
            return new WeightInitDistribution((Distribution)new OrthogonalDistribution(scale4));
        }
        if (kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL()) || kerasInit.equals(conf.getINIT_TRUNCATED_NORMAL_ALIAS())) {
            double mean = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_MEAN());
            double stdDev = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_STDDEV());
            return new WeightInitDistribution((Distribution)new TruncatedNormalDistribution(mean, stdDev));
        }
        if (kerasInit.equals(conf.getINIT_IDENTITY()) || kerasInit.equals(conf.getINIT_IDENTITY_ALIAS())) {
            if (kerasMajorVersion == 2) {
                double gain = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_GAIN());
                if (gain == 1.0) throw new IllegalStateException("Error getting Keras weight initialization");
                if (gain == 1.0) return new WeightInitIdentity();
                return new WeightInitIdentity(Double.valueOf(gain));
            }
            double scale5 = 1.0;
            if (initConfig.containsKey(conf.getLAYER_FIELD_INIT_SCALE())) {
                scale5 = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
            }
            if (scale5 == 1.0) return new WeightInitIdentity();
            return new WeightInitIdentity(Double.valueOf(scale5));
        }
        if (!kerasInit.equals(conf.getINIT_VARIANCE_SCALING())) throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + kerasInit);
        try {
            scale = (Double)initConfig.get(conf.getLAYER_FIELD_INIT_SCALE());
        }
        catch (Exception e) {
            scale = ((Integer)initConfig.get(conf.getLAYER_FIELD_INIT_SCALE())).intValue();
        }
        String mode = (String)initConfig.get(conf.getLAYER_FIELD_INIT_MODE());
        String distribution = (String)initConfig.get(conf.getLAYER_FIELD_INIT_DISTRIBUTION());
        switch (mode) {
            case "fan_in": {
                if (!distribution.equals("normal")) return new WeightInitVarScalingUniformFanIn(Double.valueOf(scale));
                return new WeightInitVarScalingNormalFanIn(Double.valueOf(scale));
            }
            case "fan_out": {
                if (!distribution.equals("normal")) return new WeightInitVarScalingUniformFanOut(Double.valueOf(scale));
                return new WeightInitVarScalingNormalFanOut(Double.valueOf(scale));
            }
            case "fan_avg": {
                if (!distribution.equals("normal")) return new WeightInitVarScalingUniformFanAvg(Double.valueOf(scale));
                return new WeightInitVarScalingNormalFanAvg(Double.valueOf(scale));
            }
        }
        throw new InvalidKerasConfigurationException("Initialization argument 'mode' has to be either fan_in, fan_out or fan_avg");
    }

    public static IWeightInit getWeightInitFromConfig(Map<String, Object> layerConfig, String initField, boolean enforceTrainingConfig, KerasLayerConfiguration conf, int kerasMajorVersion) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        IWeightInit init;
        HashMap initMap;
        String kerasInit;
        HashMap innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (!innerConfig.containsKey(initField)) {
            throw new InvalidKerasConfigurationException("Keras layer is missing " + initField + " field");
        }
        if (kerasMajorVersion != 2) {
            kerasInit = (String)innerConfig.get(initField);
            initMap = innerConfig;
        } else {
            HashMap fullInitMap = (HashMap)innerConfig.get(initField);
            initMap = (HashMap)fullInitMap.get("config");
            if (fullInitMap.containsKey("class_name")) {
                kerasInit = (String)fullInitMap.get("class_name");
            } else {
                throw new UnsupportedKerasConfigurationException("Incomplete initialization class");
            }
        }
        try {
            init = KerasInitilizationUtils.mapWeightInitialization(kerasInit, conf, initMap, kerasMajorVersion);
        }
        catch (UnsupportedKerasConfigurationException e) {
            if (enforceTrainingConfig) {
                throw e;
            }
            init = new WeightInitXavier();
            log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead).");
        }
        return init;
    }
}

