package edu.mit.csail.cgs.utils.regression;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import edu.mit.csail.cgs.utils.Pair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;

/* loaded from: input_file:edu/mit/csail/cgs/utils/regression/Regression.class */
public class Regression {
    private static final double noiseIncrement = 1.0E-6d;

    public static DoubleMatrix2D linear(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2, boolean z) {
        double d = 0.0d;
        while (d < 9.999999999999999E-6d) {
            try {
                return linear(doubleMatrix2D, doubleMatrix2D2);
            } catch (RuntimeException e) {
                if (!z) {
                    throw e;
                }
                d += 1.0E-6d;
                for (int i = 0; i < doubleMatrix2D.rows(); i++) {
                    for (int i2 = 0; i2 < doubleMatrix2D.columns(); i2++) {
                        doubleMatrix2D.setQuick(i, i2, doubleMatrix2D.getQuick(i, i2) + ((Math.random() - 0.5d) * d));
                    }
                }
            }
        }
        throw new IllegalArgumentException("Couldn't do linear regression");
    }

    public static DoubleMatrix2D linear(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2, DoubleMatrix2D doubleMatrix2D3, boolean z) {
        double d = 0.0d;
        while (d < 9.999999999999999E-6d) {
            try {
                return linear(doubleMatrix2D, doubleMatrix2D2, doubleMatrix2D3);
            } catch (RuntimeException e) {
                if (!z) {
                    throw e;
                }
                d += 1.0E-6d;
                for (int i = 0; i < doubleMatrix2D.rows(); i++) {
                    for (int i2 = 0; i2 < doubleMatrix2D.columns(); i2++) {
                        doubleMatrix2D.setQuick(i, i2, doubleMatrix2D.getQuick(i, i2) + ((Math.random() - 0.5d) * d));
                    }
                }
            }
        }
        throw new IllegalArgumentException("Couldn't do linear regression");
    }

    public static DoubleMatrix2D linear(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2) {
        Algebra algebra = new Algebra();
        DoubleMatrix2D transpose = algebra.transpose(doubleMatrix2D);
        return algebra.mult(algebra.mult(algebra.inverse(algebra.mult(transpose, doubleMatrix2D)), transpose), doubleMatrix2D2);
    }

    public static DoubleMatrix2D linear(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2, DoubleMatrix2D doubleMatrix2D3) {
        Algebra algebra = new Algebra();
        DoubleMatrix2D transpose = algebra.transpose(doubleMatrix2D);
        return algebra.mult(algebra.mult(algebra.mult(algebra.inverse(algebra.mult(transpose, algebra.mult(doubleMatrix2D3, doubleMatrix2D))), transpose), doubleMatrix2D3), doubleMatrix2D2);
    }

    public static DoubleMatrix2D linear(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2, DoubleMatrix1D doubleMatrix1D) {
        Algebra algebra = new Algebra();
        DoubleMatrix2D transpose = algebra.transpose(doubleMatrix2D);
        DoubleMatrix2D make = DoubleFactory2D.dense.make(doubleMatrix2D.rows(), doubleMatrix2D.columns());
        DoubleMatrix2D make2 = DoubleFactory2D.dense.make(doubleMatrix2D2.rows(), doubleMatrix2D2.columns());
        for (int i = 0; i < doubleMatrix2D.rows(); i++) {
            double quick = doubleMatrix1D.getQuick(i);
            make2.setQuick(i, 0, doubleMatrix2D2.getQuick(i, 0) * quick);
            for (int i2 = 0; i2 < doubleMatrix2D.columns(); i2++) {
                make.setQuick(i, i2, doubleMatrix2D.getQuick(i, i2) * quick);
            }
        }
        return algebra.mult(algebra.mult(algebra.inverse(algebra.mult(transpose, make)), transpose), make2);
    }

    public static DoubleMatrix2D linear(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2, double[] dArr) {
        Algebra algebra = new Algebra();
        DoubleMatrix2D transpose = algebra.transpose(doubleMatrix2D);
        DoubleMatrix2D make = DoubleFactory2D.dense.make(doubleMatrix2D.rows(), doubleMatrix2D.columns());
        DoubleMatrix2D make2 = DoubleFactory2D.dense.make(doubleMatrix2D2.rows(), doubleMatrix2D2.columns());
        for (int i = 0; i < doubleMatrix2D.rows(); i++) {
            double d = dArr[i];
            make2.setQuick(i, 0, doubleMatrix2D2.getQuick(i, 0) * d);
            for (int i2 = 0; i2 < doubleMatrix2D.columns(); i2++) {
                make.setQuick(i, i2, doubleMatrix2D.getQuick(i, i2) * d);
            }
        }
        return algebra.mult(algebra.mult(algebra.inverse(algebra.mult(transpose, make)), transpose), make2);
    }

