package org.deeplearning4j.earlystopping.scorecalc;

import org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/deeplearning4j/earlystopping/scorecalc/ROCScoreCalculator.class */
public class ROCScoreCalculator extends BaseIEvaluationScoreCalculator<Model, IEvaluation> {
    protected final ROCType type;
    protected final Metric metric;

    /* renamed from: org.deeplearning4j.earlystopping.scorecalc.ROCScoreCalculator$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/earlystopping/scorecalc/ROCScoreCalculator$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$earlystopping$scorecalc$ROCScoreCalculator$ROCType = new int[ROCType.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$earlystopping$scorecalc$ROCScoreCalculator$ROCType[ROCType.ROC.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$earlystopping$scorecalc$ROCScoreCalculator$ROCType[ROCType.BINARY.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$earlystopping$scorecalc$ROCScoreCalculator$ROCType[ROCType.MULTICLASS.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/earlystopping/scorecalc/ROCScoreCalculator$Metric.class */
    public enum Metric {
        AUC,
        AUPRC
    }

    /* loaded from: input_file:org/deeplearning4j/earlystopping/scorecalc/ROCScoreCalculator$ROCType.class */
    public enum ROCType {
        ROC,
        BINARY,
        MULTICLASS
    }

    public ROCScoreCalculator(ROCType rOCType, DataSetIterator dataSetIterator) {
        this(rOCType, Metric.AUC, dataSetIterator);
    }

    public ROCScoreCalculator(ROCType rOCType, MultiDataSetIterator multiDataSetIterator) {
        this(rOCType, Metric.AUC, multiDataSetIterator);
    }

    public ROCScoreCalculator(ROCType rOCType, Metric metric, DataSetIterator dataSetIterator) {
        super(dataSetIterator);
        this.type = rOCType;
        this.metric = metric;
    }

    public ROCScoreCalculator(ROCType rOCType, Metric metric, MultiDataSetIterator multiDataSetIterator) {
        super(multiDataSetIterator);
        this.type = rOCType;
        this.metric = metric;
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator
    protected IEvaluation newEval() {
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$earlystopping$scorecalc$ROCScoreCalculator$ROCType[this.type.ordinal()]) {
            case MergeVertex.DEFAULT_MERGE_DIM /* 1 */:
                return new ROC();
            case 2:
                return new ROCBinary();
            case 3:
                return new ROCMultiClass();
            default:
                throw new IllegalStateException("Unknown type: " + this.type);
        }
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator
    protected double finalScore(IEvaluation iEvaluation) {
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$earlystopping$scorecalc$ROCScoreCalculator$ROCType[this.type.ordinal()]) {
            case MergeVertex.DEFAULT_MERGE_DIM /* 1 */:
                ROC roc = (ROC) iEvaluation;
                return this.metric == Metric.AUC ? roc.calculateAUC() : roc.calculateAUCPR();
            case 2:
                ROCBinary rOCBinary = (ROCBinary) iEvaluation;
                return this.metric == Metric.AUC ? rOCBinary.calculateAverageAuc() : rOCBinary.calculateAverageAUCPR();
            case 3:
                ROCMultiClass rOCMultiClass = (ROCMultiClass) iEvaluation;
                return this.metric == Metric.AUC ? rOCMultiClass.calculateAverageAUC() : rOCMultiClass.calculateAverageAUCPR();
            default:
                throw new IllegalStateException("Unknown type: " + this.type);
        }
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator
    public boolean minimizeScore() {
        return false;
    }
}
