/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.compression;

import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import lombok.NonNull;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.compression.NDArrayCompressor;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BasicNDArrayCompressor {
    private static final Logger log = LoggerFactory.getLogger(BasicNDArrayCompressor.class);
    private static final BasicNDArrayCompressor INSTANCE = new BasicNDArrayCompressor();
    protected Map<String, NDArrayCompressor> codecs;
    protected String defaultCompression = "FLOAT16";

    private BasicNDArrayCompressor() {
        this.loadCompressors();
    }

    protected void loadCompressors() {
        this.codecs = new ConcurrentHashMap<String, NDArrayCompressor>();
        ServiceLoader<NDArrayCompressor> loader = ServiceLoader.load(NDArrayCompressor.class);
        for (NDArrayCompressor compressor : loader) {
            this.codecs.put(compressor.getDescriptor().toUpperCase(), compressor);
        }
        if (this.codecs.isEmpty()) {
            String msg = "Error loading ND4J Compressors via service loader: No compressors were found. This usually occurs when running ND4J UI from an uber-jar, which was built incorrectly (without services resource files being included)";
            log.error(msg);
            throw new RuntimeException(msg);
        }
    }

    public Set<String> getAvailableCompressors() {
        return this.codecs.keySet();
    }

    public void printAvailableCompressors() {
        StringBuilder builder = new StringBuilder();
        builder.append("Available compressors: ");
        for (String comp : this.codecs.keySet()) {
            builder.append("[").append(comp).append("] ");
        }
        System.out.println(builder.toString());
    }

    public static BasicNDArrayCompressor getInstance() {
        return INSTANCE;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public BasicNDArrayCompressor setDefaultCompression(@NonNull String algorithm) {
        if (algorithm == null) {
            throw new NullPointerException("algorithm is marked @NonNull but is null");
        }
        algorithm = algorithm.toUpperCase();
        BasicNDArrayCompressor basicNDArrayCompressor = this;
        synchronized (basicNDArrayCompressor) {
            this.defaultCompression = algorithm;
        }
        return this;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public String getDefaultCompression() {
        BasicNDArrayCompressor basicNDArrayCompressor = this;
        synchronized (basicNDArrayCompressor) {
            return this.defaultCompression;
        }
    }

    public DataBuffer compress(DataBuffer buffer) {
        return this.compress(buffer, this.getDefaultCompression());
    }

    public DataBuffer compress(DataBuffer buffer, String algorithm) {
        if (!this.codecs.containsKey(algorithm = algorithm.toUpperCase())) {
            throw new RuntimeException("Non-existent compression algorithm requested: [" + algorithm + "]");
        }
        return this.codecs.get(algorithm).compress(buffer);
    }

    public INDArray compress(INDArray array) {
        Nd4j.getExecutioner().commit();
        return this.compress(array, this.getDefaultCompression());
    }

    public void compressi(INDArray array) {
        this.compressi(array, this.getDefaultCompression());
    }

    public INDArray compress(INDArray array, String algorithm) {
        if (!this.codecs.containsKey(algorithm = algorithm.toUpperCase())) {
            throw new RuntimeException("Non-existent compression algorithm requested: [" + algorithm + "]");
        }
        return this.codecs.get(algorithm).compress(array);
    }

    public void compressi(INDArray array, String algorithm) {
        if (!this.codecs.containsKey(algorithm = algorithm.toUpperCase())) {
            throw new RuntimeException("Non-existent compression algorithm requested: [" + algorithm + "]");
        }
        this.codecs.get(algorithm).compressi(array);
    }

    public DataBuffer decompress(DataBuffer buffer, DataType targetType) {
        if (buffer.dataType() != DataType.COMPRESSED) {
            throw new IllegalStateException("You can't decompress DataBuffer with dataType of: " + buffer.dataType());
        }
        CompressedDataBuffer comp = (CompressedDataBuffer)buffer;
        CompressionDescriptor descriptor = comp.getCompressionDescriptor();
        if (!this.codecs.containsKey(descriptor.getCompressionAlgorithm())) {
            throw new RuntimeException("Non-existent compression algorithm requested: [" + descriptor.getCompressionAlgorithm() + "]");
        }
        return this.codecs.get(descriptor.getCompressionAlgorithm()).decompress(buffer, targetType);
    }

    public NDArrayCompressor getCompressor(@NonNull String name) {
        if (name == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        return this.codecs.get(name);
    }

    public INDArray decompress(INDArray array) {
        if (array.data().dataType() != DataType.COMPRESSED) {
            return array;
        }
        CompressedDataBuffer comp = (CompressedDataBuffer)array.data();
        CompressionDescriptor descriptor = comp.getCompressionDescriptor();
        if (!this.codecs.containsKey(descriptor.getCompressionAlgorithm())) {
            throw new RuntimeException("Non-existent compression algorithm requested: [" + descriptor.getCompressionAlgorithm() + "]");
        }
        return this.codecs.get(descriptor.getCompressionAlgorithm()).decompress(array);
    }

    public void decompressi(INDArray array) {
        if (array.data().dataType() != DataType.COMPRESSED) {
            return;
        }
        CompressedDataBuffer comp = (CompressedDataBuffer)array.data();
        CompressionDescriptor descriptor = comp.getCompressionDescriptor();
        if (!this.codecs.containsKey(descriptor.getCompressionAlgorithm())) {
            throw new RuntimeException("Non-existent compression algorithm requested: [" + descriptor.getCompressionAlgorithm() + "]");
        }
        this.codecs.get(descriptor.getCompressionAlgorithm()).decompressi(array);
    }

    public void autoDecompress(INDArray ... arrays) {
        for (INDArray array : arrays) {
            this.autoDecompress(array);
        }
    }

    public void autoDecompress(INDArray array) {
        if (array.isCompressed()) {
            this.decompressi(array);
        }
    }

    public INDArray compress(float[] array) {
        return this.codecs.get(this.defaultCompression).compress(array);
    }

    public INDArray compress(double[] array) {
        return this.codecs.get(this.defaultCompression).compress(array);
    }
}