    public static DoubleMatrix2D nnls(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2) {
        Algebra algebra = new Algebra();
        TreeSet treeSet = new TreeSet();
        TreeSet treeSet2 = new TreeSet();
        DoubleMatrix2D make = DoubleFactory2D.dense.make(doubleMatrix2D.columns(), 1);
        for (int i = 0; i < doubleMatrix2D.columns(); i++) {
            treeSet2.add(Integer.valueOf(i));
            make.setQuick(i, 0, 0.0d);
        }
        DoubleMatrix2D make2 = DoubleFactory2D.dense.make(doubleMatrix2D2.toArray());
        make2.assign(algebra.mult(doubleMatrix2D, make), Functions.minus);
        DoubleMatrix2D mult = algebra.mult(algebra.transpose(doubleMatrix2D), make2);
        boolean z = false;
        if (0 != 0) {
            System.err.println("trying nnls on ");
            for (int i2 = 0; i2 < doubleMatrix2D.rows(); i2++) {
                System.err.print(" " + doubleMatrix2D2.get(i2, 0) + " = ");
                for (int i3 = 0; i3 < doubleMatrix2D.columns(); i3++) {
                    System.err.print("  " + doubleMatrix2D.get(i2, i3));
                }
                System.err.println();
            }
        }
        int i4 = 0;
        while (treeSet2.size() != 0) {
            int i5 = i4;
            i4++;
            if (i5 > 20 * doubleMatrix2D.columns()) {
                throw new IllegalArgumentException("nnls is looping");
            }
            if (0 != 0) {
                System.err.println("\nP is " + treeSet + "  Z is " + treeSet2 + " skiptosix is " + z + " w=" + mult);
                System.err.println(" x is " + make + "\n");
            }
            if (!z) {
                boolean z2 = true;
                Iterator it = treeSet2.iterator();
                while (it.hasNext()) {
                    z2 = z2 && mult.getQuick(((Integer) it.next()).intValue(), 0) <= 0.0d;
                }
                if (z2) {
                    break;
                }
                double d = Double.NEGATIVE_INFINITY;
                int i6 = -1;
                Iterator it2 = treeSet2.iterator();
                while (it2.hasNext()) {
                    int intValue = ((Integer) it2.next()).intValue();
                    if (mult.getQuick(intValue, 0) > d) {
                        i6 = intValue;
                        d = mult.getQuick(intValue, 0);
                    }
                }
                if (i6 == -1) {
                    throw new RuntimeException("t=-1");
                }
                treeSet2.remove(Integer.valueOf(i6));
                treeSet.add(Integer.valueOf(i6));
                if (0 != 0) {
                    System.err.println("Moving " + i6 + " to P");
                }
            }
            DoubleMatrix2D make3 = DoubleFactory2D.dense.make(doubleMatrix2D.rows(), treeSet.size(), 0.0d);
            DoubleMatrix2D doubleMatrix2D3 = null;
            double d2 = 0.0d;
            while (doubleMatrix2D3 == null) {
                int i7 = 0;
                Iterator it3 = treeSet.iterator();
                while (it3.hasNext()) {
                    int intValue2 = ((Integer) it3.next()).intValue();
                    for (int i8 = 0; i8 < make3.rows(); i8++) {
                        make3.setQuick(i8, i7, doubleMatrix2D.getQuick(i8, intValue2) + (Math.random() * d2));
                    }
                    i7++;
                }
                if (0 != 0) {
                    try {
                        System.err.println("trying linreg on ");
                        for (int i9 = 0; i9 < make3.rows(); i9++) {
                            System.err.print(" " + doubleMatrix2D2.get(i9, 0) + " = ");
                            for (int i10 = 0; i10 < make3.columns(); i10++) {
                                System.err.print("  " + make3.get(i9, i10));
                            }
                            System.err.println();
                        }
                    } catch (IllegalArgumentException e) {
                        d2 += 1.0E-7d;
                    } catch (Exception e2) {
                        e2.printStackTrace();
                    }
                }
                doubleMatrix2D3 = linear(make3, doubleMatrix2D2);
                if (d2 > 1.0E-5d) {
                    throw new IllegalArgumentException("Couldn't do linreg");
                }
            }
            if (0 != 0) {
                System.err.println("Solved linreg as " + doubleMatrix2D3);
            }
            boolean z3 = true;
            for (int i11 = 0; i11 < treeSet.size(); i11++) {
                z3 = z3 && doubleMatrix2D3.getQuick(i11, 0) > 0.0d;
            }
            if (z3) {
                int i12 = 0;
                Iterator it4 = treeSet2.iterator();
                while (it4.hasNext()) {
                    make.setQuick(((Integer) it4.next()).intValue(), 0, 0.0d);
                }
                Iterator it5 = treeSet.iterator();
                while (it5.hasNext()) {
                    make.setQuick(((Integer) it5.next()).intValue(), 0, doubleMatrix2D3.getQuick(i12, 0));
                    i12++;
                }
                DoubleMatrix2D make4 = DoubleFactory2D.dense.make(doubleMatrix2D2.toArray());
                make4.assign(algebra.mult(doubleMatrix2D, make), Functions.minus);
                mult = algebra.mult(algebra.transpose(doubleMatrix2D), make4);
                z = false;
            } else {
                double d3 = Double.POSITIVE_INFINITY;
                int i13 = -1;
                int i14 = 0;
                Iterator it6 = treeSet.iterator();
                while (it6.hasNext()) {
                    int intValue3 = ((Integer) it6.next()).intValue();
                    if (doubleMatrix2D3.getQuick(i14, 0) > 0.0d) {
                        i14++;
                    } else {
                        double quick = make.getQuick(intValue3, 0) / (make.getQuick(intValue3, 0) - doubleMatrix2D3.getQuick(i14, 0));
                        if (quick < d3) {
                            i13 = intValue3;
                            d3 = quick;
                        }
                        i14++;
                    }
                }
                if (i13 == -1) {
                    throw new RuntimeException("q==-1");
                }
                if (0 != 0) {
                    System.err.println("  q=" + i13 + "  alpha=" + d3);
                }
                Iterator it7 = treeSet2.iterator();
                while (it7.hasNext()) {
                    int intValue4 = ((Integer) it7.next()).intValue();
                    make.setQuick(intValue4, 0, make.getQuick(intValue4, 0) * (1.0d - d3));
                }
                int i15 = 0;
                Iterator it8 = treeSet.iterator();
                while (it8.hasNext()) {
                    int intValue5 = ((Integer) it8.next()).intValue();
                    int i16 = i15;
                    i15++;
                    make.setQuick(intValue5, 0, (make.getQuick(intValue5, 0) * (1.0d - d3)) + (d3 * doubleMatrix2D3.getQuick(i16, 0)));
                }
                make.setQuick(i13, 0, 0.0d);
                if (0 != 0) {
                    System.err.println("  updated x to be : " + make);
                }
                ArrayList arrayList = new ArrayList();
                Iterator it9 = treeSet.iterator();
                while (it9.hasNext()) {
                    int intValue6 = ((Integer) it9.next()).intValue();
                    if (make.getQuick(intValue6, 0) == 0.0d) {
                        arrayList.add(Integer.valueOf(intValue6));
                    }
                }
                if (0 != 0) {
                    System.err.println("Moving " + arrayList + " from P to Z");
                }
                treeSet.removeAll(arrayList);
                treeSet2.addAll(arrayList);
                z = true;
            }
        }
        if (0 != 0) {
            System.err.println("Returning " + make);
        }
        return make;
    }

