package org.deeplearning4j.util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.MultiLayerUpdater;
import org.deeplearning4j.nn.updater.UpdaterBlock;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.schedule.ISchedule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/util/NetworkUtils.class */
public class NetworkUtils {
    private static final Logger log = LoggerFactory.getLogger(NetworkUtils.class);

    private NetworkUtils() {
    }

    public static ComputationGraph toComputationGraph(MultiLayerNetwork multiLayerNetwork) {
        ComputationGraphConfiguration.GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder().dataType(multiLayerNetwork.getLayerWiseConfigurations().getDataType()).graphBuilder();
        MultiLayerConfiguration m36clone = multiLayerNetwork.getLayerWiseConfigurations().m36clone();
        int i = 0;
        String str = "in";
        graphBuilder.addInputs("in");
        for (NeuralNetConfiguration neuralNetConfiguration : m36clone.getConfs()) {
            String valueOf = String.valueOf(i);
            graphBuilder.addLayer(valueOf, neuralNetConfiguration.getLayer(), m36clone.getInputPreProcess(i), str);
            str = valueOf;
            i++;
        }
        graphBuilder.setOutputs(str);
        ComputationGraph computationGraph = new ComputationGraph(graphBuilder.build());
        computationGraph.init();
        computationGraph.setParams(multiLayerNetwork.params());
        INDArray stateViewArray = multiLayerNetwork.getUpdater().getStateViewArray();
        if (stateViewArray != null) {
            computationGraph.getUpdater().getUpdaterStateViewArray().assign(stateViewArray);
        }
        return computationGraph;
    }

    public static void setLearningRate(MultiLayerNetwork multiLayerNetwork, double d) {
        setLearningRate(multiLayerNetwork, d, (ISchedule) null);
    }

    private static void setLearningRate(MultiLayerNetwork multiLayerNetwork, double d, ISchedule iSchedule) {
        int i = multiLayerNetwork.getnLayers();
        for (int i2 = 0; i2 < i; i2++) {
            setLearningRate(multiLayerNetwork, i2, d, iSchedule, false);
        }
        refreshUpdater(multiLayerNetwork);
    }

    private static void setLearningRate(MultiLayerNetwork multiLayerNetwork, int i, double d, ISchedule iSchedule, boolean z) {
        Layer layer = multiLayerNetwork.getLayer(i).conf().getLayer();
        if (layer instanceof BaseLayer) {
            IUpdater iUpdater = ((BaseLayer) layer).getIUpdater();
            if (iUpdater != null && iUpdater.hasLearningRate()) {
                if (iSchedule != null) {
                    iUpdater.setLrAndSchedule(Double.NaN, iSchedule);
                } else {
                    iUpdater.setLrAndSchedule(d, (ISchedule) null);
                }
            }
            if (z) {
                refreshUpdater(multiLayerNetwork);
            }
        }
    }

    private static void refreshUpdater(MultiLayerNetwork multiLayerNetwork) {
        INDArray stateViewArray = multiLayerNetwork.getUpdater().getStateViewArray();
        MultiLayerUpdater multiLayerUpdater = (MultiLayerUpdater) multiLayerNetwork.getUpdater();
        multiLayerNetwork.setUpdater(null);
        MultiLayerUpdater multiLayerUpdater2 = (MultiLayerUpdater) multiLayerNetwork.getUpdater();
        multiLayerUpdater2.setStateViewArray(rebuildUpdaterStateArray(stateViewArray, multiLayerUpdater.getUpdaterBlocks(), multiLayerUpdater2.getUpdaterBlocks()));
    }

    public static void setLearningRate(MultiLayerNetwork multiLayerNetwork, ISchedule iSchedule) {
        setLearningRate(multiLayerNetwork, Double.NaN, iSchedule);
    }

    public static void setLearningRate(MultiLayerNetwork multiLayerNetwork, int i, double d) {
        setLearningRate(multiLayerNetwork, i, d, (ISchedule) null, true);
    }

    public static void setLearningRate(MultiLayerNetwork multiLayerNetwork, int i, ISchedule iSchedule) {
        setLearningRate(multiLayerNetwork, i, Double.NaN, iSchedule, true);
    }

