package com.shifthackz.aisdv1.feature.diffusion.ai.unet;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.providers.NNAPIFlags;
import android.graphics.Bitmap;
import android.util.Pair;
import androidx.core.app.NotificationCompat;
import com.shifthackz.aisdv1.core.common.appbuild.BuildVersion;
import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor;
import com.shifthackz.aisdv1.domain.preference.PreferenceManager;
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract;
import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.ArrayExtensionsKt;
import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.TensorExtensionsKt;
import com.shifthackz.aisdv1.feature.diffusion.ai.scheduler.EulerAncestralDiscreteLocalDiffusionScheduler;
import com.shifthackz.aisdv1.feature.diffusion.ai.vae.VaeDecoder;
import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionConfig;
import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag;
import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionTensor;
import com.shifthackz.aisdv1.feature.diffusion.environment.DeviceNNAPIFlagProvider;
import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider;
import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider;
import com.shifthackz.aisdv1.feature.diffusion.extensions.LocalDiffusionPathsKt;
import com.shifthackz.aisdv1.storage.db.persistent.contract.GenerationResultContract;
import java.nio.IntBuffer;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import kotlin.Metadata;
import kotlin.Unit;
import kotlin.collections.MapsKt;
import kotlin.jvm.internal.Intrinsics;
import kotlin.text.StringsKt;
import timber.log.Timber;

