package ai.onnxruntime;

import X.x;
import ai.onnxruntime.OrtSession;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/* loaded from: classes.dex */
public final class OrtTrainingSession implements AutoCloseable {
    private final OrtAllocator allocator;
    private final OrtCheckpointState checkpoint;
    private boolean closed;
    private final Set<String> evalInputNames;
    private final Set<String> evalOutputNames;
    private final String evalPath;
    private final long nativeHandle;
    private final String optimizerPath;
    private final Set<String> trainInputNames;
    private final Set<String> trainOutputNames;
    private final String trainPath;

    /* loaded from: classes.dex */
    public static final class OrtCheckpointState implements AutoCloseable {
        final long nativeHandle;

        public OrtCheckpointState(long j6) {
            this.nativeHandle = j6;
        }

        private native void addProperty(long j6, long j7, long j8, String str, float f6);

        private native void addProperty(long j6, long j7, long j8, String str, int i6);

        private native void addProperty(long j6, long j7, long j8, String str, String str2);

        private native void close(long j6, long j7);

        private native float getFloatProperty(long j6, long j7, long j8, long j10, String str);

        private native int getIntProperty(long j6, long j7, long j8, long j10, String str);

        private native String getStringProperty(long j6, long j7, long j8, long j10, String str);

        private static native long loadCheckpoint(long j6, long j7, String str);

        public static OrtCheckpointState loadCheckpoint(String str) {
            if (!OnnxRuntime.trainingEnabled) {
                throw new IllegalStateException("Training is not enabled in this build of ONNX Runtime.");
            }
            Objects.requireNonNull(str, "checkpoint path must not be null");
            return new OrtCheckpointState(loadCheckpoint(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, str));
        }

        public static OrtCheckpointState loadCheckpoint(Path path) {
            return loadCheckpoint(path.toString());
        }

        private native void saveCheckpoint(long j6, long j7, long j8, String str, boolean z6);

        public void addProperty(String str, float f6) {
            addProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, str, f6);
        }

