package com.linecorp.yflkit;

import java.lang.reflect.Array;
import java.util.HashMap;

/* loaded from: classes7.dex */
public class YFLTrainer implements AutoCloseable {
    private static final String TAG = "YFLKit";
    private final String inferModelPath;
    private final String[] inputNames;
    private final long nativeHandle;
    private final String outputModelPath;
    private final String[] outputNames;
    private final String outputUploadResultPath;
    private final String trainModelPath;

    /* loaded from: classes7.dex */
    public interface EventListener {
        void onEvaluateComplete(int i15, int i16);

        void onEvaluateLoss(YFLValue yFLValue, YFLValue yFLValue2, YFLValue yFLValue3);

        void onTrainLoss(YFLValue yFLValue, YFLValue yFLValue2, YFLValue yFLValue3, int i15, int i16);

        void onTrainingComplete();

        void onTrainingError(int i15, String str);
    }

    public YFLTrainer(String str, String str2, String str3, String str4, String[] strArr, String[] strArr2, YFLConfiguration yFLConfiguration, YFLBatchProvider yFLBatchProvider, YFLBatchProvider yFLBatchProvider2, YFLBatchProvider yFLBatchProvider3, YFLBatchProvider yFLBatchProvider4, EventListener eventListener) {
        long create = create();
        this.nativeHandle = create;
        this.inferModelPath = str;
        this.trainModelPath = str2;
        this.outputModelPath = str3;
        this.outputUploadResultPath = str4;
        this.inputNames = strArr;
        this.outputNames = strArr2;
        yFLBatchProvider.getCount();
        yFLBatchProvider2.getCount();
        setup(create, str, str2, str3, str4, strArr, strArr2, yFLConfiguration.getNativeHandle(), createDataHandles(strArr, yFLBatchProvider), createDataHandles(strArr2, yFLBatchProvider2), yFLBatchProvider.getCount(), createDataHandles(strArr, yFLBatchProvider3), createDataHandles(strArr2, yFLBatchProvider4), yFLBatchProvider3 != null ? yFLBatchProvider3.getCount() : 0, eventListener);
    }

    private static native void close(long j15);

    private static native long create();

    private long[][] createDataHandles(String[] strArr, YFLBatchProvider yFLBatchProvider) {
        if (yFLBatchProvider == null) {
            return null;
        }
        long[][] jArr = (long[][]) Array.newInstance((Class<?>) Long.TYPE, yFLBatchProvider.getCount(), strArr.length);
        for (int i15 = 0; i15 < yFLBatchProvider.getCount(); i15++) {
            HashMap<String, YFLTensor> tensorMap = yFLBatchProvider.getTensorMap(i15);
            long[] jArr2 = new long[strArr.length];
            for (int i16 = 0; i16 < strArr.length; i16++) {
                String str = strArr[i16];
                if (tensorMap.containsKey(str)) {
                    jArr2[i16] = tensorMap.get(str).getNativeHandle();
                } else if (tensorMap.containsKey("tensor")) {
                    jArr2[i16] = tensorMap.get("tensor").getNativeHandle();
                }
            }
            jArr[i15] = jArr2;
        }
        return jArr;
    }

    private static native void setup(long j15, String str, String str2, String str3, String str4, String[] strArr, String[] strArr2, long j16, long[][] jArr, long[][] jArr2, int i15, long[][] jArr3, long[][] jArr4, int i16, Object obj);

    private static native void startTraining(long j15, boolean z15);

    private static native void stopTraining(long j15);

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        close(this.nativeHandle);
    }

    public void run() {
        startTraining(this.nativeHandle, false);
    }

    public void runSync() {
        startTraining(this.nativeHandle, true);
    }

    public void stop() {
        stopTraining(this.nativeHandle);
    }
}