/* compiled from: UNet.kt */
@Metadata(d1 = {"\u0000\u0096\u0001\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0000\n\u0002\u0010$\n\u0002\u0010\u000e\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\t\n\u0000\n\u0002\u0010\u0007\n\u0002\b\u0002\n\u0002\u0010\u0011\n\u0002\u0010\u0014\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0006\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\b\u0005\b\u0000\u0018\u00002\u00020\u0001:\u0001;B/\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\u0006\u0010\b\u001a\u00020\t\u0012\u0006\u0010\n\u001a\u00020\u000b¢\u0006\u0004\b\f\u0010\rJ\u0006\u0010\u0019\u001a\u00020\u001aJ,\u0010\u001b\u001a\u000e\u0012\u0004\u0012\u00020\u001d\u0012\u0004\u0012\u00020\u001e0\u001c2\u0006\u0010\u001f\u001a\u00020\u001e2\u0006\u0010 \u001a\u00020\u001e2\u0006\u0010!\u001a\u00020\u001eH\u0002J4\u0010\"\u001a\u0006\u0012\u0002\b\u00030#2\u0006\u0010$\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u00172\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010%\u001a\u00020&2\u0006\u0010'\u001a\u00020(H\u0002J]\u0010)\u001a\u00020\u001a2\"\u0010*\u001a\u001e\u0012\u0010\u0012\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020,0+0+0+j\b\u0012\u0004\u0012\u00020,`-2\"\u0010.\u001a\u001e\u0012\u0010\u0012\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020,0+0+0+j\b\u0012\u0004\u0012\u00020,`-2\u0006\u0010/\u001a\u000200H\u0002¢\u0006\u0002\u00101J>\u00102\u001a\u00020\u001a2\u0006\u00103\u001a\u00020&2\u0006\u00104\u001a\u00020\u00172\u0006\u00105\u001a\u00020\u001e2\u0006\u0010/\u001a\u0002002\u0006\u0010$\u001a\u00020\u00172\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u0017J\u0012\u00106\u001a\u0002072\n\u00108\u001a\u0006\u0012\u0002\b\u00030#J\u0006\u00109\u001a\u00020\u001aJ\u0010\u0010:\u001a\u00020\u001a2\b\u0010\u0012\u001a\u0004\u0018\u00010\u0013R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n\u0000R\u000e\u0010\u0004\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n\u0000R\u000e\u0010\u0006\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n\u0000R\u000e\u0010\b\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n\u0000R\u000e\u0010\n\u001a\u00020\u000bX\u0082\u0004¢\u0006\u0002\n\u0000R\u0010\u0010\u000e\u001a\u0004\u0018\u00010\u000fX\u0082\u000e¢\u0006\u0002\n\u0000R\u000e\u0010\u0010\u001a\u00020\u0011X\u0082\u0004¢\u0006\u0002\n\u0000R\u0010\u0010\u0012\u001a\u0004\u0018\u00010\u0013X\u0082\u000e¢\u0006\u0002\n\u0000R\u0010\u0010\u0014\u001a\u0004\u0018\u00010\u0015X\u0082\u000e¢\u0006\u0002\n\u0000R\u000e\u0010\u0016\u001a\u00020\u0017X\u0082\u000e¢\u0006\u0002\n\u0000R\u000e\u0010\u0018\u001a\u00020\u0017X\u0082\u000e¢\u0006\u0002\n\u0000¨\u0006<"}, d2 = {"Lcom/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet;", "", "deviceNNAPIFlagProvider", "Lcom/shifthackz/aisdv1/feature/diffusion/environment/DeviceNNAPIFlagProvider;", "ortEnvironmentProvider", "Lcom/shifthackz/aisdv1/feature/diffusion/environment/OrtEnvironmentProvider;", "fileProviderDescriptor", "Lcom/shifthackz/aisdv1/core/common/file/FileProviderDescriptor;", "localModelIdProvider", "Lcom/shifthackz/aisdv1/feature/diffusion/environment/LocalModelIdProvider;", "preferenceManager", "Lcom/shifthackz/aisdv1/domain/preference/PreferenceManager;", "<init>", "(Lcom/shifthackz/aisdv1/feature/diffusion/environment/DeviceNNAPIFlagProvider;Lcom/shifthackz/aisdv1/feature/diffusion/environment/OrtEnvironmentProvider;Lcom/shifthackz/aisdv1/core/common/file/FileProviderDescriptor;Lcom/shifthackz/aisdv1/feature/diffusion/environment/LocalModelIdProvider;Lcom/shifthackz/aisdv1/domain/preference/PreferenceManager;)V", "decoder", "Lcom/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder;", "random", "Ljava/util/Random;", "callback", "Lcom/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet$Callback;", "session", "Lai/onnxruntime/OrtSession;", GenerationResultContract.WIDTH, "", GenerationResultContract.HEIGHT, "initialize", "", "createUNetModelInput", "", "", "Lai/onnxruntime/OnnxTensor;", "encoderHiddenStates", LocalDiffusionContract.KEY_SAMPLE, "timeStep", "generateLatentSample", "Lcom/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionTensor;", "batchSize", GenerationResultContract.SEED, "", "initNoiseSigma", "", "performGuidance", "noisePrediction", "", "", "Lcom/shifthackz/aisdv1/feature/diffusion/entity/Array3D;", "noisePredictionText", "guidanceScale", "", "([[[[F[[[[FD)V", "inference", "seedNum", "numInferenceSteps", "textEmbeddings", "decode", "Landroid/graphics/Bitmap;", "latents", "close", "setCallback", "Callback", "diffusion_release"}, k = 1, mv = {2, 1, 0}, xi = 48)
/* loaded from: classes3.dex */
public final class UNet {
    private Callback callback;
    private VaeDecoder decoder;
    private final DeviceNNAPIFlagProvider deviceNNAPIFlagProvider;
    private final FileProviderDescriptor fileProviderDescriptor;
    private int height;
    private final LocalModelIdProvider localModelIdProvider;
    private final OrtEnvironmentProvider ortEnvironmentProvider;
    private final PreferenceManager preferenceManager;
    private final Random random;
    private OrtSession session;
    private int width;

    /* compiled from: UNet.kt */
    @Metadata(d1 = {"\u0000\u001e\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0010\u0002\n\u0000\n\u0002\u0010\b\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0000\bf\u0018\u00002\u00020\u0001J\u0018\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u00052\u0006\u0010\u0006\u001a\u00020\u0005H&J\u001a\u0010\u0007\u001a\u00020\u00032\u0006\u0010\b\u001a\u00020\u00052\b\u0010\t\u001a\u0004\u0018\u00010\nH&¨\u0006\u000b"}, d2 = {"Lcom/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet$Callback;", "", "onStep", "", "maxStep", "", "step", "onBuildImage", NotificationCompat.CATEGORY_STATUS, "bitmap", "Landroid/graphics/Bitmap;", "diffusion_release"}, k = 1, mv = {2, 1, 0}, xi = 48)
    /* loaded from: classes3.dex */
    public interface Callback {
        void onBuildImage(int status, Bitmap bitmap);

