package ru.ok.tensorflow.classification;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.util.Pair;
import androidx.annotation.NonNull;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ru.ok.tensorflow.entity.Detection;
import ru.ok.tensorflow.entity.PalmClass;
import ru.ok.tensorflow.entity.Recognition;
import ru.ok.tensorflow.tflite.ModelDataProvider;
import ru.ok.tensorflow.tflite.TFImageData;
import ru.ok.tensorflow.util.Logger;

/* loaded from: classes6.dex */
public class ClassifierWithRegression extends Classifier {
    private Logger LOGGER;
    private final TFImageData inputData;
    protected float[][][][] outputClasses;
    private final float[][][][] outputLocations;
    private final HashMap<Integer, Object> outputMap;

    public ClassifierWithRegression(Context context, @NonNull ModelDataProvider modelDataProvider, List<Pair<PalmClass, Float>> list, float f2, float f3) {
        super(context, modelDataProvider, list, f2, f3);
        this.LOGGER = new Logger();
        int[] d2 = this.interpreterWrapper.getOutputTensor(0).d();
        this.outputClasses = (float[][][][]) Array.newInstance((Class<?>) float.class, d2[0], d2[1], d2[2], d2[3]);
        int[] d3 = this.interpreterWrapper.getOutputTensor(1).d();
        this.outputLocations = (float[][][][]) Array.newInstance((Class<?>) float.class, d3[0], d3[1], d3[2], d3[3]);
        HashMap<Integer, Object> hashMap = new HashMap<>();
        this.outputMap = hashMap;
        hashMap.put(0, this.outputClasses);
        this.outputMap.put(1, this.outputLocations);
        this.inputData = new TFImageData(this.cropHeight, this.cropWidth, false, false);
    }

    private void runNetwork(Bitmap bitmap) {
        this.inputData.fromBitmap(bitmap);
        this.interpreterWrapper.runForMultipleInputsOutputs(new Object[]{this.inputData.buffer}, this.outputMap);
    }

    @Override // ru.ok.tensorflow.classification.Classifier
    public List<Recognition> classify(Bitmap bitmap, List<Detection> list) {
        Pair<PalmClass, Float> max;
        ArrayList arrayList = new ArrayList(list.size());
        for (Detection detection : list) {
            Bitmap extractCrop = detection.extractCrop(bitmap, this.cropWidth, this.cropHeight, this.cropTranslationFactor, this.cropScaleFactor, true);
            this.crop = extractCrop;
            runNetwork(extractCrop);
            Map<PalmClass, Float> classes = getClasses(this.outputClasses[0][0][0], this.classesWithThresholds);
            if (classes.containsKey(PalmClass.NOT_HAND)) {
                PalmClass palmClass = PalmClass.NOT_HAND;
                max = new Pair<>(palmClass, classes.get(palmClass));
            } else {
                max = getMax(classes);
            }
            Pair<PalmClass, Float> pair = max;
            Matrix transformation = detection.getTransformation(64, 64, this.cropTranslationFactor, this.cropScaleFactor, true);
            Matrix matrix = new Matrix();
            transformation.invert(matrix);
            float[] fArr = this.outputLocations[0][0][0];
            matrix.mapPoints(fArr);
            arrayList.add(new Recognition((PalmClass) pair.first, (Float) pair.second, detection.addLocations(fArr)));
        }
        return arrayList;
    }
}
