package org.tensorflow.lite.examples.transfer.api;

import java.io.Closeable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.TreeMap;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.Tensor;

/* loaded from: classes2.dex */
class LiteTrainHeadModel implements Closeable {
    private static final int FLOAT_BYTES = 4;
    private final LiteModelWrapper modelWrapper;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LiteTrainHeadModel(LiteModelWrapper liteModelWrapper) {
        this.modelWrapper = liteModelWrapper;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float calculateGradients(ByteBuffer byteBuffer, ByteBuffer byteBuffer2, ByteBuffer[] byteBufferArr, ByteBuffer[] byteBufferArr2) {
        if (byteBufferArr.length != byteBufferArr2.length) {
            throw new IllegalArgumentException(String.format("Parameter array size (%d) is different from gradient array size (%d)", Integer.valueOf(byteBufferArr.length), Integer.valueOf(byteBufferArr2.length)));
        }
        if (this.modelWrapper.getInterpreter().getOutputTensorCount() != byteBufferArr.length + 1) {
            throw new IllegalArgumentException(String.format("Model expected %d parameter tensors, but got %d", Integer.valueOf(this.modelWrapper.getInterpreter().getInputTensorCount() - 1), Integer.valueOf(byteBufferArr.length)));
        }
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(4);
        allocateDirect.order(ByteOrder.nativeOrder());
        TreeMap treeMap = new TreeMap();
        treeMap.put(0, allocateDirect);
        for (int i = 1; i < this.modelWrapper.getInterpreter().getOutputTensorCount(); i++) {
            treeMap.put(Integer.valueOf(i), byteBufferArr2[i - 1]);
        }
        Object[] objArr = new Object[byteBufferArr.length + 2];
        objArr[0] = byteBuffer;
        objArr[1] = byteBuffer2;
        System.arraycopy(byteBufferArr, 0, objArr, 2, byteBufferArr.length);
        this.modelWrapper.getInterpreter().runForMultipleInputsOutputs(objArr, treeMap);
        byteBuffer.rewind();
        byteBuffer2.rewind();
        for (ByteBuffer byteBuffer3 : byteBufferArr) {
            byteBuffer3.rewind();
        }
        for (ByteBuffer byteBuffer4 : byteBufferArr2) {
            byteBuffer4.rewind();
        }
        allocateDirect.rewind();
        return allocateDirect.getFloat();
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        this.modelWrapper.close();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getBatchSize() {
        return this.modelWrapper.getInterpreter().getInputTensor(0).shape()[0];
    }

    int[][] getParameterShapes() {
        Interpreter interpreter = this.modelWrapper.getInterpreter();
        int[][] iArr = new int[interpreter.getInputTensorCount() - 2];
        for (int i = 2; i < interpreter.getInputTensorCount(); i++) {
            Tensor inputTensor = interpreter.getInputTensor(i);
            int i2 = i - 2;
            iArr[i2] = new int[inputTensor.numDimensions()];
            System.arraycopy(inputTensor.shape(), 0, iArr[i2], 0, inputTensor.numDimensions());
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int[] getParameterSizes() {
        int[] iArr = new int[this.modelWrapper.getInterpreter().getInputTensorCount() - 2];
        for (int i = 2; i < this.modelWrapper.getInterpreter().getInputTensorCount(); i++) {
            iArr[i - 2] = this.modelWrapper.getInterpreter().getInputTensor(i).numElements();
        }
        return iArr;
    }
}
