package ai.onnxruntime;

import ai.onnxruntime.providers.CoreMLFlags;
import ai.onnxruntime.providers.NNAPIFlags;
import ai.onnxruntime.providers.OrtFlags;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Logger;

/* loaded from: input_file:ai/onnxruntime/OrtSession.class */
public class OrtSession implements AutoCloseable {
    private final long nativeHandle;
    private final OrtAllocator allocator;
    private final long numInputs;
    private final Set<String> inputNames;
    private final long numOutputs;
    private final Set<String> outputNames;
    private OnnxModelMetadata metadata;
    private boolean closed;

    /* loaded from: input_file:ai/onnxruntime/OrtSession$Result.class */
    public static class Result implements AutoCloseable, Iterable<Map.Entry<String, OnnxValue>> {
        private static final Logger logger = Logger.getLogger(Result.class.getName());
        private final Map<String, OnnxValue> map = new LinkedHashMap();
        private final List<OnnxValue> list = new ArrayList();
        private boolean closed;

        Result(String[] strArr, OnnxValue[] onnxValueArr) {
            if (strArr.length != onnxValueArr.length) {
                throw new IllegalArgumentException("Expected same number of names and values, found names.length = " + strArr.length + ", values.length = " + onnxValueArr.length);
            }
            for (int i = 0; i < strArr.length; i++) {
                this.map.put(strArr[i], onnxValueArr[i]);
                this.list.add(onnxValueArr[i]);
            }
            this.closed = false;
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            if (this.closed) {
                logger.warning("Closing an already closed Result");
                return;
            }
            this.closed = true;
            Iterator<OnnxValue> it = this.map.values().iterator();
            while (it.hasNext()) {
                it.next().close();
            }
        }

        @Override // java.lang.Iterable
        public Iterator<Map.Entry<String, OnnxValue>> iterator() {
            if (this.closed) {
                throw new IllegalStateException("Result is closed");
            }
            return this.map.entrySet().iterator();
        }

        public OnnxValue get(int i) {
            if (this.closed) {
                throw new IllegalStateException("Result is closed");
            }
            return this.list.get(i);
        }

        public int size() {
            return this.map.size();
        }

        public Optional<OnnxValue> get(String str) {
            if (this.closed) {
                throw new IllegalStateException("Result is closed");
            }
            OnnxValue onnxValue = this.map.get(str);
            return onnxValue != null ? Optional.of(onnxValue) : Optional.empty();
        }
    }

    /* loaded from: input_file:ai/onnxruntime/OrtSession$RunOptions.class */
    public static class RunOptions implements AutoCloseable {
        private boolean closed = false;
        private final long nativeHandle = createRunOptions(OnnxRuntime.ortApiHandle);

        public void setLogLevel(OrtLoggingLevel ortLoggingLevel) throws OrtException {
            checkClosed();
            setLogLevel(OnnxRuntime.ortApiHandle, this.nativeHandle, ortLoggingLevel.getValue());
        }

        public OrtLoggingLevel getLogLevel() throws OrtException {
            checkClosed();
            return OrtLoggingLevel.mapFromInt(getLogLevel(OnnxRuntime.ortApiHandle, this.nativeHandle));
        }

        public void setLogVerbosityLevel(int i) throws OrtException {
            checkClosed();
            setLogVerbosityLevel(OnnxRuntime.ortApiHandle, this.nativeHandle, i);
        }

        public int getLogVerbosityLevel() throws OrtException {
            checkClosed();
            return getLogVerbosityLevel(OnnxRuntime.ortApiHandle, this.nativeHandle);
        }

        public void setRunTag(String str) throws OrtException {
            checkClosed();
            setRunTag(OnnxRuntime.ortApiHandle, this.nativeHandle, str);
        }

        public String getRunTag() throws OrtException {
            checkClosed();
            return getRunTag(OnnxRuntime.ortApiHandle, this.nativeHandle);
        }

        public void setTerminate(boolean z) throws OrtException {
            checkClosed();
            setTerminate(OnnxRuntime.ortApiHandle, this.nativeHandle, z);
        }

        private void checkClosed() {
            if (this.closed) {
                throw new IllegalStateException("Trying to use a closed RunOptions");
            }
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            if (this.closed) {
                throw new IllegalStateException("Trying to close an already closed RunOptions");
            }
            close(OnnxRuntime.ortApiHandle, this.nativeHandle);
            this.closed = true;
        }

        private static native long createRunOptions(long j) throws OrtException;

        private native void setLogLevel(long j, long j2, int i) throws OrtException;

