package org.apache.tvm.contrib;

import org.apache.tvm.Function;
import org.apache.tvm.Module;
import org.apache.tvm.NDArray;
import org.apache.tvm.TVMContext;

/* loaded from: input_file:org/apache/tvm/contrib/GraphModule.class */
public class GraphModule {
    private Module module;
    private TVMContext ctx;
    private Function fsetInput;
    private Function frun;
    private Function fgetOutput;
    private Function fgetInput;
    private Function fdebugGetOutput;
    private Function floadParams;

    /* JADX INFO: Access modifiers changed from: package-private */
    public GraphModule(Module module, TVMContext tVMContext) {
        this.module = module;
        this.ctx = tVMContext;
        this.fsetInput = module.getFunction("set_input");
        this.frun = module.getFunction("run");
        this.fgetInput = module.getFunction("get_input");
        this.fgetOutput = module.getFunction("get_output");
        try {
            this.fdebugGetOutput = module.getFunction("debug_get_output");
        } catch (IllegalArgumentException e) {
        }
        this.floadParams = module.getFunction("load_params");
    }

    public void release() {
        this.fsetInput.release();
        this.frun.release();
        this.fgetInput.release();
        this.fgetOutput.release();
        if (this.fdebugGetOutput != null) {
            this.fdebugGetOutput.release();
        }
        this.floadParams.release();
        this.module.release();
    }

    public GraphModule setInput(String str, NDArray nDArray) {
        NDArray nDArray2 = nDArray;
        if (!nDArray.ctx().equals(this.ctx)) {
            nDArray2 = NDArray.empty(nDArray.shape(), this.ctx);
            nDArray.copyTo(nDArray2);
        }
        this.fsetInput.pushArg(str).pushArg(nDArray2).invoke();
        return this;
    }

    public GraphModule setInput(int i, NDArray nDArray) {
        NDArray nDArray2 = nDArray;
        if (!nDArray.ctx().equals(this.ctx)) {
            nDArray2 = NDArray.empty(nDArray.shape(), this.ctx);
            nDArray.copyTo(nDArray2);
        }
        this.fsetInput.pushArg(i).pushArg(nDArray2).invoke();
        return this;
    }

    public GraphModule run() {
        this.frun.invoke();
        return this;
    }

    public NDArray getInput(int i, NDArray nDArray) {
        this.fgetInput.pushArg(i).pushArg(nDArray).invoke();
        return nDArray;
    }

    public NDArray getOutput(int i, NDArray nDArray) {
        this.fgetOutput.pushArg(i).pushArg(nDArray).invoke();
        return nDArray;
    }

    public NDArray debugGetOutput(String str, NDArray nDArray) {
        if (this.fdebugGetOutput == null) {
            throw new RuntimeException("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0");
        }
        this.fdebugGetOutput.pushArg(str).pushArg(nDArray).invoke();
        return nDArray;
    }

    public NDArray debugGetOutput(int i, NDArray nDArray) {
        if (this.fdebugGetOutput == null) {
            throw new RuntimeException("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0");
        }
        this.fdebugGetOutput.pushArg(i).pushArg(nDArray).invoke();
        return nDArray;
    }

    public GraphModule loadParams(byte[] bArr) {
        this.floadParams.pushArg(bArr).invoke();
        return this;
    }

    public Function getFunction(String str) {
        return this.module.getFunction(str);
    }
}
