package org.nd4j.autodiff.samediff.array;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.function.Supplier;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/nd4j/autodiff/samediff/array/OptimizedGraphArrayHolder.class */
public class OptimizedGraphArrayHolder implements ArrayHolder {
    private final ArrayHolder underlyingHolder;
    private final Map<String, Supplier<INDArray>> functions = new HashMap();

    public OptimizedGraphArrayHolder(ArrayHolder arrayHolder) {
        this.underlyingHolder = arrayHolder;
    }

    public void setFunction(String str, Supplier<INDArray> supplier) {
        if (this.underlyingHolder.hasArray(str)) {
            this.underlyingHolder.removeArray(str);
        }
        this.functions.put(str, supplier);
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public boolean hasArray(String str) {
        return this.functions.containsKey(str) || this.underlyingHolder.hasArray(str);
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public INDArray getArray(String str) {
        return this.functions.containsKey(str) ? (INDArray) this.functions.get(str).get() : this.underlyingHolder.getArray(str);
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public void setArray(String str, INDArray iNDArray) {
        Preconditions.checkState(!this.functions.containsKey(str), "Cannot set array when existing array is only accessible via a function");
        this.underlyingHolder.setArray(str, iNDArray);
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public INDArray removeArray(String str) {
        Supplier<INDArray> remove = this.functions.remove(str);
        return remove != null ? (INDArray) remove.get() : this.underlyingHolder.removeArray(str);
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public int size() {
        return this.underlyingHolder.size() + this.functions.size();
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public void initFrom(ArrayHolder arrayHolder) {
        this.underlyingHolder.initFrom(arrayHolder);
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public Collection<String> arrayNames() {
        HashSet hashSet = new HashSet();
        hashSet.addAll(this.underlyingHolder.arrayNames());
        hashSet.addAll(this.functions.keySet());
        return hashSet;
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public void rename(String str, String str2) {
        if (this.functions.containsKey(str)) {
            this.functions.put(str2, this.functions.remove(str));
        } else {
            this.underlyingHolder.rename(str, str2);
        }
    }
}