        private native int getLogLevel(long j, long j2) throws OrtException;

        private native void setLogVerbosityLevel(long j, long j2, int i) throws OrtException;

        private native int getLogVerbosityLevel(long j, long j2) throws OrtException;

        private native void setRunTag(long j, long j2, String str) throws OrtException;

        private native String getRunTag(long j, long j2) throws OrtException;

        private native void setTerminate(long j, long j2, boolean z) throws OrtException;

        private static native void close(long j, long j2);
    }

    /* loaded from: input_file:ai/onnxruntime/OrtSession$SessionOptions.class */
    public static class SessionOptions implements AutoCloseable {
        private boolean closed = false;
        private final long nativeHandle = createOptions(OnnxRuntime.ortApiHandle);
        private final List<Long> customLibraryHandles = new ArrayList();
        private final Map<String, String> configEntries = new LinkedHashMap();

        /* loaded from: input_file:ai/onnxruntime/OrtSession$SessionOptions$ExecutionMode.class */
        public enum ExecutionMode {
            SEQUENTIAL(0),
            PARALLEL(1);

            private final int id;

            ExecutionMode(int i) {
                this.id = i;
            }

            public int getID() {
                return this.id;
            }
        }

        /* loaded from: input_file:ai/onnxruntime/OrtSession$SessionOptions$OptLevel.class */
        public enum OptLevel {
            NO_OPT(0),
            BASIC_OPT(1),
            EXTENDED_OPT(2),
            ALL_OPT(99);

            private final int id;

            OptLevel(int i) {
                this.id = i;
            }

            public int getID() {
                return this.id;
            }
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            if (this.closed) {
                throw new IllegalStateException("Trying to close a closed SessionOptions.");
            }
            if (this.customLibraryHandles.size() > 0) {
                long[] jArr = new long[this.customLibraryHandles.size()];
                for (int i = 0; i < this.customLibraryHandles.size(); i++) {
                    jArr[i] = this.customLibraryHandles.get(i).longValue();
                }
                closeCustomLibraries(jArr);
            }
            closeOptions(OnnxRuntime.ortApiHandle, this.nativeHandle);
            this.closed = true;
        }

        private void checkClosed() {
            if (this.closed) {
                throw new IllegalStateException("Trying to use a closed SessionOptions");
            }
        }

        public void setExecutionMode(ExecutionMode executionMode) throws OrtException {
            checkClosed();
            setExecutionMode(OnnxRuntime.ortApiHandle, this.nativeHandle, executionMode.getID());
        }

        public void setOptimizationLevel(OptLevel optLevel) throws OrtException {
            checkClosed();
            setOptimizationLevel(OnnxRuntime.ortApiHandle, this.nativeHandle, optLevel.getID());
        }

        public void setInterOpNumThreads(int i) throws OrtException {
            checkClosed();
            setInterOpNumThreads(OnnxRuntime.ortApiHandle, this.nativeHandle, i);
        }

        public void setIntraOpNumThreads(int i) throws OrtException {
            checkClosed();
            setIntraOpNumThreads(OnnxRuntime.ortApiHandle, this.nativeHandle, i);
        }

        public void setOptimizedModelFilePath(String str) throws OrtException {
            checkClosed();
            setOptimizationModelFilePath(OnnxRuntime.ortApiHandle, this.nativeHandle, str);
        }

        public void setLoggerId(String str) throws OrtException {
            checkClosed();
            setLoggerId(OnnxRuntime.ortApiHandle, this.nativeHandle, str);
        }

        public void enableProfiling(String str) throws OrtException {
            checkClosed();
            enableProfiling(OnnxRuntime.ortApiHandle, this.nativeHandle, str);
        }

        public void disableProfiling() throws OrtException {
            checkClosed();
            disableProfiling(OnnxRuntime.ortApiHandle, this.nativeHandle);
        }

        public void setMemoryPatternOptimization(boolean z) throws OrtException {
            checkClosed();
            setMemoryPatternOptimization(OnnxRuntime.ortApiHandle, this.nativeHandle, z);
        }

        public void setCPUArenaAllocator(boolean z) throws OrtException {
            checkClosed();
            setCPUArenaAllocator(OnnxRuntime.ortApiHandle, this.nativeHandle, z);
        }

        public void setSessionLogLevel(OrtLoggingLevel ortLoggingLevel) throws OrtException {
            checkClosed();
            setSessionLogLevel(OnnxRuntime.ortApiHandle, this.nativeHandle, ortLoggingLevel.getValue());
        }

