package org.deeplearning4j.nn.modelimport.keras.utils;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Deconvolution3D;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.shade.jackson.core.type.TypeReference;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.class */
public class KerasModelUtils {
    private static final Logger log = LoggerFactory.getLogger(KerasModelUtils.class);

    public static void setDataFormatIfNeeded(InputPreProcessor inputPreProcessor, KerasLayer kerasLayer) {
        if (inputPreProcessor instanceof ReshapePreprocessor) {
            ReshapePreprocessor reshapePreprocessor = (ReshapePreprocessor) inputPreProcessor;
            if (!kerasLayer.isLayer() || kerasLayer.getDimOrder() == null) {
                return;
            }
            Convolution3D layer = kerasLayer.getLayer();
            if (!(layer instanceof ConvolutionLayer)) {
                if (layer instanceof BaseRecurrentLayer) {
                    reshapePreprocessor.setFormat(((BaseRecurrentLayer) layer).getRnnDataFormat());
                    return;
                }
                return;
            }
            Convolution3D convolution3D = (ConvolutionLayer) layer;
            if (convolution3D instanceof Convolution3D) {
                reshapePreprocessor.setFormat(convolution3D.getDataFormat());
            } else if (convolution3D instanceof Deconvolution3D) {
                reshapePreprocessor.setFormat(((Deconvolution3D) convolution3D).getDataFormat());
            } else {
                reshapePreprocessor.setFormat(convolution3D.getCnn2dDataFormat());
            }
        }
    }

    public static Model copyWeightsToModel(Model model, Map<String, KerasLayer> map) throws InvalidKerasConfigurationException {
        Layer[] layers = model instanceof MultiLayerNetwork ? ((MultiLayerNetwork) model).getLayers() : ((ComputationGraph) model).getLayers();
        HashSet<String> hashSet = new HashSet(map.keySet());
        for (Layer layer : layers) {
            String layerName = layer.conf().getLayer().getLayerName();
            if (!map.containsKey(layerName)) {
                throw new InvalidKerasConfigurationException("No weights found for layer in model (named " + layerName + ")");
            }
            map.get(layerName).copyWeightsToLayer(layer);
            hashSet.remove(layerName);
        }
        for (String str : hashSet) {
            if (map.get(str).getNumParams() > 0) {
                throw new InvalidKerasConfigurationException("Attempting to copy weights for layer not in model (named " + str + ")");
            }
        }
        return model;
    }

    public static int determineKerasMajorVersion(Map<String, Object> map, KerasModelConfiguration kerasModelConfiguration) throws InvalidKerasConfigurationException {
        int numericValue;
        if (map.containsKey(kerasModelConfiguration.getFieldKerasVersion())) {
            String str = (String) map.get(kerasModelConfiguration.getFieldKerasVersion());
            if (!Character.isDigit(str.charAt(0))) {
                throw new InvalidKerasConfigurationException("Keras version was not readable (" + kerasModelConfiguration.getFieldKerasVersion() + " provided)");
            }
            numericValue = Character.getNumericValue(str.charAt(0));
        } else {
            log.warn("Could not read keras version used (no " + kerasModelConfiguration.getFieldKerasVersion() + " field found) \nassuming keras version is 1.0.7 or earlier.");
            numericValue = 1;
        }
        return numericValue;
    }

    public static String determineKerasBackend(Map<String, Object> map, KerasModelConfiguration kerasModelConfiguration) {
        String str = null;
        if (map.containsKey(kerasModelConfiguration.getFieldBackend())) {
            str = (String) map.get(kerasModelConfiguration.getFieldBackend());
        } else {
            log.warn("Could not read keras backend used (no " + kerasModelConfiguration.getFieldBackend() + " field found) \n");
        }
        return str;
    }

    private static String findParameterName(String str, String[] strArr) {
        String replaceFirst = Pattern.compile(strArr[strArr.length - 1]).matcher(str).replaceFirst("");
        Matcher matcher = Pattern.compile("^_(.+)$").matcher(replaceFirst);
        if (matcher.find()) {
            replaceFirst = matcher.group(1);
        }
        Matcher matcher2 = Pattern.compile(":\\d+?$").matcher(replaceFirst);
        if (matcher2.find()) {
            replaceFirst = matcher2.replaceFirst("");
        }
        Matcher matcher3 = Pattern.compile("_\\d+$").matcher(replaceFirst);
        if (matcher3.find()) {
            replaceFirst = matcher3.replaceFirst("");
        }
        return replaceFirst;
    }

    /* JADX WARN: Removed duplicated region for block: B:49:0x036d A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:53:0x0370 A[SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static void importWeights(org.deeplearning4j.nn.modelimport.keras.Hdf5Archive r8, java.lang.String r9, java.util.Map<java.lang.String, org.deeplearning4j.nn.modelimport.keras.KerasLayer> r10, int r11, java.lang.String r12) throws org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException, org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException {
        /*
            Method dump skipped, instructions count: 1561
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils.importWeights(org.deeplearning4j.nn.modelimport.keras.Hdf5Archive, java.lang.String, java.util.Map, int, java.lang.String):void");
    }

    public static Map<String, Object> parseModelConfig(String str, String str2) throws IOException, InvalidKerasConfigurationException {
        Map<String, Object> parseYamlString;
        if (str != null) {
            parseYamlString = parseJsonString(str);
        } else {
            if (str2 == null) {
                throw new InvalidKerasConfigurationException("Requires model configuration as either JSON or YAML string.");
            }
            parseYamlString = parseYamlString(str2);
        }
        return parseYamlString;
    }

    public static Map<String, Object> parseJsonString(String str) throws IOException {
        return (Map) new ObjectMapper().readValue(str, new TypeReference<HashMap<String, Object>>() { // from class: org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils.1
        });
    }

    public static Map<String, Object> parseYamlString(String str) throws IOException {
        return (Map) new ObjectMapper(new YAMLFactory()).readValue(str, new TypeReference<HashMap<String, Object>>() { // from class: org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils.2
        });
    }
}
