package org.hipparchus.stat.correlation;

import java.util.Arrays;
import org.hipparchus.linear.BlockRealMatrix;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.util.FastMath;
import org.hipparchus.util.MathArrays;

/* loaded from: classes.dex */
public class KendallsCorrelation {
    private final RealMatrix correlationMatrix;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes.dex */
    public static class DoublePair implements Comparable<DoublePair> {
        private final double first;
        private final double second;

        DoublePair(double d, double d2) {
            this.first = d;
            this.second = d2;
        }

        @Override // java.lang.Comparable
        public int compareTo(DoublePair doublePair) {
            int compare = Double.compare(getFirst(), doublePair.getFirst());
            return compare != 0 ? compare : Double.compare(getSecond(), doublePair.getSecond());
        }

        public double getFirst() {
            return this.first;
        }

        public double getSecond() {
            return this.second;
        }
    }

    public KendallsCorrelation() {
        this.correlationMatrix = null;
    }

    public KendallsCorrelation(RealMatrix realMatrix) {
        this.correlationMatrix = computeCorrelationMatrix(realMatrix);
    }

    public KendallsCorrelation(double[][] dArr) {
        this(MatrixUtils.createRealMatrix(dArr));
    }

    private static long sum(long j) {
        return (j * (1 + j)) / 2;
    }

    public RealMatrix computeCorrelationMatrix(RealMatrix realMatrix) {
        int columnDimension = realMatrix.getColumnDimension();
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix(columnDimension, columnDimension);
        for (int i = 0; i < columnDimension; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                double correlation = correlation(realMatrix.getColumn(i), realMatrix.getColumn(i2));
                blockRealMatrix.setEntry(i, i2, correlation);
                blockRealMatrix.setEntry(i2, i, correlation);
            }
            blockRealMatrix.setEntry(i, i, 1.0d);
        }
        return blockRealMatrix;
    }

    public RealMatrix computeCorrelationMatrix(double[][] dArr) {
        return computeCorrelationMatrix(new BlockRealMatrix(dArr));
    }

    public double correlation(double[] dArr, double[] dArr2) {
        long j;
        int i;
        MathArrays.checkEqualLength(dArr, dArr2);
        int length = dArr.length;
        long sum = sum(length - 1);
        DoublePair[] doublePairArr = new DoublePair[length];
        for (int i2 = 0; i2 < length; i2++) {
            doublePairArr[i2] = new DoublePair(dArr[i2], dArr2[i2]);
        }
        Arrays.sort(doublePairArr);
        DoublePair doublePair = doublePairArr[0];
        int i3 = 1;
        long j2 = 0;
        long j3 = 1;
        long j4 = 0;
        long j5 = 1;
        while (i3 < length) {
            DoublePair doublePair2 = doublePairArr[i3];
            if (Double.compare(doublePair2.getFirst(), doublePair.getFirst()) == 0) {
                j3++;
                if (Double.compare(doublePair2.getSecond(), doublePair.getSecond()) == 0) {
                    j5++;
                } else {
                    j4 += sum(j5 - 1);
                    j5 = 1;
                }
            } else {
                j2 += sum(j3 - 1);
                j4 += sum(j5 - 1);
                j3 = 1;
                j5 = 1;
            }
            i3++;
            doublePair = doublePair2;
        }
        long sum2 = j2 + sum(j3 - 1);
        long sum3 = j4 + sum(j5 - 1);
        DoublePair[] doublePairArr2 = new DoublePair[length];
        int i4 = 1;
        long j6 = 0;
        while (i4 < length) {
            long j7 = j6;
            int i5 = 0;
            while (i5 < length) {
                int min = FastMath.min(i5 + i4, length);
                int min2 = FastMath.min(min + i4, length);
                int i6 = i5;
                long j8 = j7;
                int i7 = i6;
                int i8 = min;
                while (true) {
                    if (i7 < min || i8 < min2) {
                        if (i7 < min) {
                            if (i8 < min2) {
                                j = sum2;
                                i = min2;
                                if (Double.compare(doublePairArr[i7].getSecond(), doublePairArr[i8].getSecond()) <= 0) {
                                    doublePairArr2[i6] = doublePairArr[i7];
                                } else {
                                    doublePairArr2[i6] = doublePairArr[i8];
                                    i8++;
                                    j8 += min - i7;
                                }
                            } else {
                                j = sum2;
                                i = min2;
                                doublePairArr2[i6] = doublePairArr[i7];
                            }
                            i7++;
                        } else {
                            j = sum2;
                            i = min2;
                            doublePairArr2[i6] = doublePairArr[i8];
                            i8++;
                        }
                        i6++;
                        min2 = i;
                        sum2 = j;
                    }
                }
                i5 += i4 * 2;
                j7 = j8;
            }
            i4 <<= 1;
            j6 = j7;
            DoublePair[] doublePairArr3 = doublePairArr2;
            doublePairArr2 = doublePairArr;
            doublePairArr = doublePairArr3;
        }
        long j9 = sum2;
        DoublePair doublePair3 = doublePairArr[0];
        int i9 = 1;
        long j10 = 0;
        long j11 = 1;
        while (i9 < length) {
            DoublePair doublePair4 = doublePairArr[i9];
            if (Double.compare(doublePair4.getSecond(), doublePair3.getSecond()) == 0) {
                j11++;
            } else {
                j10 += sum(j11 - 1);
                j11 = 1;
            }
            i9++;
            doublePair3 = doublePair4;
        }
        long sum4 = j10 + sum(j11 - 1);
        long j12 = sum - j9;
        long j13 = ((j12 - sum4) + sum3) - (j6 * 2);
        double d = j12;
        double d2 = sum - sum4;
        Double.isNaN(d);
        Double.isNaN(d2);
        double d3 = d * d2;
        double d4 = j13;
        double sqrt = FastMath.sqrt(d3);
        Double.isNaN(d4);
        return d4 / sqrt;
    }

    public RealMatrix getCorrelationMatrix() {
        return this.correlationMatrix;
    }
}
