package deepboof.impl.backward.standard;

import deepboof.backward.DFunctionLinear;
import deepboof.impl.forward.standard.FunctionLinear_F64;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

/* loaded from: classes3.dex */
public class DFunctionLinear_F64 extends BaseDFunction<Tensor_F64> implements DFunctionLinear<Tensor_F64> {
    protected int D;
    protected int M;
    Tensor_F64 bias;
    Tensor_F64 weight;

    public DFunctionLinear_F64(int i) {
        this.M = i;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // deepboof.impl.backward.standard.BaseDFunction
    public void _backwards(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642, Tensor_F64 tensor_F643, List<Tensor_F64> list) {
        Tensor_F64 tensor_F644 = tensor_F643;
        char c = 0;
        Tensor_F64 tensor_F645 = list.get(0);
        char c2 = 1;
        Tensor_F64 tensor_F646 = list.get(1);
        tensor_F643.zero();
        tensor_F645.zero();
        tensor_F646.zero();
        int i = 0;
        while (i < this.miniBatchSize) {
            int i2 = 0;
            while (i2 < this.M) {
                int i3 = (this.D * i2) + this.weight.startIndex;
                int i4 = (this.D * i) + tensor_F64.startIndex;
                int[] iArr = new int[2];
                iArr[c] = i;
                iArr[c2] = i2;
                double d = tensor_F642.get(iArr);
                int i5 = (this.D * i) + tensor_F644.startIndex;
                int i6 = (this.D * i2) + tensor_F645.startIndex;
                int i7 = 0;
                while (i7 < this.D) {
                    double[] dArr = tensor_F644.d;
                    dArr[i5] = dArr[i5] + (this.weight.d[i3 + i7] * d);
                    double[] dArr2 = tensor_F645.d;
                    dArr2[i6] = dArr2[i6] + (tensor_F64.d[i4 + i7] * d);
                    i7++;
                    tensor_F644 = tensor_F643;
                    i6++;
                    i5++;
                }
                double[] dArr3 = tensor_F646.d;
                int i8 = tensor_F646.startIndex + i2;
                dArr3[i8] = dArr3[i8] + d;
                i2++;
                tensor_F644 = tensor_F643;
                c = 0;
                c2 = 1;
            }
            i++;
            tensor_F644 = tensor_F643;
            c = 0;
            c2 = 1;
        }
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _forward(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642) {
        FunctionLinear_F64.forwards(tensor_F64, tensor_F642, this.weight, this.bias, this.miniBatchSize, this.D, this.M);
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _initialize() {
        if (this.shapeInput.length < 1) {
            throw new IllegalArgumentException("Input tensor shape must have a dimension of at least 1");
        }
        this.D = TensorOps.tensorLength(this.shapeInput);
        this.shapeParameters.add(new int[]{this.M, this.D});
        this.shapeParameters.add(new int[]{this.M});
        this.shapeOutput = new int[]{this.M};
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _setParameters(List<Tensor_F64> list) {
        this.weight = list.get(0);
        this.bias = list.get(1);
    }

    @Override // deepboof.forward.FunctionLinear
    public int getNumberOfOutputs() {
        return this.M;
    }

    @Override // deepboof.Function
    public Class<Tensor_F64> getTensorType() {
        return Tensor_F64.class;
    }
}