        public void setSessionLogVerbosityLevel(int i) throws OrtException {
            checkClosed();
            setSessionLogVerbosityLevel(OnnxRuntime.ortApiHandle, this.nativeHandle, i);
        }

        public void registerCustomOpLibrary(String str) throws OrtException {
            checkClosed();
            this.customLibraryHandles.add(Long.valueOf(registerCustomOpLibrary(OnnxRuntime.ortApiHandle, this.nativeHandle, str)));
        }

        public void setSymbolicDimensionValue(String str, long j) throws OrtException {
            checkClosed();
            addFreeDimensionOverrideByName(OnnxRuntime.ortApiHandle, this.nativeHandle, str, j);
        }

        public void disablePerSessionThreads() throws OrtException {
            checkClosed();
            disablePerSessionThreads(OnnxRuntime.ortApiHandle, this.nativeHandle);
        }

        public void addConfigEntry(String str, String str2) throws OrtException {
            checkClosed();
            addConfigEntry(OnnxRuntime.ortApiHandle, this.nativeHandle, str, str2);
            this.configEntries.put(str, str2);
        }

        public Map<String, String> getConfigEntries() {
            checkClosed();
            return Collections.unmodifiableMap(this.configEntries);
        }

        public void addCUDA() throws OrtException {
            addCUDA(0);
        }

        public void addCUDA(int i) throws OrtException {
            checkClosed();
            addCUDA(OnnxRuntime.ortApiHandle, this.nativeHandle, i);
        }

        public void addCPU(boolean z) throws OrtException {
            checkClosed();
            addCPU(OnnxRuntime.ortApiHandle, this.nativeHandle, z ? 1 : 0);
        }

        public void addDnnl(boolean z) throws OrtException {
            checkClosed();
            addDnnl(OnnxRuntime.ortApiHandle, this.nativeHandle, z ? 1 : 0);
        }

        public void addOpenVINO(String str) throws OrtException {
            checkClosed();
            addOpenVINO(OnnxRuntime.ortApiHandle, this.nativeHandle, str);
        }

        public void addTensorrt(int i) throws OrtException {
            checkClosed();
            addTensorrt(OnnxRuntime.ortApiHandle, this.nativeHandle, i);
        }

        public void addNnapi() throws OrtException {
            addNnapi(EnumSet.noneOf(NNAPIFlags.class));
        }

        public void addNnapi(EnumSet<NNAPIFlags> enumSet) throws OrtException {
            checkClosed();
            addNnapi(OnnxRuntime.ortApiHandle, this.nativeHandle, OrtFlags.aggregateToInt(enumSet));
        }

        public void addNuphar(boolean z, String str) throws OrtException {
            checkClosed();
            addNuphar(OnnxRuntime.ortApiHandle, this.nativeHandle, z ? 1 : 0, str);
        }

        public void addDirectML(int i) throws OrtException {
            checkClosed();
            addDirectML(OnnxRuntime.ortApiHandle, this.nativeHandle, i);
        }

        public void addACL(boolean z) throws OrtException {
            checkClosed();
            addACL(OnnxRuntime.ortApiHandle, this.nativeHandle, z ? 1 : 0);
        }

        public void addArmNN(boolean z) throws OrtException {
            checkClosed();
            addArmNN(OnnxRuntime.ortApiHandle, this.nativeHandle, z ? 1 : 0);
        }

        public void addROCM(int i, long j) throws OrtException {
            checkClosed();
            addROCM(OnnxRuntime.ortApiHandle, this.nativeHandle, i, j);
        }

        public void addCoreML() throws OrtException {
            addCoreML(EnumSet.noneOf(CoreMLFlags.class));
        }

        public void addCoreML(EnumSet<CoreMLFlags> enumSet) throws OrtException {
            checkClosed();
            addCoreML(OnnxRuntime.ortApiHandle, this.nativeHandle, OrtFlags.aggregateToInt(enumSet));
        }

        private native void setExecutionMode(long j, long j2, int i) throws OrtException;

        private native void setOptimizationLevel(long j, long j2, int i) throws OrtException;

        private native void setInterOpNumThreads(long j, long j2, int i) throws OrtException;

        private native void setIntraOpNumThreads(long j, long j2, int i) throws OrtException;

        private native void setOptimizationModelFilePath(long j, long j2, String str) throws OrtException;

        private native long createOptions(long j);

        private native void setLoggerId(long j, long j2, String str) throws OrtException;

        private native void enableProfiling(long j, long j2, String str) throws OrtException;