        public void addProperty(String str, int i6) {
            addProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, str, i6);
        }

        public void addProperty(String str, String str2) {
            addProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, str, str2);
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            close(OnnxRuntime.ortTrainingApiHandle, this.nativeHandle);
        }

        public float getFloatProperty(OrtAllocator ortAllocator, String str) {
            return getFloatProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, ortAllocator.handle, str);
        }

        public int getIntProperty(OrtAllocator ortAllocator, String str) {
            return getIntProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, ortAllocator.handle, str);
        }

        public String getStringProperty(OrtAllocator ortAllocator, String str) {
            return getStringProperty(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, ortAllocator.handle, str);
        }

        public void saveCheckpoint(Path path, boolean z6) {
            Objects.requireNonNull(path, "checkpoint path must not be null");
            saveCheckpoint(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, path.toString(), z6);
        }
    }

    static {
        try {
            OnnxRuntime.init();
        } catch (IOException e6) {
            throw new RuntimeException("Failed to load onnx-runtime library", e6);
        }
    }

    private OrtTrainingSession(long j6, OrtAllocator ortAllocator, OrtCheckpointState ortCheckpointState, String str, String str2, String str3) {
        this.closed = false;
        this.nativeHandle = j6;
        this.allocator = ortAllocator;
        this.checkpoint = ortCheckpointState;
        this.trainPath = str;
        this.evalPath = str2;
        this.optimizerPath = str3;
        this.trainInputNames = Collections.unmodifiableSet(new LinkedHashSet(Arrays.asList(getTrainInputNames(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, j6, ortAllocator.handle))));
        this.trainOutputNames = Collections.unmodifiableSet(new LinkedHashSet(Arrays.asList(getTrainOutputNames(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, j6, ortAllocator.handle))));
        this.evalInputNames = Collections.unmodifiableSet(new LinkedHashSet(Arrays.asList(getEvalInputNames(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, j6, ortAllocator.handle))));
        this.evalOutputNames = Collections.unmodifiableSet(new LinkedHashSet(Arrays.asList(getEvalOutputNames(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, j6, ortAllocator.handle))));
    }

    public OrtTrainingSession(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, OrtSession.SessionOptions sessionOptions, OrtCheckpointState ortCheckpointState, String str, String str2, String str3) {
        this(createTrainingSession(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, ortEnvironment.getNativeHandle(), sessionOptions.getNativeHandle(), ortCheckpointState.nativeHandle, str, str2, str3), ortAllocator, ortCheckpointState, str, str2, str3);
    }

    private void checkClosed() {
        if (this.closed) {
            throw new IllegalStateException("Trying to use a closed OrtTrainingSession");
        }
    }

    private native void closeSession(long j6, long j7);

    private static native long createTrainingSession(long j6, long j7, long j8, long j10, long j11, String str, String str2, String str3);

    private native OnnxValue[] evalStep(long j6, long j7, long j8, long j10, String[] strArr, long[] jArr, long j11, String[] strArr2, long j12, long j13);

    private native void exportModelForInference(long j6, long j7, long j8, String str, long j10, String[] strArr);

    private native String[] getEvalInputNames(long j6, long j7, long j8, long j10);

    private native String[] getEvalOutputNames(long j6, long j7, long j8, long j10);

    private native float getLearningRate(long j6, long j7, long j8);

    private native String[] getTrainInputNames(long j6, long j7, long j8, long j10);

    private native String[] getTrainOutputNames(long j6, long j7, long j8, long j10);

    private native void lazyResetGrad(long j6, long j7, long j8);

    private native void optimizerStep(long j6, long j7, long j8, long j10);

    private native void registerLinearLRScheduler(long j6, long j7, long j8, long j10, long j11, float f6);

    private native void schedulerStep(long j6, long j7, long j8);

    private native void setLearningRate(long j6, long j7, long j8, float f6);

    public static void setSeed(long j6) {
        setSeed(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, j6);
    }

    private static native void setSeed(long j6, long j7, long j8);

    private native OnnxValue[] trainStep(long j6, long j7, long j8, long j10, String[] strArr, long[] jArr, long j11, String[] strArr2, long j12, long j13);

    public void addProperty(String str, float f6) {
        this.checkpoint.addProperty(str, f6);
    }

    public void addProperty(String str, int i6) {
        this.checkpoint.addProperty(str, i6);
    }

    public void addProperty(String str, String str2) {
        this.checkpoint.addProperty(str, str2);
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.closed) {
            throw new IllegalStateException("Trying to close an already closed OrtSession.");
        }
        closeSession(OnnxRuntime.ortTrainingApiHandle, this.nativeHandle);
        this.checkpoint.close();
        this.closed = true;
    }

    public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> map) {
        return evalStep(map, this.evalOutputNames, null);
    }

    public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> map, OrtSession.RunOptions runOptions) {
        return evalStep(map, this.evalOutputNames, runOptions);
    }

    public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> map, Set<String> set) {
        return evalStep(map, set, null);
    }

    public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> map, Set<String> set, OrtSession.RunOptions runOptions) {
        checkClosed();
        if ((map.isEmpty() && this.evalInputNames.size() != 0) || map.size() > this.evalInputNames.size()) {
            throw new OrtException("Unexpected number of inputs, expected [1," + this.evalInputNames.size() + ") found " + map.size());
        }
        if (set.isEmpty() || set.size() > this.evalOutputNames.size()) {
            throw new OrtException("Unexpected number of requestedOutputs, expected [1," + this.evalOutputNames.size() + ") found " + set.size());
        }
        int size = map.size();
        String[] strArr = new String[size];
        long[] jArr = new long[map.size()];
        int i6 = 0;
        int i7 = 0;
        for (Map.Entry<String, ? extends OnnxTensorLike> entry : map.entrySet()) {
            if (!this.evalInputNames.contains(entry.getKey())) {
                throw new OrtException("Unknown input name " + entry.getKey() + ", expected one of " + this.evalInputNames.toString());
            }
            strArr[i7] = entry.getKey();
            jArr[i7] = entry.getValue().getNativeHandle();
            i7++;
        }
        int size2 = set.size();
        String[] strArr2 = new String[size2];
        for (String str : set) {
            if (!this.evalOutputNames.contains(str)) {
                StringBuilder y3 = x.y("Unknown output name ", str, ", expected one of ");
                y3.append(this.evalOutputNames.toString());
                throw new OrtException(y3.toString());
            }
            strArr2[i6] = str;
            i6++;
        }
        return new OrtSession.Result(strArr2, evalStep(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, this.allocator.handle, strArr, jArr, size, strArr2, size2, runOptions == null ? 0L : runOptions.getNativeHandle()));
    }

    public void exportModelForInference(Path path, String[] strArr) {
        checkClosed();
        if (strArr.length == 0) {
            throw new IllegalArgumentException("Requires at least one output name");
        }
        exportModelForInference(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, path.toString(), strArr.length, strArr);
    }

    public Set<String> getEvalInputNames() {
        return this.evalInputNames;
    }

    public Set<String> getEvalOutputNames() {
        return this.evalOutputNames;
    }

    public float getFloatProperty(String str) {
        return this.checkpoint.getFloatProperty(this.allocator, str);
    }

    public int getIntProperty(String str) {
        return this.checkpoint.getIntProperty(this.allocator, str);
    }

    public float getLearningRate() {
        checkClosed();
        return getLearningRate(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle);
    }

    public String getStringProperty(String str) {
        return this.checkpoint.getStringProperty(this.allocator, str);
    }

    public Set<String> getTrainInputNames() {
        return this.trainInputNames;
    }

    public Set<String> getTrainOutputNames() {
        return this.trainOutputNames;
    }

    public void lazyResetGrad() {
        checkClosed();
        lazyResetGrad(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle);
    }

    public void optimizerStep() {
        optimizerStep(null);
    }

    public void optimizerStep(OrtSession.RunOptions runOptions) {
        checkClosed();
        optimizerStep(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, runOptions == null ? 0L : runOptions.getNativeHandle());
    }

    public void registerLinearLRScheduler(long j6, long j7, float f6) {
        registerLinearLRScheduler(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, j6, j7, f6);
    }

    public void saveCheckpoint(Path path, boolean z6) {
        checkClosed();
        this.checkpoint.saveCheckpoint(path, z6);
    }

    public void schedulerStep() {
        checkClosed();
        schedulerStep(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle);
    }

    public void setLearningRate(float f6) {
        checkClosed();
        setLearningRate(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, f6);
    }

    public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> map) {
        return trainStep(map, this.trainOutputNames, null);
    }

    public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> map, OrtSession.RunOptions runOptions) {
        return trainStep(map, this.trainOutputNames, runOptions);
    }

    public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> map, Set<String> set) {
        return trainStep(map, set, null);
    }

    public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> map, Set<String> set, OrtSession.RunOptions runOptions) {
        checkClosed();
        if ((map.isEmpty() && this.trainInputNames.size() != 0) || map.size() > this.trainInputNames.size()) {
            throw new OrtException("Unexpected number of inputs, expected [1," + this.trainInputNames.size() + ") found " + map.size());
        }
        if (set.isEmpty() || set.size() > this.trainOutputNames.size()) {
            throw new OrtException("Unexpected number of requestedOutputs, expected [1," + this.trainOutputNames.size() + ") found " + set.size());
        }
        int size = map.size();
        String[] strArr = new String[size];
        long[] jArr = new long[map.size()];
        int i6 = 0;
        int i7 = 0;
        for (Map.Entry<String, ? extends OnnxTensorLike> entry : map.entrySet()) {
            if (!this.trainInputNames.contains(entry.getKey())) {
                throw new OrtException("Unknown input name " + entry.getKey() + ", expected one of " + this.trainInputNames);
            }
            strArr[i7] = entry.getKey();
            jArr[i7] = entry.getValue().getNativeHandle();
            i7++;
        }
        int size2 = set.size();
        String[] strArr2 = new String[size2];
        for (String str : set) {
            if (!this.trainOutputNames.contains(str)) {
                StringBuilder y3 = x.y("Unknown output name ", str, ", expected one of ");
                y3.append(this.trainOutputNames.toString());
                throw new OrtException(y3.toString());
            }
            strArr2[i6] = str;
            i6++;
        }
        return new OrtSession.Result(strArr2, trainStep(OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, this.nativeHandle, this.allocator.handle, strArr, jArr, size, strArr2, size2, runOptions == null ? 0L : runOptions.getNativeHandle()));
    }
}
