package org.tensorflow.lite.examples.transfer.api;

import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.GatheringByteChannel;
import java.nio.channels.ScatteringByteChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.tensorflow.lite.examples.transfer.api.TransferLearningModel;

/* loaded from: classes2.dex */
public final class TransferLearningModel implements Closeable {
    private static final int FLOAT_BYTES = 4;
    private static final int NUM_THREADS = Math.max(1, Runtime.getRuntime().availableProcessors() - 1);
    private final LiteBottleneckModel bottleneckModel;
    private final int[] bottleneckShape;
    private final String[] classesByIdx;
    private final ByteBuffer inferenceBottleneck;
    private final LiteInferenceModel inferenceModel;
    private final LiteInitializeModel initializeModel;
    private final ByteBuffer[] modelGradients;
    private ByteBuffer[] modelParameters;
    private ByteBuffer[] nextModelParameters;
    private ByteBuffer[] nextOptimizerState;
    private final LiteOptimizerModel optimizerModel;
    private ByteBuffer[] optimizerState;
    private final LiteTrainHeadModel trainHeadModel;
    private final ByteBuffer trainingBatchBottlenecks;
    private final ByteBuffer trainingBatchClasses;
    private final ByteBuffer zeroBatchClasses;
    private final List<TrainingSample> trainingSamples = new ArrayList();
    private final ExecutorService executor = Executors.newFixedThreadPool(NUM_THREADS);
    private final Lock trainingLock = new ReentrantLock();
    private final ReadWriteLock parameterLock = new ReentrantReadWriteLock();
    private final Lock inferenceLock = new ReentrantLock();
    private volatile boolean isTerminating = false;
    private final Map<String, Integer> classes = new TreeMap();

    /* loaded from: classes2.dex */
    public interface LossConsumer {
        void onLoss(int i, float f);
    }

    /* loaded from: classes2.dex */
    public static class Prediction {
        private final String className;
        private final float confidence;

        public Prediction(String str, float f) {
            this.className = str;
            this.confidence = f;
        }

        public String getClassName() {
            return this.className;
        }