        private native void disableProfiling(long j, long j2) throws OrtException;

        private native void setMemoryPatternOptimization(long j, long j2, boolean z) throws OrtException;

        private native void setCPUArenaAllocator(long j, long j2, boolean z) throws OrtException;

        private native void setSessionLogLevel(long j, long j2, int i) throws OrtException;

        private native void setSessionLogVerbosityLevel(long j, long j2, int i) throws OrtException;

        private native long registerCustomOpLibrary(long j, long j2, String str) throws OrtException;

        private native void closeCustomLibraries(long[] jArr);

        private native void closeOptions(long j, long j2);

        private native void addFreeDimensionOverrideByName(long j, long j2, String str, long j3) throws OrtException;

        private native void disablePerSessionThreads(long j, long j2) throws OrtException;

        private native void addConfigEntry(long j, long j2, String str, String str2) throws OrtException;

        private native void addCPU(long j, long j2, int i) throws OrtException;

        private native void addCUDA(long j, long j2, int i) throws OrtException;

        private native void addDnnl(long j, long j2, int i) throws OrtException;

        private native void addOpenVINO(long j, long j2, String str) throws OrtException;

        private native void addTensorrt(long j, long j2, int i) throws OrtException;

        private native void addNnapi(long j, long j2, int i) throws OrtException;

        private native void addNuphar(long j, long j2, int i, String str) throws OrtException;

        private native void addDirectML(long j, long j2, int i) throws OrtException;

        private native void addACL(long j, long j2, int i) throws OrtException;

        private native void addArmNN(long j, long j2, int i) throws OrtException;

        private native void addROCM(long j, long j2, int i, long j3) throws OrtException;

