package org.broadinstitute.gatk.utils;

import cern.jet.math.Arithmetic;
import cern.jet.random.Normal;
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import java.io.Serializable;
import java.util.Comparator;
import java.util.Iterator;
import java.util.TreeSet;
import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.NormalDistribution;
import org.apache.commons.math.distribution.NormalDistributionImpl;
import org.broadinstitute.gatk.engine.GenomeAnalysisEngine;
import org.broadinstitute.gatk.utils.collections.Pair;
import org.broadinstitute.gatk.utils.exceptions.GATKException;

/* loaded from: input_file:org/broadinstitute/gatk/utils/MannWhitneyU.class */
public class MannWhitneyU {
    private static Normal STANDARD_NORMAL = new Normal(0.0d, 1.0d, null);
    private static NormalDistribution APACHE_NORMAL = new NormalDistributionImpl(0.0d, 1.0d, 0.01d);
    private static double LNSQRT2PI = Math.log(Math.sqrt(6.283185307179586d));
    private TreeSet<Pair<Number, USet>> observations;
    private int sizeSet1;
    private int sizeSet2;
    private ExactMode exactMode;

    /* loaded from: input_file:org/broadinstitute/gatk/utils/MannWhitneyU$DitheringComparator.class */
    private static class DitheringComparator implements Comparator<Pair<Number, USet>>, Serializable {
        @Override // java.util.Comparator
        public boolean equals(Object obj) {
            return false;
        }

        @Override // java.util.Comparator
        public int compare(Pair<Number, USet> pair, Pair<Number, USet> pair2) {
            double compare = Double.compare(pair.first.doubleValue(), pair2.first.doubleValue());
            if (compare > 0.0d) {
                return 1;
            }
            return (compare >= 0.0d && !GenomeAnalysisEngine.getRandomGenerator().nextBoolean()) ? 1 : -1;
        }
    }

    /* loaded from: input_file:org/broadinstitute/gatk/utils/MannWhitneyU$ExactMode.class */
    public enum ExactMode {
        POINT,
        CUMULATIVE
    }

    /* loaded from: input_file:org/broadinstitute/gatk/utils/MannWhitneyU$NumberedPairComparator.class */
    private static class NumberedPairComparator implements Comparator<Pair<Number, USet>>, Serializable {
        @Override // java.util.Comparator
        public boolean equals(Object obj) {
            return false;
        }

        @Override // java.util.Comparator
        public int compare(Pair<Number, USet> pair, Pair<Number, USet> pair2) {
            return Double.compare(pair.first.doubleValue(), pair2.first.doubleValue());
        }
    }

    /* loaded from: input_file:org/broadinstitute/gatk/utils/MannWhitneyU$USet.class */
    public enum USet {
        SET1,
        SET2
    }

    public MannWhitneyU(ExactMode exactMode, boolean z) {
        if (z) {
            this.observations = new TreeSet<>(new DitheringComparator());
        } else {
            this.observations = new TreeSet<>(new NumberedPairComparator());
        }
        this.sizeSet1 = 0;
        this.sizeSet2 = 0;
        this.exactMode = exactMode;
    }

    public MannWhitneyU() {
        this(ExactMode.POINT, true);
    }

    public MannWhitneyU(boolean z) {
        this(ExactMode.POINT, z);
    }

    public MannWhitneyU(ExactMode exactMode) {
        this(exactMode, true);
    }

    public void add(Number number, USet uSet) {
        this.observations.add(new Pair<>(number, uSet));
        if (uSet == USet.SET1) {
            this.sizeSet1++;
        } else {
            this.sizeSet2++;
        }
    }

    public Pair<Long, Long> getR1R2() {
        long calculateOneSidedU = calculateOneSidedU(this.observations, USet.SET1);
        long j = (this.sizeSet1 * (this.sizeSet1 + 1)) / 2;
        long j2 = calculateOneSidedU + j;
        long j3 = (this.sizeSet2 * (this.sizeSet2 + 1)) / 2;
        return new Pair<>(Long.valueOf(j2), Long.valueOf(((j * j3) - calculateOneSidedU) + j3));
    }