    public static DoubleMatrix2D predict(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2) {
        return new Algebra().mult(doubleMatrix2D, doubleMatrix2D2);
    }

    public static Pair<Double, Double> score(double[] dArr, double[] dArr2) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("Size Mismatch " + dArr.length + " vs " + dArr2.length);
        }
        double d = 0.0d;
        for (double d2 : dArr2) {
            d += d2;
        }
        double length = d / dArr2.length;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d3 += Math.pow(dArr2[i] - dArr[i], 2.0d);
            d4 += Math.pow(dArr2[i] - length, 2.0d);
        }
        return new Pair<>(new Double(d3), new Double(1.0d - (d3 / d4)));
    }

    public static Pair<Double, Double> score(List<Double> list, List<Double> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("Size Mismatch " + list.size() + " vs " + list2.size());
        }
        double d = 0.0d;
        double d2 = 0.0d;
        Iterator<Double> it = list2.iterator();
        while (it.hasNext()) {
            d += it.next().doubleValue();
        }
        Iterator<Double> it2 = list.iterator();
        while (it2.hasNext()) {
            d2 += it2.next().doubleValue();
        }
        double size = d2 / list.size();
        double size2 = d / list2.size();
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i = 0; i < list.size(); i++) {
            d3 += Math.pow(list2.get(i).doubleValue() - list.get(i).doubleValue(), 2.0d);
            d4 += Math.pow(list2.get(i).doubleValue() - size2, 2.0d);
        }
        System.err.println(String.format("mean = %.2f predmean = %.2f, ess = %.2f, tss = %.2f", Double.valueOf(size2), Double.valueOf(size), Double.valueOf(d3), Double.valueOf(d4)));
        return new Pair<>(new Double(d3), new Double(1.0d - (d3 / d4)));
    }

    public static Pair<Double, Double> score(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2, DoubleMatrix2D doubleMatrix2D3) {
        new Algebra();
        if (doubleMatrix2D.rows() != doubleMatrix2D2.rows()) {
            throw new IllegalArgumentException("Size Mismatch " + doubleMatrix2D.rows() + " vs " + doubleMatrix2D2.rows());
        }
        if (doubleMatrix2D.columns() != 1) {
            throw new IllegalArgumentException("predicted columns != 1");
        }
        if (doubleMatrix2D2.columns() != 1) {
            throw new IllegalArgumentException("observed columns != 1");
        }
        double d = 0.0d;
        for (int i = 0; i < doubleMatrix2D2.rows(); i++) {
            d += doubleMatrix2D3.get(i, i) * doubleMatrix2D2.get(i, 0);
        }
        double rows = d / doubleMatrix2D2.rows();
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i2 = 0; i2 < doubleMatrix2D.rows(); i2++) {
            d2 += Math.pow(doubleMatrix2D3.get(i2, i2) * (doubleMatrix2D2.get(i2, 0) - doubleMatrix2D.get(i2, 0)), 2.0d);
            d3 += Math.pow(doubleMatrix2D3.get(i2, i2) * (doubleMatrix2D2.get(i2, 0) - rows), 2.0d);
        }
        return new Pair<>(new Double(d2), new Double(1.0d - (d2 / d3)));
    }

    public static Pair<Double, Double> score(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2, DoubleMatrix1D doubleMatrix1D) {
        new Algebra();
        if (doubleMatrix2D.rows() != doubleMatrix2D2.rows()) {
            throw new IllegalArgumentException("Size Mismatch " + doubleMatrix2D.rows() + " vs " + doubleMatrix2D2.rows());
        }
        if (doubleMatrix2D.columns() != 1) {
            throw new IllegalArgumentException("predicted columns != 1");
        }
        if (doubleMatrix2D2.columns() != 1) {
            throw new IllegalArgumentException("observed columns != 1");
        }
        double d = 0.0d;
        for (int i = 0; i < doubleMatrix2D2.rows(); i++) {
            d += doubleMatrix1D.get(i) * doubleMatrix2D2.get(i, 0);
        }
        double rows = d / doubleMatrix2D2.rows();
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i2 = 0; i2 < doubleMatrix2D.rows(); i2++) {
            d2 += Math.pow(doubleMatrix1D.get(i2) * (doubleMatrix2D2.get(i2, 0) - doubleMatrix2D.get(i2, 0)), 2.0d);
            d3 += Math.pow(doubleMatrix1D.get(i2) * (doubleMatrix2D2.get(i2, 0) - rows), 2.0d);
        }
        return new Pair<>(new Double(d2), new Double(1.0d - (d2 / d3)));
    }

    public static Pair<Double, Double> score(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2) {
        new Algebra();
        if (doubleMatrix2D.rows() != doubleMatrix2D2.rows()) {
            throw new IllegalArgumentException("Size Mismatch " + doubleMatrix2D.rows() + " vs " + doubleMatrix2D2.rows());
        }
        if (doubleMatrix2D.columns() != 1) {
            throw new IllegalArgumentException("predicted columns != 1");
        }
        if (doubleMatrix2D2.columns() != 1) {
            throw new IllegalArgumentException("observed columns != 1");
        }
        double d = 0.0d;
        for (int i = 0; i < doubleMatrix2D2.rows(); i++) {
            d += doubleMatrix2D2.get(i, 0);
        }
        double rows = d / doubleMatrix2D2.rows();
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i2 = 0; i2 < doubleMatrix2D.rows(); i2++) {
            d2 += Math.pow(doubleMatrix2D2.get(i2, 0) - doubleMatrix2D.get(i2, 0), 2.0d);
            d3 += Math.pow(doubleMatrix2D2.get(i2, 0) - rows, 2.0d);
        }
        return new Pair<>(new Double(d2), new Double(1.0d - (d2 / d3)));
    }
}
