package org.nd4j.autodiff.listeners.debugging;

import java.util.Arrays;
import java.util.List;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.ScalarOp;

/* loaded from: input_file:org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.class */
public class ExecDebuggingListener extends BaseListener {
    private final PrintMode printMode;
    private final int maxIterations;
    private final boolean logIter;
    private long printIterations = 0;
    private int lastIter = -1;
    private int stepThisIter = 0;

    /* loaded from: input_file:org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener$PrintMode.class */
    public enum PrintMode {
        OPS_ONLY,
        SHAPES_ONLY,
        REPRODUCE
    }

    public ExecDebuggingListener(PrintMode printMode, int i, boolean z) {
        this.printMode = printMode;
        this.maxIterations = i;
        this.logIter = z;
    }

    @Override // org.nd4j.autodiff.listeners.Listener
    public boolean isActive(Operation operation) {
        return true;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void preOpExecution(SameDiff sameDiff, At at, SameDiffOp sameDiffOp, OpContext opContext) {
        INDArray scalar;
        INDArray scalar2;
        if (this.lastIter != at.iteration()) {
            this.lastIter = at.iteration();
            this.stepThisIter = 0;
            this.printIterations++;
        }
        if (this.maxIterations <= 0 || this.printIterations <= this.maxIterations) {
            StringBuilder sb = new StringBuilder();
            if (this.logIter) {
                sb.append("(iter=").append(at.iteration()).append(",epoch=").append(at.epoch()).append(",");
            }
            StringBuilder append = sb.append("op=");
            int i = this.stepThisIter;
            this.stepThisIter = i + 1;
            append.append(i).append(this.logIter ? ") " : " - ");
            Object op = sameDiffOp.getOp();
            sb.append(sameDiffOp.getOp().getClass().getName());
            CustomOp customOp = op instanceof CustomOp ? (CustomOp) op : null;
            Op op2 = op instanceof Op ? (Op) op : null;
            if (this.printMode == PrintMode.OPS_ONLY) {
                sb.append("\n");
            } else if (this.printMode == PrintMode.SHAPES_ONLY) {
                if (customOp != null) {
                    if (customOp.iArgs() != null && customOp.iArgs().length > 0) {
                        sb.append("\n\tiArgs=").append(Arrays.toString(customOp.iArgs()));
                    }
                    if (customOp.bArgs() != null && customOp.bArgs().length > 0) {
                        sb.append("\n\tbArgs=").append(Arrays.toString(customOp.bArgs()));
                    }
                    if (customOp.tArgs() != null && customOp.tArgs().length > 0) {
                        sb.append("\n\ttArgs=").append(Arrays.toString(customOp.tArgs()));
                    }
                    List<INDArray> inputArguments = customOp.inputArguments();
                    List<INDArray> outputArguments = customOp.outputArguments();
                    if (inputArguments != null) {
                        for (int i2 = 0; i2 < inputArguments.size(); i2++) {
                            sb.append("\n\tInput[").append(i2).append("]=").append(inputArguments.get(i2).shapeInfoToString());
                        }
                    }
                    if (outputArguments != null) {
                        for (int i3 = 0; i3 < outputArguments.size(); i3++) {
                            sb.append("\n\tOutputs[").append(i3).append("]=").append(outputArguments.get(i3).shapeInfoToString());
                        }
                    }
                } else {
                    if (op2.x() != null) {
                        sb.append("\n\tx: ").append(op2.x().shapeInfoToString());
                    }
                    if (op2.y() != null) {
                        sb.append("\n\ty: ").append(op2.y().shapeInfoToString());
                    }
                    if (op2.z() != null) {
                        sb.append("\n\tz: ").append(op2.z().shapeInfoToString());
                    }
                    if ((op2 instanceof ScalarOp) && (scalar2 = ((ScalarOp) op2).scalar()) != null) {
                        sb.append("\n\tscalar: ").append(scalar2.shapeInfoToString());
                    }
                }
                sb.append("\n");
            } else if (this.printMode == PrintMode.REPRODUCE) {
                sb.append("\n");
                if (customOp != null) {
                    sb.append("DynamicCustomOp op = new ").append(customOp.getClass().getName()).append("();\n");
                    if (customOp.iArgs() != null && customOp.iArgs().length > 0) {
                        sb.append("op.addIArgument(").append(Arrays.toString(customOp.iArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
                    }
                    if (customOp.bArgs() != null && customOp.bArgs().length > 0) {
                        sb.append("op.addBArgument(").append(Arrays.toString(customOp.bArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
                    }
                    if (customOp.tArgs() != null && customOp.tArgs().length > 0) {
                        sb.append("op.addTArgument(").append(Arrays.toString(customOp.tArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
                    }
                    List<INDArray> inputArguments2 = customOp.inputArguments();
                    List<INDArray> outputArguments2 = customOp.outputArguments();
                    if (inputArguments2 != null) {
                        sb.append("INDArray[] inputs = new INDArray[").append(inputArguments2.size()).append("];\n");
                        for (int i4 = 0; i4 < inputArguments2.size(); i4++) {
                            sb.append("inputs[").append(i4).append("] = ");
                            sb.append(createString(inputArguments2.get(i4))).append(";\n");
                        }
                        sb.append("op.addInputArgument(inputs);\n");
                    }
                    if (outputArguments2 != null) {
                        sb.append("INDArray[] outputs = new INDArray[").append(outputArguments2.size()).append("];\n");
                        for (int i5 = 0; i5 < outputArguments2.size(); i5++) {
                            sb.append("outputs[").append(i5).append("] = ");
                            sb.append(createString(outputArguments2.get(i5))).append(";\n");
                        }
                        sb.append("op.addOutputArgument(outputs);\n");
                    }
                } else {
                    sb.append("Op op = new ").append(sameDiffOp.getClass().getName()).append("();\n");
                    if (op2.x() != null) {
                        sb.append("op.setX(").append(createString(op2.x())).append(");\n");
                    }
                    if (op2.y() != null) {
                        sb.append("op.setY(").append(createString(op2.y())).append(");\n");
                    }
                    if (op2.z() != null) {
                        sb.append("op.setZ").append(createString(op2.z())).append(");\n");
                    }
                    if ((op2 instanceof ScalarOp) && (scalar = ((ScalarOp) op2).scalar()) != null) {
                        sb.append("((ScalarOp)op).setScalar(").append(createString(scalar)).append(");\n");
                    }
                }
                sb.append("Nd4j.exec(op);\n");
            }
            System.out.print(sb.toString());
        }
    }

    private static String createString(INDArray iNDArray) {
        StringBuilder sb = new StringBuilder();
        if (iNDArray.isEmpty()) {
            sb.append("Nd4j.empty(DataType.").append(iNDArray.dataType()).append(");");
        } else {
            sb.append("Nd4j.createFromArray(");
            DataType dataType = iNDArray.dataType();
            switch (dataType) {
                case DOUBLE:
                    sb.append(Arrays.toString(iNDArray.dup().data().asDouble()).replaceAll("[\\[\\]]", ""));
                    break;
                case FLOAT:
                case HALF:
                case BFLOAT16:
                    sb.append(Arrays.toString(iNDArray.dup().data().asFloat()).replaceAll(",", "f,").replaceAll("]", "f").replaceAll("[\\[\\]]", ""));
                    break;
                case LONG:
                case UINT32:
                case UINT64:
                    sb.append(Arrays.toString(iNDArray.dup().data().asLong()).replaceAll(",", "L,").replaceAll("]", "L").replaceAll("[\\[\\]]", ""));
                    break;
                case INT:
                case SHORT:
                case UBYTE:
                case BYTE:
                case UINT16:
                case BOOL:
                    sb.append(Arrays.toString(iNDArray.dup().data().asInt()).replaceAll("[\\[\\]]", ""));
                    break;
            }
            sb.append(").reshape(").append(Arrays.toString(iNDArray.shape()).replaceAll("[\\[\\]]", "")).append(")");
            if (dataType == DataType.HALF || dataType == DataType.BFLOAT16 || dataType == DataType.UINT32 || dataType == DataType.UINT64 || dataType == DataType.SHORT || dataType == DataType.UBYTE || dataType == DataType.BYTE || dataType == DataType.UINT16 || dataType == DataType.BOOL) {
                sb.append(".cast(DataType.").append(iNDArray.dataType()).append(")");
            }
        }
        return sb.toString();
    }
}
