package org.deeplearning4j.nn.layers.mkldnn;

import java.util.Collections;
import java.util.Map;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.class */
public class MKLDNNSubsamplingHelper implements SubsamplingHelper {
    protected OpContext context;

    /* renamed from: org.deeplearning4j.nn.layers.mkldnn.MKLDNNSubsamplingHelper$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType = new int[PoolingType.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[PoolingType.MAX.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[PoolingType.AVG.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[PoolingType.SUM.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[PoolingType.PNORM.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    public MKLDNNSubsamplingHelper(DataType dataType) {
    }

    @Override // org.deeplearning4j.nn.layers.LayerHelper
    public boolean checkSupported() {
        return BaseMKLDNNHelper.mklDnnEnabled();
    }

    @Override // org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, INDArray iNDArray2, int[] iArr, int[] iArr2, int[] iArr3, PoolingType poolingType, ConvolutionMode convolutionMode, int[] iArr4, CNN2DFormat cNN2DFormat, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (poolingType == PoolingType.SUM || poolingType == PoolingType.PNORM) {
            return null;
        }
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, iNDArray.dataType(), iNDArray.shape());
        int i = 2;
        int i2 = 3;
        if (cNN2DFormat == CNN2DFormat.NHWC) {
            i = 1;
            i2 = 2;
        }
        if (convolutionMode == ConvolutionMode.Same) {
            iArr3 = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int) iNDArray2.size(i), (int) iNDArray2.size(i2)}, new int[]{(int) iNDArray.size(i), (int) iNDArray.size(i2)}, iArr, iArr2, iArr4);
        }
        Pooling2DConfig build = Pooling2DConfig.builder().paddingMode(ConvolutionMode.mapToMode(convolutionMode)).kH(iArr[0]).kW(iArr[1]).sH(iArr2[0]).sW(iArr2[1]).dH(iArr4[0]).dW(iArr4[1]).pH(iArr3[0]).pW(iArr3[1]).isNHWC(cNN2DFormat == CNN2DFormat.NHWC).build();
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[poolingType.ordinal()]) {
            case MergeVertex.DEFAULT_MERGE_DIM /* 1 */:
                build.setType(Pooling2D.Pooling2DType.MAX);
                break;
            case 2:
                build.setType(Pooling2D.Pooling2DType.AVG);
                break;
        }
        Nd4j.exec(new Pooling2DDerivative(iNDArray, iNDArray2, createUninitialized, build));
        return new Pair<>(new DefaultGradient(), createUninitialized);
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:22:0x0168. Please report as an issue. */
    @Override // org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper
    public INDArray activate(INDArray iNDArray, boolean z, int[] iArr, int[] iArr2, int[] iArr3, PoolingType poolingType, ConvolutionMode convolutionMode, int[] iArr4, CNN2DFormat cNN2DFormat, LayerWorkspaceMgr layerWorkspaceMgr) {
        int[] outputSize;
        MaxPooling2D avgPooling2D;
        int i = 2;
        int i2 = 3;
        if (cNN2DFormat == CNN2DFormat.NHWC) {
            i = 1;
            i2 = 2;
        }
        if (convolutionMode == ConvolutionMode.Same) {
            outputSize = ConvolutionUtils.getOutputSize(iNDArray, iArr, iArr2, null, convolutionMode, iArr4, cNN2DFormat);
            iArr3 = ConvolutionUtils.getSameModeTopLeftPadding(outputSize, new int[]{(int) iNDArray.size(i), (int) iNDArray.size(i2)}, iArr, iArr2, iArr4);
        } else {
            outputSize = ConvolutionUtils.getOutputSize(iNDArray, iArr, iArr2, iArr3, convolutionMode, iArr4, cNN2DFormat);
        }
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, iNDArray.dataType(), cNN2DFormat == CNN2DFormat.NCHW ? new long[]{iNDArray.size(0), iNDArray.size(1), outputSize[0], outputSize[1]} : new long[]{iNDArray.size(0), outputSize[0], outputSize[1], iNDArray.size(3)});
        if (this.context == null) {
            this.context = Nd4j.getExecutioner().buildContext();
            OpContext opContext = this.context;
            long[] jArr = new long[11];
            jArr[0] = iArr[0];
            jArr[1] = iArr[1];
            jArr[2] = iArr2[0];
            jArr[3] = iArr2[1];
            jArr[4] = iArr3[0];
            jArr[5] = iArr3[1];
            jArr[6] = iArr4[0];
            jArr[7] = iArr4[1];
            jArr[8] = ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same);
            jArr[9] = 0;
            jArr[10] = cNN2DFormat == CNN2DFormat.NCHW ? 0L : 1L;
            opContext.setIArguments(jArr);
        }
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$PoolingType[poolingType.ordinal()]) {
            case MergeVertex.DEFAULT_MERGE_DIM /* 1 */:
                avgPooling2D = new MaxPooling2D();
                this.context.purge();
                this.context.setInputArray(0, iNDArray);
                this.context.setOutputArray(0, createUninitialized);
                Nd4j.exec(avgPooling2D, this.context);
                return createUninitialized;
            case 2:
                avgPooling2D = new AvgPooling2D();
                this.context.purge();
                this.context.setInputArray(0, iNDArray);
                this.context.setOutputArray(0, createUninitialized);
                Nd4j.exec(avgPooling2D, this.context);
                return createUninitialized;
            case 3:
            case 4:
            default:
                return null;
        }
    }

    @Override // org.deeplearning4j.nn.layers.LayerHelper
    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }
}