        void onStep(int maxStep, int step);
    }

    public UNet(DeviceNNAPIFlagProvider deviceNNAPIFlagProvider, OrtEnvironmentProvider ortEnvironmentProvider, FileProviderDescriptor fileProviderDescriptor, LocalModelIdProvider localModelIdProvider, PreferenceManager preferenceManager) {
        Intrinsics.checkNotNullParameter(deviceNNAPIFlagProvider, "deviceNNAPIFlagProvider");
        Intrinsics.checkNotNullParameter(ortEnvironmentProvider, "ortEnvironmentProvider");
        Intrinsics.checkNotNullParameter(fileProviderDescriptor, "fileProviderDescriptor");
        Intrinsics.checkNotNullParameter(localModelIdProvider, "localModelIdProvider");
        Intrinsics.checkNotNullParameter(preferenceManager, "preferenceManager");
        this.deviceNNAPIFlagProvider = deviceNNAPIFlagProvider;
        this.ortEnvironmentProvider = ortEnvironmentProvider;
        this.fileProviderDescriptor = fileProviderDescriptor;
        this.localModelIdProvider = localModelIdProvider;
        this.preferenceManager = preferenceManager;
        this.random = new Random();
        this.width = 384;
        this.height = 384;
    }

    private final Map<String, OnnxTensor> createUNetModelInput(OnnxTensor encoderHiddenStates, OnnxTensor sample, OnnxTensor timeStep) {
        HashMap hashMap = new HashMap();
        hashMap.put(LocalDiffusionContract.KEY_ENCODER_HIDDEN_STATES, encoderHiddenStates);
        hashMap.put(LocalDiffusionContract.KEY_SAMPLE, sample);
        hashMap.put(LocalDiffusionContract.KEY_TIME_STEP, timeStep);
        return hashMap;
    }

    private final LocalDiffusionTensor<?> generateLatentSample(int batchSize, int height, int width, long seed, float initNoiseSigma) {
        int i;
        Random random = new Random(seed);
        float[][][][] fArr = new float[batchSize][][];
        int i2 = 0;
        while (true) {
            i = 4;
            if (i2 >= batchSize) {
                break;
            }
            float[][][] fArr2 = new float[4][];
            for (int i3 = 0; i3 < 4; i3++) {
                int i4 = height / 8;
                float[][] fArr3 = new float[i4];
                for (int i5 = 0; i5 < i4; i5++) {
                    fArr3[i5] = new float[width / 8];
                }
                fArr2[i3] = fArr3;
            }
            fArr[i2] = fArr2;
            i2++;
        }
        int i6 = 0;
        while (i6 < batchSize) {
            int i7 = 0;
            while (i7 < i) {
                int i8 = height / 8;
                for (int i9 = 0; i9 < i8; i9++) {
                    int i10 = width / 8;
                    int i11 = 0;
                    while (i11 < i10) {
                        int i12 = i7;
                        fArr[i6][i12][i9][i11] = (float) (Math.sqrt((-2.0f) * Math.log(random.nextDouble())) * Math.cos(random.nextDouble() * 6.283185307179586d) * initNoiseSigma);
                        i11++;
                        i7 = i12;
                    }
                }
                i7++;
                i = 4;
            }
            i6++;
            i = 4;
        }
        OnnxTensor createTensor = OnnxTensor.createTensor(this.ortEnvironmentProvider.getEnvironment(), fArr);
        Intrinsics.checkNotNullExpressionValue(createTensor, "createTensor(...)");
        return new LocalDiffusionTensor<>(createTensor, fArr, new long[]{batchSize, 4, height / 8, width / 8});
    }