    @Ensures({"validateObservations(observations) || Double.isNaN(result.getFirst())", "result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
    @Requires({"lessThanOther != null"})
    public Pair<Double, Double> runOneSidedTest(USet uSet) {
        long calculateOneSidedU = calculateOneSidedU(this.observations, uSet);
        int i = uSet == USet.SET1 ? this.sizeSet1 : this.sizeSet2;
        int i2 = uSet == USet.SET1 ? this.sizeSet2 : this.sizeSet1;
        return (i == 0 || i2 == 0) ? new Pair<>(Double.valueOf(Double.NaN), Double.valueOf(Double.NaN)) : calculateP(i, i2, calculateOneSidedU, false, this.exactMode);
    }

    @Ensures({"result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
    public Pair<Double, Double> runTwoSidedTest() {
        Pair<Long, USet> calculateTwoSidedU = calculateTwoSidedU(this.observations);
        long longValue = calculateTwoSidedU.first.longValue();
        int i = calculateTwoSidedU.second == USet.SET1 ? this.sizeSet1 : this.sizeSet2;
        int i2 = calculateTwoSidedU.second == USet.SET1 ? this.sizeSet2 : this.sizeSet1;
        return (i == 0 || i2 == 0) ? new Pair<>(Double.valueOf(Double.NaN), Double.valueOf(Double.NaN)) : calculateP(i, i2, longValue, true, this.exactMode);
    }

    @Ensures({"result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
    @Requires({"m > 0", "n > 0"})
    protected static Pair<Double, Double> calculateP(int i, int i2, long j, boolean z, ExactMode exactMode) {
        return (i <= 8 || i2 <= 8) ? (i <= 5 || i2 <= 7) ? (i > 8 || i2 > 8) ? calculatePFromTable(i, i2, j, z) : calculatePRecursively(i, i2, j, z, exactMode) : calculatePNormalApproximation(i, i2, j, z) : calculatePNormalApproximation(i, i2, j, z);
    }

    public static Pair<Double, Double> calculatePFromTable(int i, int i2, long j, boolean z) {
        return calculatePNormalApproximation(i, i2, j, z);
    }

    @Ensures({"result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
    @Requires({"m > 0", "n > 0"})
    public static Pair<Double, Double> calculatePNormalApproximation(int i, int i2, long j, boolean z) {
        double zApprox = getZApprox(i, i2, j);
        if (z) {
            return new Pair<>(Double.valueOf(zApprox), Double.valueOf(2.0d * (zApprox < 0.0d ? STANDARD_NORMAL.cdf(zApprox) : 1.0d - STANDARD_NORMAL.cdf(zApprox))));
        }
        return new Pair<>(Double.valueOf(zApprox), Double.valueOf(STANDARD_NORMAL.cdf(zApprox)));
    }

    @Ensures({"! Double.isNaN(result)", "! Double.isInfinite(result)"})
    @Requires({"m > 0", "n > 0"})
    private static double getZApprox(int i, int i2, long j) {
        return (j - (((i2 * i) + 1.0d) / 2.0d)) / Math.sqrt(((i * i2) * ((i + i2) + 1.0d)) / 12.0d);
    }

    public static double calculatePUniformApproximation(int i, int i2, long j) {
        double sqrt = ((i / 2.0d) * (1.0d - Math.sqrt(((i + i2) + 1) / i2))) + ((j + ((i * (i + 1)) / 2)) / Math.sqrt(i2 * ((i + i2) + 1)));
        if (sqrt < 0.0d) {
            return 1.0d;
        }
        if (sqrt > i) {
            return 0.0d;
        }
        return sqrt > ((double) i) / 2.0d ? 1.0d - ((1.0d / Arithmetic.factorial(i)) * uniformSumHelper(sqrt, (int) Math.floor(sqrt), i, 0)) : (1.0d / Arithmetic.factorial(i)) * uniformSumHelper(sqrt, (int) Math.floor(sqrt), i, 0);
    }

    private static double uniformSumHelper(double d, int i, int i2, int i3) {
        if (i3 > i) {
            return 0.0d;
        }
        return ((i3 % 2 == 0 ? 1 : -1) * Arithmetic.binomial(i2, i3) * Math.pow(d - i3, i2)) + uniformSumHelper(d, i, i2, i3 + 1);
    }

    @Ensures({"result != null", "result.first > 0"})
    @Requires({"observed != null", "observed.size() > 0"})
    public static Pair<Long, USet> calculateTwoSidedU(TreeSet<Pair<Number, USet>> treeSet) {
        int i = 0;
        int i2 = 0;
        long j = 0;
        long j2 = 0;
        USet uSet = null;
        Iterator<Pair<Number, USet>> it = treeSet.iterator();
        while (it.hasNext()) {
            Pair<Number, USet> next = it.next();
            if (next.second == USet.SET1) {
                i++;
            } else {
                i2++;
            }
            if (uSet != null) {
                if (next.second == USet.SET1) {
                    j2 += i2;
                } else {
                    j += i;
                }
            }
            uSet = next.second;
        }
        return j < j2 ? new Pair<>(Long.valueOf(j), USet.SET1) : new Pair<>(Long.valueOf(j2), USet.SET2);
    }

    @Ensures({"result >= 0"})
    @Requires({"observed != null", "dominator != null", "observed.size() > 0"})
    public static long calculateOneSidedU(TreeSet<Pair<Number, USet>> treeSet, USet uSet) {
        long j = 0;
        int i = 0;
        Iterator<Pair<Number, USet>> it = treeSet.iterator();
        while (it.hasNext()) {
            if (it.next().second != uSet) {
                i++;
            } else {
                j += i;
            }
        }
        return j;
    }

    @Ensures({"result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
    @Requires({"m > 0", "n > 0", "u >= 0"})
    public static Pair<Double, Double> calculatePRecursively(int i, int i2, long j, boolean z, ExactMode exactMode) {
        double sqrt;
        if (i2 > 8 && i > 5) {
            throw new GATKException(String.format("Please use the appropriate (normal or sum of uniform) approximation. Values n: %d, m: %d", Integer.valueOf(i), Integer.valueOf(i2)));
        }
        double cpr = exactMode == ExactMode.POINT ? cpr(i, i2, j) : cumulativeCPR(i, i2, j);
        try {
            if (exactMode == ExactMode.CUMULATIVE) {
                sqrt = APACHE_NORMAL.inverseCumulativeProbability(cpr);
            } else {
                double sqrt2 = Math.sqrt((((1.0d + (1.0d / ((1 + i) + i2))) * (i * i2)) * ((1.0d + i) + i2)) / 12.0d);
                sqrt = cpr > 1.0d / Math.sqrt(((sqrt2 * sqrt2) * 2.0d) * 3.141592653589793d) ? 0.0d : j >= ((long) ((i * i2) / 2)) ? Math.sqrt((-2.0d) * (Math.log(sqrt2) + Math.log(cpr) + LNSQRT2PI)) : -Math.sqrt((-2.0d) * (Math.log(sqrt2) + Math.log(cpr) + LNSQRT2PI));
            }
            return new Pair<>(Double.valueOf(sqrt), Double.valueOf(z ? 2.0d * cpr : cpr));
        } catch (MathException e) {
            throw new GATKException("A math exception occurred in inverting the probability", e);
        }
    }

    protected static double calculatePRecursivelyDoNotCheckValuesEvenThoughItIsSlow(int i, int i2, long j) {
        return cpr(i, i2, j);
    }

    protected static long countSequences(int i, int i2, long j) {
        if (j < 0) {
            return 0L;
        }
        return (i2 == 0 || i == 0) ? j == 0 ? 1L : 0L : countSequences(i - 1, i2, j - i2) + countSequences(i, i2 - 1, j);
    }

    private static double cpr(int i, int i2, long j) {
        if (j < 0) {
            return 0.0d;
        }
        return (i2 == 0 || i == 0) ? j == 0 ? 1.0d : 0.0d : ((i / (i + i2)) * cpr(i - 1, i2, j - i2)) + ((i2 / (i + i2)) * cpr(i, i2 - 1, j));
    }

    private static double cumulativeCPR(int i, int i2, long j) {
        double d = 0.0d;
        long j2 = j <= ((long) ((i * i2) / 2)) ? j : (i * i2) - j;
        long j3 = 0;
        while (true) {
            long j4 = j3;
            if (j4 >= j2) {
                break;
            }
            d += cpr(i, i2, j4);
            j3 = j4 + 1;
        }
        return j <= ((long) ((i * i2) / 2)) ? d : 1.0d - d;
    }

    protected TreeSet<Pair<Number, USet>> getObservations() {
        return this.observations;
    }

    protected Pair<Integer, Integer> getSetSizes() {
        return new Pair<>(Integer.valueOf(this.sizeSet1), Integer.valueOf(this.sizeSet2));
    }

    protected static boolean validateObservations(TreeSet<Pair<Number, USet>> treeSet) {
        boolean z = false;
        boolean z2 = false;
        boolean z3 = false;
        Iterator<Pair<Number, USet>> it = treeSet.iterator();
        while (it.hasNext()) {
            Pair<Number, USet> next = it.next();
            if (!z && next.getSecond() == USet.SET1) {
                z = true;
            }
            if (!z2 && next.getSecond() == USet.SET2) {
                z2 = true;
            }
            if (Double.isNaN(next.getFirst().doubleValue()) || Double.isInfinite(next.getFirst().doubleValue())) {
                z3 = true;
            }
        }
        return !z3 && z && z2;
    }
}