        public float getConfidence() {
            return this.confidence;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes2.dex */
    public static class TrainingSample {
        final ByteBuffer bottleneck;
        final String className;

        TrainingSample(ByteBuffer byteBuffer, String str) {
            this.bottleneck = byteBuffer;
            this.className = str;
        }
    }

    public TransferLearningModel(ModelLoader modelLoader, Collection<String> collection) {
        this.classesByIdx = (String[]) collection.toArray(new String[0]);
        for (int i = 0; i < collection.size(); i++) {
            this.classes.put(this.classesByIdx[i], Integer.valueOf(i));
        }
        try {
            this.initializeModel = new LiteInitializeModel(modelLoader.loadInitializeModel());
            LiteBottleneckModel liteBottleneckModel = new LiteBottleneckModel(modelLoader.loadBaseModel());
            this.bottleneckModel = liteBottleneckModel;
            LiteTrainHeadModel liteTrainHeadModel = new LiteTrainHeadModel(modelLoader.loadTrainModel());
            this.trainHeadModel = liteTrainHeadModel;
            this.inferenceModel = new LiteInferenceModel(modelLoader.loadInferenceModel(), collection.size());
            this.optimizerModel = new LiteOptimizerModel(modelLoader.loadOptimizerModel());
            this.bottleneckShape = liteBottleneckModel.getBottleneckShape();
            int[] parameterSizes = liteTrainHeadModel.getParameterSizes();
            this.modelParameters = new ByteBuffer[parameterSizes.length];
            this.modelGradients = new ByteBuffer[parameterSizes.length];
            this.nextModelParameters = new ByteBuffer[parameterSizes.length];
            for (int i2 = 0; i2 < parameterSizes.length; i2++) {
                int i3 = parameterSizes[i2] * 4;
                this.modelParameters[i2] = allocateBuffer(i3);
                this.modelGradients[i2] = allocateBuffer(i3);
                this.nextModelParameters[i2] = allocateBuffer(i3);
            }
            this.initializeModel.initializeParameters(this.modelParameters);
            int[] stateElementSizes = this.optimizerModel.stateElementSizes();
            this.optimizerState = new ByteBuffer[stateElementSizes.length];
            this.nextOptimizerState = new ByteBuffer[stateElementSizes.length];
            int i4 = 0;
            while (true) {
                ByteBuffer[] byteBufferArr = this.optimizerState;
                if (i4 >= byteBufferArr.length) {
                    break;
                }
                int i5 = stateElementSizes[i4] * 4;
                byteBufferArr[i4] = allocateBuffer(i5);
                this.nextOptimizerState[i4] = allocateBuffer(i5);
                fillBufferWithZeros(this.optimizerState[i4]);
                i4++;
            }
            this.trainingBatchBottlenecks = allocateBuffer(getTrainBatchSize() * numBottleneckFeatures() * 4);
            int trainBatchSize = getTrainBatchSize() * collection.size();
            int i6 = trainBatchSize * 4;
            this.trainingBatchClasses = allocateBuffer(i6);
            this.zeroBatchClasses = allocateBuffer(i6);
            for (int i7 = 0; i7 < trainBatchSize; i7++) {
                this.zeroBatchClasses.putFloat(0.0f);
            }
            this.zeroBatchClasses.rewind();
            this.inferenceBottleneck = allocateBuffer(numBottleneckFeatures() * 4);
        } catch (IOException e) {
            throw new RuntimeException("Couldn't read underlying models for TransferLearningModel", e);
        }
    }

    private static ByteBuffer allocateBuffer(int i) {
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(i);
        allocateDirect.order(ByteOrder.nativeOrder());
        return allocateDirect;
    }

    private void checkNotTerminating() {
        if (this.isTerminating) {
            throw new IllegalStateException("Cannot operate on terminating model");
        }
    }

    private static void fillBufferWithZeros(ByteBuffer byteBuffer) {
        int capacity = byteBuffer.capacity();
        int min = Math.min(1024, capacity);
        ByteBuffer allocateBuffer = allocateBuffer(min);
        for (int i = 0; i < min; i++) {
            allocateBuffer.put((byte) 0);
        }
        allocateBuffer.rewind();
        for (int i2 = 0; i2 < capacity / min; i2++) {
            byteBuffer.put(allocateBuffer);
        }
        for (int i3 = 0; i3 < capacity % min; i3++) {
            byteBuffer.put((byte) 0);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ int lambda$predict$2(Prediction prediction, Prediction prediction2) {
        return -Float.compare(prediction.confidence, prediction2.confidence);
    }

    private int numBottleneckFeatures() {
        int i = 1;
        for (int i2 : this.bottleneckShape) {
            i *= i2;
        }
        return i;
    }

    private Iterable<List<TrainingSample>> trainingBatches() {
        if (!this.trainingLock.tryLock()) {
            throw new RuntimeException("Thread calling trainingBatches() must hold the training lock");
        }
        this.trainingLock.unlock();
        Collections.shuffle(this.trainingSamples);
        return new Iterable() { // from class: org.tensorflow.lite.examples.transfer.api.TransferLearningModel$$ExternalSyntheticLambda1
            @Override // java.lang.Iterable
            public final Iterator iterator() {
                return TransferLearningModel.this.m1851x21c75b89();
            }
        };
    }

    public Future<Void> addSample(final float[] fArr, final String str) {
        checkNotTerminating();
        if (this.classes.containsKey(str)) {
            return this.executor.submit(new Callable() { // from class: org.tensorflow.lite.examples.transfer.api.TransferLearningModel$$ExternalSyntheticLambda2
                @Override // java.util.concurrent.Callable
                public final Object call() {
                    return TransferLearningModel.this.m1849x91710103(fArr, str);
                }
            });
        }
        throw new IllegalArgumentException(String.format("Class \"%s\" is not one of the classes recognized by the model", str));
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        this.isTerminating = true;
        this.executor.shutdownNow();
        this.inferenceLock.lock();
        try {
        } catch (InterruptedException unused) {
        } catch (Throwable th) {
            this.inferenceLock.unlock();
            throw th;
        }
        if (!this.executor.awaitTermination(5L, TimeUnit.SECONDS)) {
            throw new RuntimeException("Model thread pool failed to terminate");
        }
        this.initializeModel.close();
        this.bottleneckModel.close();
        this.trainHeadModel.close();
        this.inferenceModel.close();
        this.optimizerModel.close();
        this.inferenceLock.unlock();
    }

    public int getNumberSamples() {
        return this.trainingSamples.size();
    }

    public int getTrainBatchSize() {
        return this.trainHeadModel.getBatchSize();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: lambda$addSample$0$org-tensorflow-lite-examples-transfer-api-TransferLearningModel, reason: not valid java name */
    public /* synthetic */ Void m1849x91710103(float[] fArr, String str) throws Exception {
        ByteBuffer allocateBuffer = allocateBuffer(fArr.length * 4);
        for (float f : fArr) {
            allocateBuffer.putFloat(f);
        }
        allocateBuffer.rewind();
        if (Thread.interrupted()) {
            return null;
        }
        ByteBuffer generateBottleneck = this.bottleneckModel.generateBottleneck(allocateBuffer, null);
        this.trainingLock.lockInterruptibly();
        try {
            this.trainingSamples.add(new TrainingSample(generateBottleneck, str));
            return null;
        } finally {
            this.trainingLock.unlock();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: lambda$train$1$org-tensorflow-lite-examples-transfer-api-TransferLearningModel, reason: not valid java name */
    public /* synthetic */ Void m1850x2ceb0ba1(int i, LossConsumer lossConsumer) throws Exception {
        this.trainingLock.lock();
        loop0: for (int i2 = 0; i2 < i; i2++) {
            try {
                float f = 0.0f;
                int i3 = 0;
                for (List<TrainingSample> list : trainingBatches()) {
                    if (Thread.interrupted()) {
                        break loop0;
                    }
                    this.trainingBatchClasses.put(this.zeroBatchClasses);
                    this.trainingBatchClasses.rewind();
                    this.zeroBatchClasses.rewind();
                    for (int i4 = 0; i4 < list.size(); i4++) {
                        TrainingSample trainingSample = list.get(i4);
                        this.trainingBatchBottlenecks.put(trainingSample.bottleneck);
                        trainingSample.bottleneck.rewind();
                        this.trainingBatchClasses.putFloat(((this.classes.size() * i4) + this.classes.get(trainingSample.className).intValue()) * 4, 1.0f);
                    }
                    this.trainingBatchBottlenecks.rewind();
                    f += this.trainHeadModel.calculateGradients(this.trainingBatchBottlenecks, this.trainingBatchClasses, this.modelParameters, this.modelGradients);
                    i3++;
                    this.optimizerModel.performStep(this.modelParameters, this.modelGradients, this.optimizerState, this.nextModelParameters, this.nextOptimizerState);
                    ByteBuffer[] byteBufferArr = this.optimizerState;
                    this.optimizerState = this.nextOptimizerState;
                    this.nextOptimizerState = byteBufferArr;
                    this.parameterLock.writeLock().lock();
                    try {
                        ByteBuffer[] byteBufferArr2 = this.modelParameters;
                        this.modelParameters = this.nextModelParameters;
                        this.nextModelParameters = byteBufferArr2;
                        this.parameterLock.writeLock().unlock();
                    } finally {
                    }
                }
                float f2 = f / i3;
                if (lossConsumer != null) {
                    lossConsumer.onLoss(i2, f2);
                }
            } catch (Throwable th) {
                this.trainingLock.unlock();
                throw th;
            }
        }
        this.trainingLock.unlock();
        return null;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: lambda$trainingBatches$3$org-tensorflow-lite-examples-transfer-api-TransferLearningModel, reason: not valid java name */
    public /* synthetic */ Iterator m1851x21c75b89() {
        return new Iterator<List<TrainingSample>>() { // from class: org.tensorflow.lite.examples.transfer.api.TransferLearningModel.1
            private int nextIndex = 0;

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.nextIndex < TransferLearningModel.this.trainingSamples.size();
            }

            @Override // java.util.Iterator
            public List<TrainingSample> next() {
                int i = this.nextIndex;
                int trainBatchSize = TransferLearningModel.this.getTrainBatchSize() + i;
                this.nextIndex = trainBatchSize;
                return trainBatchSize >= TransferLearningModel.this.trainingSamples.size() ? TransferLearningModel.this.trainingSamples.subList(TransferLearningModel.this.trainingSamples.size() - TransferLearningModel.this.getTrainBatchSize(), TransferLearningModel.this.trainingSamples.size()) : TransferLearningModel.this.trainingSamples.subList(i, trainBatchSize);
            }
        };
    }

    public void loadParameters(ScatteringByteChannel scatteringByteChannel) throws IOException {
        this.parameterLock.writeLock().lock();
        try {
            scatteringByteChannel.read(this.modelParameters);
            for (ByteBuffer byteBuffer : this.modelParameters) {
                byteBuffer.rewind();
            }
        } finally {
            this.parameterLock.writeLock().unlock();
        }
    }

    public Prediction[] predict(float[] fArr) {
        checkNotTerminating();
        this.inferenceLock.lock();
        try {
            if (this.isTerminating) {
                this.inferenceLock.unlock();
                return null;
            }
            ByteBuffer allocateBuffer = allocateBuffer(fArr.length * 4);
            for (float f : fArr) {
                allocateBuffer.putFloat(f);
            }
            allocateBuffer.rewind();
            ByteBuffer generateBottleneck = this.bottleneckModel.generateBottleneck(allocateBuffer, this.inferenceBottleneck);
            this.parameterLock.readLock().lock();
            try {
                float[] runInference = this.inferenceModel.runInference(generateBottleneck, this.modelParameters);
                this.parameterLock.readLock().unlock();
                Prediction[] predictionArr = new Prediction[this.classes.size()];
                for (int i = 0; i < this.classes.size(); i++) {
                    predictionArr[i] = new Prediction(this.classesByIdx[i], runInference[i]);
                }
                Arrays.sort(predictionArr, new Comparator() { // from class: org.tensorflow.lite.examples.transfer.api.TransferLearningModel$$ExternalSyntheticLambda3
                    @Override // java.util.Comparator
                    public final int compare(Object obj, Object obj2) {
                        return TransferLearningModel.lambda$predict$2((TransferLearningModel.Prediction) obj, (TransferLearningModel.Prediction) obj2);
                    }
                });
                return predictionArr;
            } catch (Throwable th) {
                this.parameterLock.readLock().unlock();
                throw th;
            }
        } finally {
            this.inferenceLock.unlock();
        }
    }

    public void saveParameters(GatheringByteChannel gatheringByteChannel) throws IOException {
        this.parameterLock.readLock().lock();
        try {
            gatheringByteChannel.write(this.modelParameters);
            for (ByteBuffer byteBuffer : this.modelParameters) {
                byteBuffer.rewind();
            }
        } finally {
            this.parameterLock.readLock().unlock();
        }
    }

    public Future<Void> train(final int i, final LossConsumer lossConsumer) {
        checkNotTerminating();
        if (this.trainingSamples.size() >= getTrainBatchSize()) {
            return this.executor.submit(new Callable() { // from class: org.tensorflow.lite.examples.transfer.api.TransferLearningModel$$ExternalSyntheticLambda0
                @Override // java.util.concurrent.Callable
                public final Object call() {
                    return TransferLearningModel.this.m1850x2ceb0ba1(i, lossConsumer);
                }
            });
        }
        throw new RuntimeException(String.format("Too few samples to start training: need %d, got %d", Integer.valueOf(getTrainBatchSize()), Integer.valueOf(this.trainingSamples.size())));
    }
}
