package org.deeplearning4j.nn.modelimport.keras;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfigurationFactory;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/KerasLayer.class */
public class KerasLayer {
    private static final String LAYER_FIELD_KERAS_VERSION = "keras_version";
    protected String className;
    protected String layerName;
    protected int[] inputShape;
    protected DimOrder dimOrder;
    protected List<String> inboundLayerNames;
    protected List<String> outboundLayerNames;
    protected Layer layer;
    protected GraphVertex vertex;
    protected Map<String, INDArray> weights;
    protected double weightL1Regularization;
    protected double weightL2Regularization;
    protected double dropout;
    protected Integer kerasMajorVersion;
    protected KerasLayerConfiguration conf;
    protected Map<String, Object> originalLayerConfig;
    private static final Logger log = LoggerFactory.getLogger(KerasLayer.class);
    static final Map<String, Class<? extends KerasLayer>> customLayers = new HashMap();
    static final Map<String, SameDiffLambdaLayer> lambdaLayers = new HashMap();

    /* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/KerasLayer$DimOrder.class */
    public enum DimOrder {
        NONE,
        THEANO,
        TENSORFLOW
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public KerasLayer(Integer num) throws UnsupportedKerasConfigurationException {
        this.weightL1Regularization = 0.0d;
        this.weightL2Regularization = 0.0d;
        this.dropout = 1.0d;
        this.kerasMajorVersion = 2;
        this.className = null;
        this.layerName = null;
        this.inputShape = null;
        this.dimOrder = DimOrder.NONE;
        this.inboundLayerNames = new ArrayList();
        this.outboundLayerNames = new ArrayList();
        this.layer = null;
        this.vertex = null;
        this.weights = null;
        this.kerasMajorVersion = num;
        this.conf = KerasLayerConfigurationFactory.get(this.kerasMajorVersion);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public KerasLayer() throws UnsupportedKerasConfigurationException {
        this.weightL1Regularization = 0.0d;
        this.weightL2Regularization = 0.0d;
        this.dropout = 1.0d;
        this.kerasMajorVersion = 2;
        this.className = null;
        this.layerName = null;
        this.inputShape = null;
        this.dimOrder = DimOrder.NONE;
        this.inboundLayerNames = new ArrayList();
        this.outboundLayerNames = new ArrayList();
        this.layer = null;
        this.vertex = null;
        this.weights = null;
        this.conf = KerasLayerConfigurationFactory.get(this.kerasMajorVersion);
    }

    protected KerasLayer(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, true);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public KerasLayer(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.weightL1Regularization = 0.0d;
        this.weightL2Regularization = 0.0d;
        this.dropout = 1.0d;
        this.kerasMajorVersion = 2;
        this.originalLayerConfig = map;
        this.kerasMajorVersion = (Integer) map.get(LAYER_FIELD_KERAS_VERSION);
        this.conf = KerasLayerConfigurationFactory.get(this.kerasMajorVersion);
        this.className = KerasLayerUtils.getClassNameFromConfig(map, this.conf);
        if (this.className == null) {
            throw new InvalidKerasConfigurationException("Keras layer class name is missing");
        }
        this.layerName = KerasLayerUtils.getLayerNameFromConfig(map, this.conf);
        if (this.layerName == null) {
            throw new InvalidKerasConfigurationException("Keras layer class name is missing");
        }
        this.inputShape = KerasLayerUtils.getInputShapeFromConfig(map, this.conf);
        this.dimOrder = KerasLayerUtils.getDimOrderFromConfig(map, this.conf);
        this.inboundLayerNames = KerasLayerUtils.getInboundLayerNamesFromConfig(map, this.conf);
        this.outboundLayerNames = KerasLayerUtils.getOutboundLayerNamesFromConfig(map, this.conf);
        this.layer = null;
        this.vertex = null;
        this.weights = null;
        this.weightL1Regularization = KerasRegularizerUtils.getWeightRegularizerFromConfig(map, this.conf, this.conf.getLAYER_FIELD_W_REGULARIZER(), this.conf.getREGULARIZATION_TYPE_L1());
        this.weightL2Regularization = KerasRegularizerUtils.getWeightRegularizerFromConfig(map, this.conf, this.conf.getLAYER_FIELD_W_REGULARIZER(), this.conf.getREGULARIZATION_TYPE_L2());
        this.dropout = KerasLayerUtils.getDropoutFromConfig(map, this.conf);
        KerasLayerUtils.checkForUnsupportedConfigurations(map, z, this.conf);
    }

    public static void registerLambdaLayer(String str, SameDiffLambdaLayer sameDiffLambdaLayer) {
        lambdaLayers.put(str, sameDiffLambdaLayer);
    }

    public static void clearLambdaLayers() {
        lambdaLayers.clear();
    }

    public static void registerCustomLayer(String str, Class<? extends KerasLayer> cls) {
        customLayers.put(str, cls);
    }

    public static void clearCustomLayers() {
        customLayers.clear();
    }

    public Integer getKerasMajorVersion() {
        return this.kerasMajorVersion;
    }

    public String getClassName() {
        return this.className;
    }

    public String getLayerName() {
        return this.layerName;
    }

    public int[] getInputShape() {
        if (this.inputShape == null) {
            return null;
        }
        return (int[]) this.inputShape.clone();
    }

    public DimOrder getDimOrder() {
        return this.dimOrder;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setDimOrder(DimOrder dimOrder) {
        this.dimOrder = dimOrder;
    }

    public List<String> getInboundLayerNames() {
        if (this.inboundLayerNames == null) {
            this.inboundLayerNames = new ArrayList();
        }
        return this.inboundLayerNames;
    }

    public void setInboundLayerNames(List<String> list) {
        this.inboundLayerNames = new ArrayList(list);
    }

    public int getNumParams() {
        return 0;
    }

    public boolean usesRegularization() {
        return this.weightL1Regularization > 0.0d || this.weightL2Regularization > 0.0d || this.dropout < 1.0d;
    }

    public void setWeights(Map<String, INDArray> map) throws InvalidKerasConfigurationException {
    }

    public Map<String, INDArray> getWeights() {
        return this.weights;
    }

    public void copyWeightsToLayer(org.deeplearning4j.nn.api.Layer layer) throws InvalidKerasConfigurationException {
        if (getNumParams() > 0) {
            String str = "Error when attempting to copy weights from Keras layer " + getLayerName() + " to DL4J layer " + layer.conf().getLayer().getLayerName();
            if (getWeights() == null) {
                throw new InvalidKerasConfigurationException(str + "(weights is null)");
            }
            HashSet hashSet = new HashSet(layer.paramTable().keySet());
            HashSet hashSet2 = new HashSet(this.weights.keySet());
            hashSet.removeAll(hashSet2);
            if (!hashSet.isEmpty()) {
                throw new InvalidKerasConfigurationException(str + "(no stored weights for parameters: " + StringUtils.join(hashSet, ", ") + ")");
            }
            hashSet2.removeAll(layer.paramTable().keySet());
            if (!hashSet2.isEmpty()) {
                throw new InvalidKerasConfigurationException(str + "(found no parameters named: " + StringUtils.join(hashSet2, ", ") + ")");
            }
            for (String str2 : layer.paramTable().keySet()) {
                try {
                    long[] shape = ((INDArray) layer.paramTable().get(str2)).shape();
                    long[] shape2 = this.weights.get(str2).shape();
                    INDArray iNDArray = this.weights.get(str2);
                    if (Arrays.equals(shape, shape2) || ArrayUtil.prod(shape) != ArrayUtil.prod(shape2)) {
                        layer.setParam(str2, iNDArray);
                    } else {
                        layer.setParam(str2, iNDArray.reshape(shape));
                    }
                } catch (Exception e) {
                    log.error(e.getMessage());
                    throw new InvalidKerasConfigurationException(e.getMessage() + "\nTried to set weights for layer with name " + getLayerName() + ", of " + layer.conf().getLayer().getClass() + ".\nFailed to set weights for parameter " + str2 + "\nExpected shape for this parameter: " + layer.getParam(str2).shapeInfoToString() + ", \ngot: " + this.weights.get(str2).shapeInfoToString());
                }
            }
        }
    }

    public boolean isLayer() {
        return this.layer != null;
    }

    public Layer getLayer() {
        return this.layer;
    }

    public void setLayer(Layer layer) {
        this.layer = layer;
    }

    public boolean isVertex() {
        return this.vertex != null;
    }

    public GraphVertex getVertex() {
        return this.vertex;
    }

    public boolean isInputPreProcessor() {
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public long getNInFromConfig(Map<String, ? extends KerasLayer> map) throws UnsupportedKerasConfigurationException {
        int size = map.size();
        int i = 0;
        String str = this.inboundLayerNames.get(0);
        while (i <= size) {
            if (map.containsKey(str)) {
                KerasLayer kerasLayer = map.get(str);
                try {
                    long nOut = kerasLayer.getLayer().getNOut();
                    if (nOut > 0) {
                        return nOut;
                    }
                    i++;
                    str = kerasLayer.getInboundLayerNames().get(0);
                } catch (Exception e) {
                    str = kerasLayer.getInboundLayerNames().get(0);
                }
            }
        }
        throw new UnsupportedKerasConfigurationException("Could not determine number of input channels fordepthwise convolution.");
    }

    public InputPreProcessor getInputPreprocessor(InputType... inputTypeArr) throws InvalidKerasConfigurationException {
        InputPreProcessor inputPreProcessor = null;
        if (this.layer != null) {
            if (inputTypeArr.length > 1) {
                InputType inputType = null;
                for (int i = 0; i < inputTypeArr.length; i++) {
                    if (inputTypeArr[i] != null) {
                        if (inputType == null) {
                            inputType = inputTypeArr[i];
                        } else if (!inputType.equals(inputTypeArr[i])) {
                            throw new InvalidKerasConfigurationException("Keras layer of type \"" + this.className + "\" accepts only one input");
                        }
                    }
                }
                if (inputType == null) {
                    throw new InvalidKerasConfigurationException("Keras layer of type \"" + this.className + " did not have any inputs!");
                }
                inputPreProcessor = this.layer.getPreProcessorForInputType(inputType);
            } else {
                inputPreProcessor = this.layer.getPreProcessorForInputType(inputTypeArr[0]);
            }
        }
        return inputPreProcessor;
    }

    public InputType getOutputType(InputType... inputTypeArr) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        throw new UnsupportedOperationException("Cannot determine output type for Keras layer of type " + this.className);
    }

    public boolean isValidInboundLayer() throws InvalidKerasConfigurationException {
        return (getLayer() == null && getVertex() == null && getInputPreprocessor(new InputType[0]) == null && !this.className.equals(this.conf.getLAYER_CLASS_NAME_INPUT())) ? false : true;
    }

    public List<String> getOutboundLayerNames() {
        return this.outboundLayerNames;
    }

    public KerasLayerConfiguration getConf() {
        return this.conf;
    }

    public Map<String, Object> getOriginalLayerConfig() {
        return this.originalLayerConfig;
    }
}
