package deepboof.impl.backward.standard;

import deepboof.backward.DFunctionDropOut;
import deepboof.misc.TensorOps_F64;
import deepboof.tensors.Tensor_F64;
import java.util.List;
import java.util.Random;

/* loaded from: classes5.dex */
public class DFunctionDropOut_F64 extends BaseDFunction<Tensor_F64> implements DFunctionDropOut<Tensor_F64> {
    double dropRate;
    Tensor_F64 drops = new Tensor_F64();
    Random random;

    public DFunctionDropOut_F64(long j, double d) {
        this.random = new Random(j);
        this.dropRate = d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // deepboof.impl.backward.standard.BaseDFunction
    public void _backwards(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642, Tensor_F64 tensor_F643, List<Tensor_F64> list) {
        TensorOps_F64.elementMult(tensor_F642, this.drops, tensor_F643);
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _forward(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642) {
        if (!this.learningMode) {
            TensorOps_F64.elementMult(tensor_F64, 1.0d - this.dropRate, tensor_F642);
            return;
        }
        this.drops.reshape(tensor_F64.shape);
        int length = this.drops.length();
        int i = tensor_F64.startIndex;
        int i2 = tensor_F642.startIndex;
        int i3 = 0;
        while (i3 < length) {
            double[] dArr = this.drops.d;
            double d = this.random.nextDouble() < this.dropRate ? 0.0d : 1.0d;
            dArr[i3] = d;
            tensor_F642.d[i2] = tensor_F64.d[i] * d;
            i3++;
            i2++;
            i++;
        }
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _initialize() {
        this.shapeOutput = (int[]) this.shapeInput.clone();
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _setParameters(List<Tensor_F64> list) {
    }

    @Override // deepboof.backward.DFunctionDropOut
    public double getDropRate() {
        return this.dropRate;
    }

    @Override // deepboof.Function
    public Class<Tensor_F64> getTensorType() {
        return Tensor_F64.class;
    }
}