    public static Double getLearningRate(MultiLayerNetwork multiLayerNetwork, int i) {
        IUpdater iUpdater;
        Layer layer = multiLayerNetwork.getLayer(i).conf().getLayer();
        int iterationCount = multiLayerNetwork.getIterationCount();
        int epochCount = multiLayerNetwork.getEpochCount();
        if (!(layer instanceof BaseLayer) || (iUpdater = ((BaseLayer) layer).getIUpdater()) == null || !iUpdater.hasLearningRate()) {
            return null;
        }
        double learningRate = iUpdater.getLearningRate(iterationCount, epochCount);
        if (Double.isNaN(learningRate)) {
            return null;
        }
        return Double.valueOf(learningRate);
    }

    public static void setLearningRate(ComputationGraph computationGraph, double d) {
        setLearningRate(computationGraph, d, (ISchedule) null);
    }

    private static void setLearningRate(ComputationGraph computationGraph, double d, ISchedule iSchedule) {
        for (org.deeplearning4j.nn.api.Layer layer : computationGraph.getLayers()) {
            setLearningRate(computationGraph, layer.conf().getLayer().getLayerName(), d, iSchedule, false);
        }
        refreshUpdater(computationGraph);
    }

    private static void setLearningRate(ComputationGraph computationGraph, String str, double d, ISchedule iSchedule, boolean z) {
        Layer layer = computationGraph.getLayer(str).conf().getLayer();
        if (layer instanceof BaseLayer) {
            IUpdater iUpdater = ((BaseLayer) layer).getIUpdater();
            if (iUpdater != null && iUpdater.hasLearningRate()) {
                if (iSchedule != null) {
                    iUpdater.setLrAndSchedule(Double.NaN, iSchedule);
                } else {
                    iUpdater.setLrAndSchedule(d, (ISchedule) null);
                }
            }
            if (z) {
                refreshUpdater(computationGraph);
            }
        }
    }

    private static void refreshUpdater(ComputationGraph computationGraph) {
        INDArray stateViewArray = computationGraph.getUpdater().getStateViewArray();
        ComputationGraphUpdater updater = computationGraph.getUpdater();
        computationGraph.setUpdater(null);
        ComputationGraphUpdater updater2 = computationGraph.getUpdater();
        updater2.setStateViewArray(rebuildUpdaterStateArray(stateViewArray, updater.getUpdaterBlocks(), updater2.getUpdaterBlocks()));
    }

    public static void setLearningRate(ComputationGraph computationGraph, ISchedule iSchedule) {
        setLearningRate(computationGraph, Double.NaN, iSchedule);
    }

    public static void setLearningRate(ComputationGraph computationGraph, String str, double d) {
        setLearningRate(computationGraph, str, d, (ISchedule) null, true);
    }

    public static void setLearningRate(ComputationGraph computationGraph, String str, ISchedule iSchedule) {
        setLearningRate(computationGraph, str, Double.NaN, iSchedule, true);
    }

    public static Double getLearningRate(ComputationGraph computationGraph, String str) {
        IUpdater iUpdater;
        Layer layer = computationGraph.getLayer(str).conf().getLayer();
        int iterationCount = computationGraph.getConfiguration().getIterationCount();
        int epochCount = computationGraph.getConfiguration().getEpochCount();
        if (!(layer instanceof BaseLayer) || (iUpdater = ((BaseLayer) layer).getIUpdater()) == null || !iUpdater.hasLearningRate()) {
            return null;
        }
        double learningRate = iUpdater.getLearningRate(iterationCount, epochCount);
        if (Double.isNaN(learningRate)) {
            return null;
        }
        return Double.valueOf(learningRate);
    }

    public static INDArray output(Model model, INDArray iNDArray) {
        if (model instanceof MultiLayerNetwork) {
            return ((MultiLayerNetwork) model).output(iNDArray);
        }
        if (model instanceof ComputationGraph) {
            return ((ComputationGraph) model).outputSingle(iNDArray);
        }
        throw new UnsupportedOperationException(model.getClass().getName().startsWith("org.deeplearning4j") ? model.getClass().getName() + " models are not yet supported and pull requests are welcome: https://github.com/eclipse/deeplearning4j" : model.getClass().getName() + " models are unsupported.");
    }