    private final void performGuidance(float[][][][] noisePrediction, float[][][][] noisePredictionText, double guidanceScale) {
        long[] sizes = ArrayExtensionsKt.getSizes(noisePrediction);
        long j = sizes[0];
        long j2 = 0;
        while (j2 < j) {
            long j3 = sizes[1];
            long j4 = 0;
            while (j4 < j3) {
                long j5 = sizes[2];
                long j6 = 0;
                while (j6 < j5) {
                    long j7 = sizes[3];
                    long j8 = 0;
                    while (j8 < j7) {
                        int i = (int) j2;
                        long[] jArr = sizes;
                        int i2 = (int) j4;
                        long j9 = j;
                        int i3 = (int) j6;
                        float[] fArr = noisePrediction[i][i2][i3];
                        long j10 = j3;
                        int i4 = (int) j8;
                        float f = fArr[i4];
                        fArr[i4] = f + (((float) guidanceScale) * (noisePredictionText[i][i2][i3][i4] - f));
                        j8++;
                        sizes = jArr;
                        j = j9;
                        j3 = j10;
                        j5 = j5;
                        j2 = j2;
                    }
                    j6++;
                    sizes = sizes;
                    j = j;
                    j5 = j5;
                }
                j4++;
                sizes = sizes;
                j = j;
            }
            j2++;
            sizes = sizes;
            j = j;
        }
    }

    public final void close() {
        String name = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name, "getName(...)");
        String substringAfterLast$default = StringsKt.substringAfterLast$default(name, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default = StringsKt.substringBefore$default(substringAfterLast$default, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default).d("{LocalDiffusion} {uNet} {close} Closing session...", new Object[0]);
        OrtSession ortSession = this.session;
        if (ortSession != null) {
            ortSession.close();
        }
        VaeDecoder vaeDecoder = this.decoder;
        if (vaeDecoder != null) {
            vaeDecoder.close();
        }
        this.session = null;
        this.decoder = null;
        String name2 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name2, "getName(...)");
        String substringAfterLast$default2 = StringsKt.substringAfterLast$default(name2, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default2, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default2 = StringsKt.substringBefore$default(substringAfterLast$default2, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default2).d("{LocalDiffusion} {uNet} {close} Session closed successfully!", new Object[0]);
    }

