package org.nd4j.linalg.api.ops.impl.reduce3;

import java.util.Arrays;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.util.SameDiffUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.class */
public class EuclideanDistance extends BaseReduce3Op {
    public static final String OP_NAME = "euclidean";

    public EuclideanDistance(SameDiff sameDiff, SDVariable sDVariable, int[] iArr) {
        super(sameDiff, sDVariable, iArr);
    }

    public EuclideanDistance(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, int[] iArr) {
        super(sameDiff, sDVariable, sDVariable2, iArr);
    }

    public EuclideanDistance(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2) {
        super(sameDiff, sDVariable, sDVariable2);
    }

    public EuclideanDistance(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        super(sameDiff, sDVariable, sDVariable2, sDVariable3);
    }

    public EuclideanDistance() {
    }

    public EuclideanDistance(INDArray iNDArray, INDArray iNDArray2, int... iArr) {
        this(iNDArray, iNDArray2, (INDArray) null, iArr);
    }

    public EuclideanDistance(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        this(iNDArray, iNDArray2, iNDArray3, (int[]) null);
    }

    public EuclideanDistance(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        super(iNDArray, iNDArray2, iNDArray3, iArr);
        this.extraArgs = new Object[]{Float.valueOf(0.0f), Float.valueOf(0.0f)};
    }

    public EuclideanDistance(INDArray iNDArray, INDArray iNDArray2, boolean z, int... iArr) {
        this(iNDArray, iNDArray2, (INDArray) null, z, iArr);
        this.extraArgs = new Object[]{Float.valueOf(0.0f), Float.valueOf(0.0f)};
    }

    public EuclideanDistance(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, boolean z, int... iArr) {
        this(iNDArray, iNDArray2, iNDArray3, false, z, iArr);
        this.extraArgs = new Object[]{Float.valueOf(0.0f), Float.valueOf(0.0f)};
    }

    public EuclideanDistance(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, boolean z, boolean z2, int... iArr) {
        super(iNDArray, iNDArray2, iNDArray3, z, z2, iArr);
        this.extraArgs = new Object[]{Float.valueOf(0.0f), Float.valueOf(0.0f)};
    }

    public EuclideanDistance(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, boolean z, boolean z2, int[] iArr) {
        super(sameDiff, sDVariable, sDVariable2, z, z2, iArr);
        this.extraArgs = new Object[]{Float.valueOf(0.0f), Float.valueOf(0.0f)};
    }

    public EuclideanDistance(INDArray iNDArray, INDArray iNDArray2, boolean z, boolean z2, int[] iArr) {
        super(iNDArray, iNDArray2, (INDArray) null, z, z2, iArr);
        this.extraArgs = new Object[]{Float.valueOf(0.0f), Float.valueOf(0.0f)};
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int opNum() {
        return 1;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return OP_NAME;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        SDVariable sDVariable = outputVariables()[0];
        SDVariable sub = larg().sub(rarg());
        SDVariable div = list.get(0).div(sDVariable);
        if (!this.keepDims && this.dimensions != null && this.dimensions.length != 0 && (this.dimensions.length != 1 || this.dimensions[0] != Integer.MAX_VALUE)) {
            div = SameDiffUtils.reductionBroadcastableWithOrigShape(arg(), this.sameDiff.constant(Nd4j.createFromArray(this.dimensions)), div);
        }
        SDVariable mul = sub.mul(div);
        return Arrays.asList(mul, this.sameDiff.math.neg(mul));
    }
}