    public static void removeInstances(List<?> list, Class<?> cls) {
        removeInstancesWithWarning(list, cls, null);
    }

    public static void removeInstancesWithWarning(List<?> list, Class<?> cls, String str) {
        if (list == null || list.isEmpty()) {
            return;
        }
        Iterator<?> it = list.iterator();
        while (it.hasNext()) {
            if (cls.isAssignableFrom(it.next().getClass())) {
                if (str != null) {
                    log.warn(str);
                }
                it.remove();
            }
        }
    }

    protected static INDArray rebuildUpdaterStateArray(INDArray iNDArray, List<UpdaterBlock> list, List<UpdaterBlock> list2) {
        if (iNDArray == null) {
            return iNDArray;
        }
        if (list.size() == list2.size()) {
            boolean z = true;
            int i = 0;
            while (true) {
                if (i >= list.size()) {
                    break;
                }
                if (!list.get(i).getLayersAndVariablesInBlock().equals(list2.get(i).getLayersAndVariablesInBlock())) {
                    z = false;
                    break;
                }
                i++;
            }
            if (z) {
                return iNDArray;
            }
        }
        HashMap hashMap = new HashMap();
        for (UpdaterBlock updaterBlock : list) {
            List<UpdaterBlock.ParamState> layersAndVariablesInBlock = updaterBlock.getLayersAndVariablesInBlock();
            int paramOffsetStart = updaterBlock.getParamOffsetStart();
            int paramOffsetEnd = updaterBlock.getParamOffsetEnd();
            int updaterViewOffsetEnd = (updaterBlock.getUpdaterViewOffsetEnd() - updaterBlock.getUpdaterViewOffsetStart()) / (paramOffsetEnd - paramOffsetStart);
            INDArray updaterView = updaterBlock.getUpdaterView();
            long j = paramOffsetEnd - paramOffsetStart;
            long j2 = 0;
            for (int i2 = 0; i2 < updaterViewOffsetEnd; i2++) {
                INDArray iNDArray2 = updaterView.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j2, j2 + j)});
                long j3 = 0;
                for (UpdaterBlock.ParamState paramState : layersAndVariablesInBlock) {
                    String str = getId(paramState.getLayer()) + "_" + paramState.getParamName();
                    long length = paramState.getParamView().length();
                    INDArray iNDArray3 = iNDArray2.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j3, j3 + length)});
                    if (!hashMap.containsKey(str)) {
                        hashMap.put(str, new ArrayList());
                    }
                    ((List) hashMap.get(str)).add(iNDArray3);
                    j3 += length;
                }
                j2 += j;
            }
        }
        ArrayList arrayList = new ArrayList();
        Iterator<UpdaterBlock> it = list2.iterator();
        while (it.hasNext()) {
            List<UpdaterBlock.ParamState> layersAndVariablesInBlock2 = it.next().getLayersAndVariablesInBlock();
            int size = ((List) hashMap.get(getId(layersAndVariablesInBlock2.get(0).getLayer()) + "_" + layersAndVariablesInBlock2.get(0).getParamName())).size();
            for (int i3 = 0; i3 < size; i3++) {
                for (UpdaterBlock.ParamState paramState2 : layersAndVariablesInBlock2) {
                    arrayList.add((INDArray) ((List) hashMap.get(getId(paramState2.getLayer()) + "_" + paramState2.getParamName())).get(i3));
                }
            }
        }
        INDArray hstack = Nd4j.hstack(arrayList);
        Preconditions.checkState(hstack.rank() == 2, "Expected rank 2");
        Preconditions.checkState(iNDArray.length() == hstack.length(), "Updater state array lengths should be equal: got %s s. %s", iNDArray.length(), hstack.length());
        return hstack;
    }

    private static int getId(Trainable trainable) {
        return trainable instanceof GraphVertex ? ((GraphVertex) trainable).getVertexIndex() : ((org.deeplearning4j.nn.api.Layer) trainable).getIndex();
    }
}