    public final Bitmap decode(LocalDiffusionTensor<?> latents) {
        Intrinsics.checkNotNullParameter(latents, "latents");
        String str = "{LocalDiffusion} {uNet} {decode} Trying to decode latents: " + latents.hashCode();
        String name = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name, "getName(...)");
        String substringAfterLast$default = StringsKt.substringAfterLast$default(name, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default = StringsKt.substringBefore$default(substringAfterLast$default, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default).d(str, new Object[0]);
        float[] array = latents.getTensor().getFloatBuffer().array();
        Intrinsics.checkNotNullExpressionValue(array, "array(...)");
        LocalDiffusionTensor<?> multipleTensorsByFloat = TensorExtensionsKt.multipleTensorsByFloat(array, 5.4899807f, latents.getShape());
        HashMap hashMap = new HashMap();
        hashMap.put(LocalDiffusionContract.KEY_LATENT_SAMPLE, multipleTensorsByFloat.getTensor());
        VaeDecoder vaeDecoder = this.decoder;
        Intrinsics.checkNotNull(vaeDecoder);
        Object decode = vaeDecoder.decode(MapsKt.toMap(hashMap));
        VaeDecoder vaeDecoder2 = this.decoder;
        Intrinsics.checkNotNull(vaeDecoder2);
        Intrinsics.checkNotNull(decode, "null cannot be cast to non-null type kotlin.Array<kotlin.Array<kotlin.Array<kotlin.FloatArray>>>");
        Bitmap convertToImage = vaeDecoder2.convertToImage((float[][][][]) decode, this.width, this.height);
        String str2 = "{LocalDiffusion} {uNet} {decode} Bitmap generated successfully: " + convertToImage.hashCode();
        String name2 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name2, "getName(...)");
        String substringAfterLast$default2 = StringsKt.substringAfterLast$default(name2, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default2, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default2 = StringsKt.substringBefore$default(substringAfterLast$default2, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default2).d(str2, new Object[0]);
        return convertToImage;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public final void inference(long seedNum, int numInferenceSteps, OnnxTensor textEmbeddings, double guidanceScale, int batchSize, int width, int height) {
        Intrinsics.checkNotNullParameter(textEmbeddings, "textEmbeddings");
        String name = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name, "getName(...)");
        LocalDiffusionConfig localDiffusionConfig = null;
        Object[] objArr = 0;
        String substringAfterLast$default = StringsKt.substringAfterLast$default(name, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default = StringsKt.substringBefore$default(substringAfterLast$default, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default).d("{LocalDiffusion} {uNet} {inference} Trying to start inference:", new Object[0]);
        String str = "{LocalDiffusion} {uNet} {inference} - seed: " + seedNum;
        String name2 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name2, "getName(...)");
        String substringAfterLast$default2 = StringsKt.substringAfterLast$default(name2, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default2, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default2 = StringsKt.substringBefore$default(substringAfterLast$default2, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default2).d(str, new Object[0]);
        String str2 = "{LocalDiffusion} {uNet} {inference} - numInferenceSteps: " + numInferenceSteps;
        String name3 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name3, "getName(...)");
        String substringAfterLast$default3 = StringsKt.substringAfterLast$default(name3, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default3, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default3 = StringsKt.substringBefore$default(substringAfterLast$default3, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default3).d(str2, new Object[0]);
        String str3 = "{LocalDiffusion} {uNet} {inference} - textEmbeddings: " + textEmbeddings;
        String name4 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name4, "getName(...)");
        String substringAfterLast$default4 = StringsKt.substringAfterLast$default(name4, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default4, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default4 = StringsKt.substringBefore$default(substringAfterLast$default4, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default4).d(str3, new Object[0]);
        String str4 = "{LocalDiffusion} {uNet} {inference} - guidanceScale: " + guidanceScale;
        String name5 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name5, "getName(...)");
        String substringAfterLast$default5 = StringsKt.substringAfterLast$default(name5, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default5, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default5 = StringsKt.substringBefore$default(substringAfterLast$default5, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default5).d(str4, new Object[0]);
        String str5 = "{LocalDiffusion} {uNet} {inference} - batchSize: " + batchSize;
        String name6 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name6, "getName(...)");
        String substringAfterLast$default6 = StringsKt.substringAfterLast$default(name6, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default6, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default6 = StringsKt.substringBefore$default(substringAfterLast$default6, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default6).d(str5, new Object[0]);
        String str6 = "{LocalDiffusion} {uNet} {inference} - size: " + width + "x" + height;
        String name7 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name7, "getName(...)");
        String substringAfterLast$default7 = StringsKt.substringAfterLast$default(name7, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default7, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default7 = StringsKt.substringBefore$default(substringAfterLast$default7, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default7).d(str6, new Object[0]);
        this.width = width;
        this.height = height;
        EulerAncestralDiscreteLocalDiffusionScheduler eulerAncestralDiscreteLocalDiffusionScheduler = new EulerAncestralDiscreteLocalDiffusionScheduler(localDiffusionConfig, 1, objArr == true ? 1 : 0);
        String str7 = "{LocalDiffusion} {uNet} {inference} Initialized scheduler: " + eulerAncestralDiscreteLocalDiffusionScheduler;
        String name8 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name8, "getName(...)");
        String substringAfterLast$default8 = StringsKt.substringAfterLast$default(name8, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default8, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default8 = StringsKt.substringBefore$default(substringAfterLast$default8, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default8).d(str7, new Object[0]);
        int[] timeSteps = eulerAncestralDiscreteLocalDiffusionScheduler.setTimeSteps(numInferenceSteps);
        EulerAncestralDiscreteLocalDiffusionScheduler eulerAncestralDiscreteLocalDiffusionScheduler2 = eulerAncestralDiscreteLocalDiffusionScheduler;
        int i = 2;
        LocalDiffusionTensor<?> generateLatentSample = generateLatentSample(batchSize, height, width, seedNum <= 0 ? this.random.nextLong() : seedNum, (float) eulerAncestralDiscreteLocalDiffusionScheduler.getInitNoiseSigma());
        String str8 = "{LocalDiffusion} {uNet} {inference} Got latents: " + generateLatentSample.hashCode();
        String name9 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name9, "getName(...)");
        String substringAfterLast$default9 = StringsKt.substringAfterLast$default(name9, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default9, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default9 = StringsKt.substringBefore$default(substringAfterLast$default9, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default9).d(str8, new Object[0]);
        long j = height / 8;
        long j2 = width / 8;
        long[] jArr = {2, 4, j, j2};
        String str9 = "{LocalDiffusion} {uNet} {inference} Got shape: " + jArr;
        String name10 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name10, "getName(...)");
        String substringAfterLast$default10 = StringsKt.substringAfterLast$default(name10, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default10, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default10 = StringsKt.substringBefore$default(substringAfterLast$default10, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default10).d(str9, new Object[0]);
        String str10 = "{LocalDiffusion} {uNet} {inference} Starting steps processing! Total : " + timeSteps.length;
        String name11 = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name11, "getName(...)");
        String substringAfterLast$default11 = StringsKt.substringAfterLast$default(name11, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default11, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default11 = StringsKt.substringBefore$default(substringAfterLast$default11, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default11).d(str10, new Object[0]);
        int length = timeSteps.length;
        LocalDiffusionTensor<?> localDiffusionTensor = generateLatentSample;
        int i2 = 0;
        while (i2 < length) {
            float[] array = localDiffusionTensor.getTensor().getFloatBuffer().array();
            Intrinsics.checkNotNullExpressionValue(array, "array(...)");
            EulerAncestralDiscreteLocalDiffusionScheduler eulerAncestralDiscreteLocalDiffusionScheduler3 = eulerAncestralDiscreteLocalDiffusionScheduler2;
            LocalDiffusionTensor<?> scaleModelInput = eulerAncestralDiscreteLocalDiffusionScheduler3.scaleModelInput(TensorExtensionsKt.duplicate(array, jArr), i2);
            int i3 = length;
            String str11 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} ------------------";
            long[] jArr2 = jArr;
            String name12 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name12, "getName(...)");
            LocalDiffusionTensor<?> localDiffusionTensor2 = localDiffusionTensor;
            String substringAfterLast$default12 = StringsKt.substringAfterLast$default(name12, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            long j3 = j2;
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default12, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default12 = StringsKt.substringBefore$default(substringAfterLast$default12, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default12).d(str11, new Object[0]);
            String str12 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} Latent model input: " + scaleModelInput;
            String name13 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name13, "getName(...)");
            String substringAfterLast$default13 = StringsKt.substringAfterLast$default(name13, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default13, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default13 = StringsKt.substringBefore$default(substringAfterLast$default13, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default13).d(str12, new Object[0]);
            String str13 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} Notifying callback about step.";
            String name14 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name14, "getName(...)");
            String substringAfterLast$default14 = StringsKt.substringAfterLast$default(name14, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default14, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default14 = StringsKt.substringBefore$default(substringAfterLast$default14, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default14).d(str13, new Object[0]);
            Callback callback = this.callback;
            if (callback != null) {
                callback.onStep(timeSteps.length, i2);
                Unit unit = Unit.INSTANCE;
            }
            OnnxTensor tensor = scaleModelInput.getTensor();
            OnnxTensor createTensor = OnnxTensor.createTensor(this.ortEnvironmentProvider.getEnvironment(), IntBuffer.wrap(new int[]{timeSteps[i2]}), new long[]{1});
            Intrinsics.checkNotNullExpressionValue(createTensor, "createTensor(...)");
            i = 2;
            Map<String, OnnxTensor> createUNetModelInput = createUNetModelInput(textEmbeddings, tensor, createTensor);
            String str14 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} Got uNet model input: " + createUNetModelInput;
            String name15 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name15, "getName(...)");
            String substringAfterLast$default15 = StringsKt.substringAfterLast$default(name15, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            int[] iArr = timeSteps;
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default15, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default15 = StringsKt.substringBefore$default(substringAfterLast$default15, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default15).d(str14, new Object[0]);
            OrtSession ortSession = this.session;
            Intrinsics.checkNotNull(ortSession);
            OrtSession.Result run = ortSession.run(createUNetModelInput);
            String str15 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} Got result from uNet session: " + run;
            String name16 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name16, "getName(...)");
            String substringAfterLast$default16 = StringsKt.substringAfterLast$default(name16, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default16, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default16 = StringsKt.substringBefore$default(substringAfterLast$default16, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default16).d(str15, new Object[0]);
            Object value = run.get(0).getValue();
            Intrinsics.checkNotNull(value, "null cannot be cast to non-null type kotlin.Array<kotlin.Array<kotlin.Array<kotlin.FloatArray>>>");
            float[][][][] fArr = (float[][][][]) value;
            String str16 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} Trying to close ORT session in: " + run;
            String name17 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name17, "getName(...)");
            String substringAfterLast$default17 = StringsKt.substringAfterLast$default(name17, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default17, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default17 = StringsKt.substringBefore$default(substringAfterLast$default17, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default17).d(str16, new Object[0]);
            run.close();
            Pair<float[][][][], float[][][][]> splitTensor = TensorExtensionsKt.splitTensor(fArr, new long[]{1, 4, j, j3});
            float[][][][] fArr2 = (float[][][][]) splitTensor.first;
            float[][][][] fArr3 = (float[][][][]) splitTensor.second;
            String str17 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} Got split tensors with prediction:";
            String name18 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name18, "getName(...)");
            String substringAfterLast$default18 = StringsKt.substringAfterLast$default(name18, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            long j4 = j;
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default18, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default18 = StringsKt.substringBefore$default(substringAfterLast$default18, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default18).d(str17, new Object[0]);
            String str18 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} - splitTensors: " + splitTensor;
            String name19 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name19, "getName(...)");
            String substringAfterLast$default19 = StringsKt.substringAfterLast$default(name19, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default19, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default19 = StringsKt.substringBefore$default(substringAfterLast$default19, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default19).d(str18, new Object[0]);
            String str19 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} - noisePrediction: " + fArr2;
            String name20 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name20, "getName(...)");
            String substringAfterLast$default20 = StringsKt.substringAfterLast$default(name20, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default20, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default20 = StringsKt.substringBefore$default(substringAfterLast$default20, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default20).d(str19, new Object[0]);
            String str20 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} - noisePredictionText: " + fArr3;
            String name21 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name21, "getName(...)");
            String substringAfterLast$default21 = StringsKt.substringAfterLast$default(name21, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default21, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default21 = StringsKt.substringBefore$default(substringAfterLast$default21, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default21).d(str20, new Object[0]);
            String str21 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} Trying to preform guidance...";
            String name22 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name22, "getName(...)");
            String substringAfterLast$default22 = StringsKt.substringAfterLast$default(name22, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default22, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default22 = StringsKt.substringBefore$default(substringAfterLast$default22, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default22).d(str21, new Object[0]);
            Intrinsics.checkNotNull(fArr2);
            Intrinsics.checkNotNull(fArr3);
            performGuidance(fArr2, fArr3, guidanceScale);
            String str22 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} Guidance performed successfully!";
            String name23 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name23, "getName(...)");
            String substringAfterLast$default23 = StringsKt.substringAfterLast$default(name23, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default23, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default23 = StringsKt.substringBefore$default(substringAfterLast$default23, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default23).d(str22, new Object[0]);
            OnnxTensor createTensor2 = OnnxTensor.createTensor(this.ortEnvironmentProvider.getEnvironment(), fArr2);
            Intrinsics.checkNotNullExpressionValue(createTensor2, "createTensor(...)");
            localDiffusionTensor = eulerAncestralDiscreteLocalDiffusionScheduler3.step(new LocalDiffusionTensor<>(createTensor2, fArr2, ArrayExtensionsKt.getSizes(fArr2)), i2, localDiffusionTensor2);
            String str23 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} Finalized latents: " + localDiffusionTensor;
            String name24 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name24, "getName(...)");
            String substringAfterLast$default24 = StringsKt.substringAfterLast$default(name24, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default24, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default24 = StringsKt.substringBefore$default(substringAfterLast$default24, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default24).d(str23, new Object[0]);
            String str24 = "{LocalDiffusion} {uNet} {inference} {Step_" + i2 + "} ------------------";
            String name25 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name25, "getName(...)");
            String substringAfterLast$default25 = StringsKt.substringAfterLast$default(name25, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default25, (CharSequence) "$", false, 2, (Object) null)) {
                substringAfterLast$default25 = StringsKt.substringBefore$default(substringAfterLast$default25, "$", (String) null, 2, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default25).d(str24, new Object[0]);
            i2++;
            length = i3;
            jArr = jArr2;
            eulerAncestralDiscreteLocalDiffusionScheduler2 = eulerAncestralDiscreteLocalDiffusionScheduler3;
            timeSteps = iArr;
            j2 = j3;
            j = j4;
        }
        LocalDiffusionTensor<?> localDiffusionTensor3 = localDiffusionTensor;
        int[] iArr2 = timeSteps;
        Callback callback2 = this.callback;
        if (callback2 != null) {
            String name26 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name26, "getName(...)");
            String substringAfterLast$default26 = StringsKt.substringAfterLast$default(name26, BuildVersion.DELIMITER_VERSION, (String) null, i, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default26, (CharSequence) "$", false, i, (Object) null)) {
                substringAfterLast$default26 = StringsKt.substringBefore$default(substringAfterLast$default26, "$", (String) null, i, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default26).d("{LocalDiffusion} {uNet} {inference} Finalization / Flushing image...", new Object[0]);
            Callback callback3 = this.callback;
            if (callback3 != null) {
                callback3.onStep(iArr2.length, iArr2.length);
                Unit unit2 = Unit.INSTANCE;
            }
            Bitmap decode = decode(localDiffusionTensor3);
            String str25 = "{LocalDiffusion} {uNet} {inference} Finalization / Decoded bitmap: " + decode.hashCode();
            String name27 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name27, "getName(...)");
            String substringAfterLast$default27 = StringsKt.substringAfterLast$default(name27, BuildVersion.DELIMITER_VERSION, (String) null, i, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default27, (CharSequence) "$", false, i, (Object) null)) {
                substringAfterLast$default27 = StringsKt.substringBefore$default(substringAfterLast$default27, "$", (String) null, i, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default27).d(str25, new Object[0]);
            callback2.onBuildImage(0, decode);
            String name28 = getClass().getName();
            Intrinsics.checkNotNullExpressionValue(name28, "getName(...)");
            String substringAfterLast$default28 = StringsKt.substringAfterLast$default(name28, BuildVersion.DELIMITER_VERSION, (String) null, i, (Object) null);
            if (StringsKt.contains$default((CharSequence) substringAfterLast$default28, (CharSequence) "$", false, i, (Object) null)) {
                substringAfterLast$default28 = StringsKt.substringBefore$default(substringAfterLast$default28, "$", (String) null, i, (Object) null);
            }
            Timber.INSTANCE.tag(substringAfterLast$default28).d("{LocalDiffusion} {uNet} {inference} Finalization / Notifying callback and closing session.", new Object[0]);
            close();
        }
    }

    public final void initialize() {
        if (this.session != null) {
            return;
        }
        this.decoder = new VaeDecoder(this.ortEnvironmentProvider, this.fileProviderDescriptor, this.localModelIdProvider, this.preferenceManager, this.deviceNNAPIFlagProvider.get());
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        sessionOptions.addConfigEntry(LocalDiffusionContract.ORT_KEY_MODEL_FORMAT, LocalDiffusionContract.ORT);
        if (this.deviceNNAPIFlagProvider.get() == LocalDiffusionFlag.NN_API.getValue()) {
            sessionOptions.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED));
        }
        this.session = this.ortEnvironmentProvider.getEnvironment().createSession(LocalDiffusionPathsKt.modelPathPrefix(this.preferenceManager, this.fileProviderDescriptor, this.localModelIdProvider) + "/unet/model.ort", sessionOptions);
    }

    public final void setCallback(Callback callback) {
        String str = "{LocalDiffusion} {uNet} Setting new result callback " + (callback != null ? callback.hashCode() : 0);
        String name = getClass().getName();
        Intrinsics.checkNotNullExpressionValue(name, "getName(...)");
        String substringAfterLast$default = StringsKt.substringAfterLast$default(name, BuildVersion.DELIMITER_VERSION, (String) null, 2, (Object) null);
        if (StringsKt.contains$default((CharSequence) substringAfterLast$default, (CharSequence) "$", false, 2, (Object) null)) {
            substringAfterLast$default = StringsKt.substringBefore$default(substringAfterLast$default, "$", (String) null, 2, (Object) null);
        }
        Timber.INSTANCE.tag(substringAfterLast$default).d(str, new Object[0]);
        this.callback = callback;
    }
}