        private native void addCoreML(long j, long j2, int i) throws OrtException;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public OrtSession(OrtEnvironment ortEnvironment, String str, OrtAllocator ortAllocator, SessionOptions sessionOptions) throws OrtException {
        this(createSession(OnnxRuntime.ortApiHandle, ortEnvironment.nativeHandle, str, sessionOptions.nativeHandle), ortAllocator);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public OrtSession(OrtEnvironment ortEnvironment, byte[] bArr, OrtAllocator ortAllocator, SessionOptions sessionOptions) throws OrtException {
        this(createSession(OnnxRuntime.ortApiHandle, ortEnvironment.nativeHandle, bArr, sessionOptions.nativeHandle), ortAllocator);
    }

    private OrtSession(long j, OrtAllocator ortAllocator) throws OrtException {
        this.closed = false;
        this.nativeHandle = j;
        this.allocator = ortAllocator;
        this.numInputs = getNumInputs(OnnxRuntime.ortApiHandle, j);
        this.inputNames = new LinkedHashSet(Arrays.asList(getInputNames(OnnxRuntime.ortApiHandle, j, ortAllocator.handle)));
        this.numOutputs = getNumOutputs(OnnxRuntime.ortApiHandle, j);
        this.outputNames = new LinkedHashSet(Arrays.asList(getOutputNames(OnnxRuntime.ortApiHandle, j, ortAllocator.handle)));
    }

    public long getNumInputs() {
        if (this.closed) {
            throw new IllegalStateException("Asking for inputs from a closed OrtSession.");
        }
        return this.numInputs;
    }

    public long getNumOutputs() {
        if (this.closed) {
            throw new IllegalStateException("Asking for outputs from a closed OrtSession.");
        }
        return this.numOutputs;
    }

    public Set<String> getInputNames() {
        if (this.closed) {
            throw new IllegalStateException("Asking for inputs from a closed OrtSession.");
        }
        return this.inputNames;
    }

    public Set<String> getOutputNames() {
        if (this.closed) {
            throw new IllegalStateException("Asking for outputs from a closed OrtSession.");
        }
        return this.outputNames;
    }

    public Map<String, NodeInfo> getInputInfo() throws OrtException {
        if (this.closed) {
            throw new IllegalStateException("Asking for inputs from a closed OrtSession.");
        }
        return wrapInMap(getInputInfo(OnnxRuntime.ortApiHandle, this.nativeHandle, this.allocator.handle));
    }

    public Map<String, NodeInfo> getOutputInfo() throws OrtException {
        if (this.closed) {
            throw new IllegalStateException("Asking for outputs from a closed OrtSession.");
        }
        return wrapInMap(getOutputInfo(OnnxRuntime.ortApiHandle, this.nativeHandle, this.allocator.handle));
    }

    public Result run(Map<String, OnnxTensor> map) throws OrtException {
        return run(map, this.outputNames);
    }

    public Result run(Map<String, OnnxTensor> map, RunOptions runOptions) throws OrtException {
        return run(map, this.outputNames, runOptions);
    }

    public Result run(Map<String, OnnxTensor> map, Set<String> set) throws OrtException {
        return run(map, set, null);
    }

    public Result run(Map<String, OnnxTensor> map, Set<String> set, RunOptions runOptions) throws OrtException {
        if (this.closed) {
            throw new IllegalStateException("Trying to score a closed OrtSession.");
        }
        if (map.isEmpty() || map.size() > this.numInputs) {
            throw new OrtException("Unexpected number of inputs, expected [1," + this.numInputs + ") found " + map.size());
        }
        if (set.isEmpty() || set.size() > this.numOutputs) {
            throw new OrtException("Unexpected number of requestedOutputs, expected [1," + this.numOutputs + ") found " + set.size());
        }
        String[] strArr = new String[map.size()];
        long[] jArr = new long[map.size()];
        int i = 0;
        for (Map.Entry<String, OnnxTensor> entry : map.entrySet()) {
            if (!this.inputNames.contains(entry.getKey())) {
                throw new OrtException("Unknown input name " + entry.getKey() + ", expected one of " + this.inputNames.toString());
            }
            strArr[i] = entry.getKey();
            jArr[i] = entry.getValue().getNativeHandle();
            i++;
        }
        String[] strArr2 = new String[set.size()];
        int i2 = 0;
        for (String str : set) {
            if (!this.outputNames.contains(str)) {
                throw new OrtException("Unknown output name " + str + ", expected one of " + this.outputNames.toString());
            }
            strArr2[i2] = str;
            i2++;
        }
        return new Result(strArr2, run(OnnxRuntime.ortApiHandle, this.nativeHandle, this.allocator.handle, strArr, jArr, strArr.length, strArr2, strArr2.length, runOptions == null ? 0L : runOptions.nativeHandle));
    }

    public OnnxModelMetadata getMetadata() throws OrtException {
        if (this.metadata == null) {
            this.metadata = constructMetadata(OnnxRuntime.ortApiHandle, this.nativeHandle, this.allocator.handle);
        }
        return this.metadata;
    }

    public long getProfilingStartTimeInNs() throws OrtException {
        return getProfilingStartTimeInNs(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    public String endProfiling() throws OrtException {
        return endProfiling(OnnxRuntime.ortApiHandle, this.nativeHandle, this.allocator.handle);
    }

    public String toString() {
        return "OrtSession(numInputs=" + this.numInputs + ",numOutputs=" + this.numOutputs + ")";
    }

    @Override // java.lang.AutoCloseable
    public void close() throws OrtException {
        if (this.closed) {
            throw new IllegalStateException("Trying to close an already closed OrtSession.");
        }
        closeSession(OnnxRuntime.ortApiHandle, this.nativeHandle);
        this.closed = true;
    }

    private static Map<String, NodeInfo> wrapInMap(NodeInfo[] nodeInfoArr) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (NodeInfo nodeInfo : nodeInfoArr) {
            linkedHashMap.put(nodeInfo.getName(), nodeInfo);
        }
        return linkedHashMap;
    }

    private static native long createSession(long j, long j2, String str, long j3) throws OrtException;

    private static native long createSession(long j, long j2, byte[] bArr, long j3) throws OrtException;

    private native long getNumInputs(long j, long j2) throws OrtException;

    private native String[] getInputNames(long j, long j2, long j3) throws OrtException;

    private native NodeInfo[] getInputInfo(long j, long j2, long j3) throws OrtException;

    private native long getNumOutputs(long j, long j2) throws OrtException;

    private native String[] getOutputNames(long j, long j2, long j3) throws OrtException;

    private native NodeInfo[] getOutputInfo(long j, long j2, long j3) throws OrtException;

    private native OnnxValue[] run(long j, long j2, long j3, String[] strArr, long[] jArr, long j4, String[] strArr2, long j5, long j6) throws OrtException;

    private native long getProfilingStartTimeInNs(long j, long j2) throws OrtException;

    private native String endProfiling(long j, long j2, long j3) throws OrtException;

    private native void closeSession(long j, long j2) throws OrtException;

    private native OnnxModelMetadata constructMetadata(long j, long j2, long j3) throws OrtException;

    static {
        try {
            OnnxRuntime.init();
        } catch (IOException e) {
            throw new RuntimeException("Failed to load onnx-runtime library", e);
        }
    }
}
