package deepboof.impl.backward.standard;

import deepboof.DeepBoofConstants;
import deepboof.Function;
import deepboof.backward.NumericalGradient;
import deepboof.misc.TensorOps;
import deepboof.misc.TensorOps_F64;
import deepboof.tensors.Tensor_F64;
import java.util.List;

/* loaded from: classes4.dex */
public class NumericalGradient_F64 implements NumericalGradient<Tensor_F64> {
    Function<Tensor_F64> function;
    Tensor_F64 input;
    List<Tensor_F64> parameters;
    double T = DeepBoofConstants.TEST_TOL_A_F64;
    Tensor_F64 output = new Tensor_F64();

    private void process(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642, Tensor_F64 tensor_F643) {
        int length = tensor_F64.length();
        for (int i = 0; i < length; i++) {
            int i2 = tensor_F64.startIndex + i;
            double d = tensor_F64.d[i2];
            tensor_F64.d[i2] = this.T + d;
            this.function.setParameters(this.parameters);
            this.function.forward(this.input, this.output);
            Tensor_F64 tensor_F644 = this.output;
            TensorOps_F64.elementMult(tensor_F644, tensor_F642, tensor_F644);
            double elementSum = TensorOps_F64.elementSum(this.output);
            tensor_F64.d[i2] = d - this.T;
            this.function.setParameters(this.parameters);
            this.function.forward(this.input, this.output);
            Tensor_F64 tensor_F645 = this.output;
            TensorOps_F64.elementMult(tensor_F645, tensor_F642, tensor_F645);
            double elementSum2 = TensorOps_F64.elementSum(this.output);
            tensor_F64.d[i2] = d;
            tensor_F643.d[tensor_F643.startIndex + i] = (elementSum - elementSum2) / (this.T * 2.0d);
        }
    }

    @Override // deepboof.backward.NumericalGradient
    public void configure(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("T must be > 0");
        }
        this.T = d;
    }

    @Override // deepboof.backward.NumericalGradient
    public void differentiate(Tensor_F64 tensor_F64, List<Tensor_F64> list, Tensor_F64 tensor_F642, Tensor_F64 tensor_F643, List<Tensor_F64> list2) {
        this.output.reshape(TensorOps.WI(tensor_F64.length(0), this.function.getOutputShape()));
        this.input = tensor_F64;
        this.parameters = list;
        process(tensor_F64, tensor_F642, tensor_F643);
        for (int i = 0; i < list.size(); i++) {
            process(list.get(i), tensor_F642, list2.get(i));
        }
    }

    @Override // deepboof.backward.NumericalGradient
    public void setFunction(Function<Tensor_F64> function) {
        this.function = function;
    }
}
