package org.apache.tvm.rpc;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.tvm.Function;
import org.apache.tvm.Module;
import org.apache.tvm.TVMContext;

/* loaded from: input_file:org/apache/tvm/rpc/RPCSession.class */
public class RPCSession {
    private final Module session;
    private final int tblIndex;
    private final Map<String, Function> remoteFuncs = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    public RPCSession(Module module) {
        this.session = module;
        this.tblIndex = (int) RPC.getApi("SessTableIndex").pushArg(this.session).invoke().asLong();
    }

    public Function getFunction(String str) {
        return this.session.getFunction(str);
    }

    public TVMContext context(String str, int i) {
        TVMContext tVMContext = new TVMContext(str, i);
        return new TVMRemoteContext(tVMContext.deviceType + ((this.tblIndex + 1) * 128), i, this);
    }

    public TVMContext context(String str) {
        return context(str, 0);
    }

    public TVMContext context(int i, int i2) {
        return new TVMRemoteContext(i + ((this.tblIndex + 1) * 128), i2, this);
    }

    public TVMContext context(int i) {
        return context(i, 0);
    }

    public TVMContext cpu(int i) {
        return context(1, i);
    }

    public TVMContext cpu() {
        return cpu(0);
    }

    public TVMContext gpu(int i) {
        return context(2, i);
    }

    public TVMContext gpu() {
        return gpu(0);
    }

    public TVMContext cl(int i) {
        return context(4, i);
    }

    public TVMContext cl() {
        return cl(0);
    }

    public TVMContext vulkan(int i) {
        return context(7, i);
    }

    public TVMContext vulkan() {
        return vulkan(0);
    }

    public TVMContext metal(int i) {
        return context(8, i);
    }

    public TVMContext metal() {
        return metal(0);
    }

    public void upload(byte[] bArr, String str) {
        if (str == null) {
            throw new IllegalArgumentException("Please specify the upload target");
        }
        Function function = this.remoteFuncs.get("upload");
        if (function == null) {
            function = getFunction("tvm.rpc.server.upload");
            this.remoteFuncs.put("upload", function);
        }
        function.pushArg(str).pushArg(bArr).invoke();
    }

    public void upload(File file, String str) throws IOException {
        upload(getBytesFromFile(file), str);
    }

    public void upload(File file) throws IOException {
        upload(file, file.getName());
    }

    public byte[] download(String str) {
        Function function = this.remoteFuncs.get("download");
        if (function == null) {
            function = getFunction("tvm.rpc.server.download");
            this.remoteFuncs.put("download", function);
        }
        return function.pushArg(str).invoke().asBytes();
    }

    public Module loadModule(String str) {
        return RPC.getApi("LoadRemoteModule").pushArg(this.session).pushArg(str).invoke().asModule();
    }

    private static byte[] getBytesFromFile(File file) throws IOException {
        int read;
        long length = file.length();
        if (length > 2147483647L) {
            throw new IOException("File " + file.getName() + " is too large!");
        }
        byte[] bArr = new byte[(int) length];
        int i = 0;
        FileInputStream fileInputStream = new FileInputStream(file);
        while (i < bArr.length && (read = fileInputStream.read(bArr, i, bArr.length - i)) >= 0) {
            try {
                i += read;
            } finally {
                fileInputStream.close();
            }
        }
        if (i < bArr.length) {
            throw new IOException("Could not completely read file " + file.getName());
        }
        return bArr;
    }
}
