/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tvm;

import java.util.HashMap;
import java.util.Map;
import org.apache.tvm.APIInternal;
import org.apache.tvm.Base;
import org.apache.tvm.TVMValue;
import org.apache.tvm.TVMValueLong;

public class TVMContext {
    private static final Map<Integer, String> MASK2STR = new HashMap<Integer, String>();
    private static final Map<String, Integer> STR2MASK = new HashMap<String, Integer>();
    public final int deviceType;
    public final int deviceId;

    public static TVMContext cpu(int devId) {
        return new TVMContext(1, devId);
    }

    public static TVMContext cpu() {
        return TVMContext.cpu(0);
    }

    public static TVMContext gpu(int devId) {
        return new TVMContext(2, devId);
    }

    public static TVMContext gpu() {
        return TVMContext.gpu(0);
    }

    public static TVMContext opencl(int devId) {
        return new TVMContext(4, devId);
    }

    public static TVMContext opencl() {
        return TVMContext.opencl(0);
    }

    public static TVMContext vulkan(int devId) {
        return new TVMContext(7, devId);
    }

    public static TVMContext vulkan() {
        return TVMContext.vulkan(0);
    }

    public static TVMContext metal(int devId) {
        return new TVMContext(8, devId);
    }

    public static TVMContext metal() {
        return TVMContext.metal(0);
    }

    public static TVMContext vpi(int devId) {
        return new TVMContext(9, devId);
    }

    public static TVMContext vpi() {
        return TVMContext.vpi(0);
    }

    public static TVMContext hexagon(int devId) {
        return new TVMContext(14, devId);
    }

    public static TVMContext hexagon() {
        return TVMContext.hexagon(0);
    }

    public TVMContext(int deviceType, int deviceId) {
        this.deviceType = deviceType;
        this.deviceId = deviceId;
    }

    public TVMContext(String deviceType, int deviceId) {
        this(STR2MASK.get(deviceType), deviceId);
    }

    public boolean exist() {
        TVMValue ret = APIInternal.get("_GetDeviceAttr").pushArg(this.deviceType).pushArg(this.deviceId).pushArg(0).invoke();
        return ((TVMValueLong)ret).value != 0L;
    }

    public long maxThreadsPerBlock() {
        TVMValue ret = APIInternal.get("_GetDeviceAttr").pushArg(this.deviceType).pushArg(this.deviceId).pushArg(1).invoke();
        return ((TVMValueLong)ret).value;
    }

    public long warpSize() {
        TVMValue ret = APIInternal.get("_GetDeviceAttr").pushArg(this.deviceType).pushArg(this.deviceId).pushArg(2).invoke();
        return ((TVMValueLong)ret).value;
    }

    public void sync() {
        Base.checkCall(Base._LIB.tvmSynchronize(this.deviceType, this.deviceId));
    }

    public int hashCode() {
        return this.deviceType << 16 | this.deviceId;
    }

    public boolean equals(Object other) {
        if (other != null && other instanceof TVMContext) {
            TVMContext obj = (TVMContext)other;
            return this.deviceId == obj.deviceId && this.deviceType == obj.deviceType;
        }
        return false;
    }

    public String toString() {
        if (this.deviceType >= 128) {
            int tblId = this.deviceType / 128 - 1;
            int devType = this.deviceType % 128;
            return String.format("remote[%d]:%s(%d)", tblId, MASK2STR.get(devType), this.deviceId);
        }
        return String.format("%s(%d)", MASK2STR.get(this.deviceType), this.deviceId);
    }

    static {
        MASK2STR.put(1, "cpu");
        MASK2STR.put(2, "gpu");
        MASK2STR.put(4, "opencl");
        MASK2STR.put(7, "vulkan");
        MASK2STR.put(8, "metal");
        MASK2STR.put(9, "vpi");
        MASK2STR.put(14, "hexagon");
        STR2MASK.put("cpu", 1);
        STR2MASK.put("gpu", 2);
        STR2MASK.put("cuda", 2);
        STR2MASK.put("cl", 4);
        STR2MASK.put("opencl", 4);
        STR2MASK.put("vulkan", 7);
        STR2MASK.put("metal", 8);
        STR2MASK.put("vpi", 9);
        STR2MASK.put("hexagon", 14);
    }
}

