package org.nd4j.linalg.dataset.api.iterator;

import java.util.ArrayList;
import java.util.List;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/iterator/KFoldIterator.class */
public class KFoldIterator implements DataSetIterator {
    private static final long serialVersionUID = 6130298603412865817L;
    protected DataSet allData;
    protected int k;
    protected int N;
    protected int[] intervalBoundaries;
    protected int kCursor;
    protected DataSet test;
    protected DataSet train;
    protected DataSetPreProcessor preProcessor;

    public KFoldIterator(DataSet dataSet) {
        this(10, dataSet);
    }

    public KFoldIterator(int i, DataSet dataSet) {
        this.kCursor = 0;
        if (i <= 1) {
            throw new IllegalArgumentException();
        }
        this.k = i;
        this.N = dataSet.numExamples();
        this.allData = dataSet;
        int i2 = this.N / i;
        int i3 = this.N % i;
        this.intervalBoundaries = new int[i + 1];
        this.intervalBoundaries[0] = 0;
        for (int i4 = 1; i4 <= i; i4++) {
            if (i4 <= i3) {
                this.intervalBoundaries[i4] = this.intervalBoundaries[i4 - 1] + i2 + 1;
            } else {
                this.intervalBoundaries[i4] = this.intervalBoundaries[i4 - 1] + i2;
            }
        }
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public DataSet next(int i) throws UnsupportedOperationException {
        return null;
    }

    public int totalExamples() {
        return this.N;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int inputColumns() {
        return (int) this.allData.getFeatures().size(1);
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int totalOutcomes() {
        return (int) this.allData.getLabels().size(1);
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public boolean resetSupported() {
        return true;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public boolean asyncSupported() {
        return false;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void reset() {
        this.allData.shuffle();
        this.kCursor = 0;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int batch() {
        return this.intervalBoundaries[this.kCursor + 1] - this.intervalBoundaries[this.kCursor];
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.preProcessor = dataSetPreProcessor;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public List<String> getLabels() {
        return this.allData.getLabelNamesList();
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return this.kCursor < this.k;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public DataSet next() {
        nextFold();
        return this.train;
    }

    @Override // java.util.Iterator
    public void remove() {
    }

    protected void nextFold() {
        int i = this.intervalBoundaries[this.kCursor];
        int i2 = this.intervalBoundaries[this.kCursor + 1];
        ArrayList arrayList = new ArrayList();
        if (i2 < totalExamples()) {
            if (i > 0) {
                arrayList.add((DataSet) this.allData.getRange(0, i));
            }
            arrayList.add((DataSet) this.allData.getRange(i2, totalExamples()));
            this.train = DataSet.merge(arrayList);
        } else {
            this.train = (DataSet) this.allData.getRange(0, i);
        }
        this.test = (DataSet) this.allData.getRange(i, i2);
        this.kCursor++;
    }

    public DataSet testFold() {
        return this.test;
    }
}
