package org.deeplearning4j.nn.modelimport.keras.utils;

import java.util.Map;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.class */
public class KerasActivationUtils {
    public static Activation mapToActivation(String str, KerasLayerConfiguration kerasLayerConfiguration) throws UnsupportedKerasConfigurationException {
        Activation activation;
        if (str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_SOFTMAX())) {
            activation = Activation.SOFTMAX;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_SOFTPLUS())) {
            activation = Activation.SOFTPLUS;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_SOFTSIGN())) {
            activation = Activation.SOFTSIGN;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_RELU())) {
            activation = Activation.RELU;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_RELU6())) {
            activation = Activation.RELU6;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_ELU())) {
            activation = Activation.ELU;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_SELU())) {
            activation = Activation.SELU;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_TANH())) {
            activation = Activation.TANH;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_SIGMOID())) {
            activation = Activation.SIGMOID;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_HARD_SIGMOID())) {
            activation = Activation.HARDSIGMOID;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_LINEAR())) {
            activation = Activation.IDENTITY;
        } else {
            if (!str.equals(kerasLayerConfiguration.getKERAS_ACTIVATION_SWISH())) {
                throw new UnsupportedKerasConfigurationException("Unknown Keras activation function " + str);
            }
            activation = Activation.SWISH;
        }
        return activation;
    }

    public static IActivation mapToIActivation(String str, KerasLayerConfiguration kerasLayerConfiguration) throws UnsupportedKerasConfigurationException {
        return mapToActivation(str, kerasLayerConfiguration).getActivationFunction();
    }

    public static IActivation getIActivationFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return getActivationFromConfig(map, kerasLayerConfiguration).getActivationFunction();
    }

    public static Activation getActivationFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_ACTIVATION())) {
            return mapToActivation((String) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_ACTIVATION()), kerasLayerConfiguration);
        }
        throw new InvalidKerasConfigurationException("Keras layer is missing " + kerasLayerConfiguration.getLAYER_FIELD_ACTIVATION() + " field");
    }
}
