/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.convolution;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.factory.Nd4j;

public class Convolution {
    private Convolution() {
    }

    public static INDArray col2im(INDArray col, int[] stride, int[] padding, int height, int width) {
        return Convolution.col2im(col, stride[0], stride[1], padding[0], padding[1], height, width);
    }

    public static INDArray col2im(INDArray col, int sH, int sW, int ph, int pW, int kH, int kW) {
        if (col.rank() != 6) {
            throw new IllegalArgumentException("col2im input array must be rank 6");
        }
        INDArray output = Nd4j.create(col.dataType(), col.size(0), col.size(1), kH, kW);
        Conv2DConfig cfg = Conv2DConfig.builder().sH(sH).sW(sW).dH(1L).dW(1L).kH(kH).kW(kW).pH(ph).pW(pW).build();
        Col2Im col2Im = Col2Im.builder().inputArrays(new INDArray[]{col}).outputs(new INDArray[]{output}).conv2DConfig(cfg).build();
        Nd4j.getExecutioner().execAndReturn(col2Im);
        return col2Im.outputArguments()[0];
    }

    public static INDArray col2im(INDArray col, INDArray z, int sH, int sW, int pH, int pW, int kH, int kW, int dH, int dW) {
        if (col.rank() != 6) {
            throw new IllegalArgumentException("col2im input array must be rank 6");
        }
        if (z.rank() != 4) {
            throw new IllegalArgumentException("col2im output array must be rank 4");
        }
        Col2Im col2Im = Col2Im.builder().inputArrays(new INDArray[]{col}).outputs(new INDArray[]{z}).conv2DConfig(Conv2DConfig.builder().sH(sH).sW(sW).dH(dH).dW(dW).kH(kH).kW(kW).pH(pH).pW(pW).build()).build();
        Nd4j.getExecutioner().execAndReturn(col2Im);
        return z;
    }

    public static INDArray im2col(INDArray img, int[] kernel, int[] stride, int[] padding) {
        Nd4j.getCompressor().autoDecompress(img);
        return Convolution.im2col(img, kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], 0, false);
    }

    public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, boolean isSameMode) {
        return Convolution.im2col(img, kh, kw, sy, sx, ph, pw, 1, 1, isSameMode);
    }

    public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dh, int dw, boolean isSameMode) {
        Nd4j.getCompressor().autoDecompress(img);
        long outH = Convolution.outputSize(img.size(2), kh, sy, ph, dh, isSameMode);
        long outW = Convolution.outputSize(img.size(3), kw, sx, pw, dw, isSameMode);
        INDArray out = Nd4j.create(new long[]{img.size(0), img.size(1), kh, kw, outH, outW}, 'c');
        return Convolution.im2col(img, kh, kw, sy, sx, ph, pw, dh, dw, isSameMode, out);
    }

    public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, boolean isSameMode, INDArray out) {
        Im2col im2col = Im2col.builder().outputs(new INDArray[]{out}).inputArrays(new INDArray[]{img}).conv2DConfig(Conv2DConfig.builder().kH(kh).pW(pw).pH(ph).sH(sy).sW(sx).kH(kh).kW(kw).dH(1L).dW(1L).isSameMode(isSameMode).build()).build();
        Nd4j.getExecutioner().execAndReturn(im2col);
        return im2col.outputArguments()[0];
    }

    public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dH, int dW, boolean isSameMode, INDArray out) {
        Im2col im2col = Im2col.builder().outputs(new INDArray[]{out}).inputArrays(new INDArray[]{img}).conv2DConfig(Conv2DConfig.builder().pW(pw).pH(ph).sH(sy).sW(sx).kW(kw).kH(kh).dW(dW).dH(dH).isSameMode(isSameMode).build()).build();
        Nd4j.getExecutioner().execAndReturn(im2col);
        return im2col.outputArguments()[0];
    }

    public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor, double extra, int virtualHeight, int virtualWidth, INDArray out) {
        Pooling2D pooling = new Pooling2D(img, out, Pooling2DConfig.builder().dH(dh).dW(dw).extra(extra).kH(kh).kW(kw).pH(ph).pW(pw).isSameMode(isSameMode).sH(sy).sW(sx).type(type).divisor(divisor).build());
        Nd4j.getExecutioner().execAndReturn(pooling);
        return out;
    }

    public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int pval, boolean isSameMode) {
        INDArray output = null;
        if (isSameMode) {
            int oH = (int)Math.ceil((float)img.size(2) * 1.0f / (float)sy);
            int oW = (int)Math.ceil((float)img.size(3) * 1.0f / (float)sx);
            output = Nd4j.createUninitialized(img.dataType(), new long[]{img.size(0), img.size(1), kh, kw, oH, oW}, 'c');
        } else {
            long oH = (img.size(2) - (long)(kh + (kh - 1) * 0) + (long)(2 * ph)) / (long)sy + 1L;
            long oW = (img.size(3) - (long)(kw + (kw - 1) * 0) + (long)(2 * pw)) / (long)sx + 1L;
            output = Nd4j.createUninitialized(img.dataType(), new long[]{img.size(0), img.size(1), kh, kw, oH, oW}, 'c');
        }
        Im2col im2col = Im2col.builder().inputArrays(new INDArray[]{img}).outputs(new INDArray[]{output}).conv2DConfig(Conv2DConfig.builder().pW(pw).pH(ph).sH(sy).sW(sx).kW(kw).kH(kh).dW(1L).dH(1L).isSameMode(isSameMode).build()).build();
        Nd4j.getExecutioner().execAndReturn(im2col);
        return im2col.outputArguments()[0];
    }

    @Deprecated
    public static long outSize(long size, long k, long s, long p, int dilation, boolean coverAll) {
        k = Convolution.effectiveKernelSize(k, dilation);
        if (coverAll) {
            return (size + p * 2L - k + s - 1L) / s + 1L;
        }
        return (size + p * 2L - k) / s + 1L;
    }

    public static long outputSize(long size, long k, long s, long p, int dilation, boolean isSameMode) {
        k = Convolution.effectiveKernelSize(k, dilation);
        if (isSameMode) {
            return (int)Math.ceil((float)size * 1.0f / (float)s);
        }
        return (size - k + 2L * p) / s + 1L;
    }

    public static long effectiveKernelSize(long kernel, int dilation) {
        return kernel + (kernel - 1L) * (long)(dilation - 1);
    }

    public static INDArray conv2d(INDArray input, INDArray kernel, Type type) {
        return Nd4j.getConvolution().conv2d(input, kernel, type);
    }

    public static INDArray convn(INDArray input, INDArray kernel, Type type, int[] axes) {
        return Nd4j.getConvolution().convn(input, kernel, type, axes);
    }

    public static INDArray convn(INDArray input, INDArray kernel, Type type) {
        return Nd4j.getConvolution().convn(input, kernel, type);
    }

    public static enum Type {
        FULL,
        VALID,
        SAME;

    }
}

