package edu.mit.csail.cgs.utils.models.data;

import Jama.Matrix;
import Jama.QRDecomposition;
import cern.jet.random.ChiSquare;
import cern.jet.random.Normal;
import cern.jet.random.engine.DRand;
import cern.jet.random.engine.RandomEngine;
import edu.mit.csail.cgs.utils.BitVector;
import edu.mit.csail.cgs.utils.Predicate;
import edu.mit.csail.cgs.utils.models.Model;
import edu.mit.csail.cgs.utils.models.data.RegressionModel;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Vector;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* loaded from: input_file:edu/mit/csail/cgs/utils/models/data/DataRegression.class */
public class DataRegression<M extends Model> {
    private DataFrame<M> frame;
    private Predicted<M> dataY;
    private Predictors<M> dataX;
    private QRDecomposition qr;
    private Matrix Rinv;
    private Matrix betaHat;
    private Matrix Vbeta;
    private double s2;
    private double r2;
    private String dataYVar;
    private String[] dataXVars;
    private RandomEngine engine = new DRand();
    private Normal ndist = new Normal(0.0d, 1.0d, this.engine);
    private static Pattern stmtPattern = Pattern.compile("\\s*([^\\s~]+)\\s*~\\s*(.*)");

    public static void main(String[] strArr) {
        try {
            DataFrame dataFrame = new DataFrame(XYPoint.class, new File("C:\\Documents and Settings\\tdanford\\Desktop\\test.txt"));
            new RegressionModel() { // from class: edu.mit.csail.cgs.utils.models.data.DataRegression.1
                public RegressionModel.DependentVariable y;
                public RegressionModel.NumericVariable x;
                public RegressionModel.Intercept b;
            };
            DataRegression dataRegression = new DataRegression(dataFrame, "y ~ x + 1");
            dataRegression.transform(new ATransformation<XYPoint, XYPoint>(XYPoint.class, XYPoint.class) { // from class: edu.mit.csail.cgs.utils.models.data.DataRegression.2
                @Override // edu.mit.csail.cgs.utils.models.data.Transformation
                public XYPoint transform(XYPoint xYPoint) {
                    xYPoint.y = Double.valueOf(xYPoint.y.doubleValue() * 2.0d);
                    return xYPoint;
                }
            });
            Map<String, Double> calculateRegression = dataRegression.calculateRegression();
            Map<String, Double[]> calculateBounds = dataRegression.calculateBounds();
            for (String str : calculateRegression.keySet()) {
                Double[] dArr = calculateBounds.get(str);
                System.out.println(String.format("%s \t%.3f\t(%.3f, %.3f)", str, calculateRegression.get(str), dArr[0], dArr[1]));
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public DataRegression(DataFrame<M> dataFrame, String str) {
        this.frame = dataFrame;
        Vector<String> parseStatement = parseStatement(str);
        if (parseStatement == null) {
            throw new IllegalArgumentException(String.format("Couldn't parse statement \"%s\"", str));
        }
        this.dataYVar = parseStatement.get(0);
        this.dataXVars = (String[]) parseStatement.subList(1, parseStatement.size()).toArray(new String[parseStatement.size() - 1]);
        this.dataY = new Predicted<>(this.frame, this.dataYVar);
        this.dataX = new Predictors<>(this.frame, this.dataXVars);
    }

    public DataRegression(DataFrame<M> dataFrame, RegressionModel regressionModel) {
        this.frame = dataFrame;
        Field dependentVariable = regressionModel.getDependentVariable();
        Vector<Field> independentVariables = regressionModel.getIndependentVariables();
        boolean hasInterceptVariable = regressionModel.hasInterceptVariable();
        int i = hasInterceptVariable ? 1 : 0;
        this.dataYVar = dependentVariable.getName();
        this.dataXVars = new String[independentVariables.size() + i];
        int i2 = 0;
        if (hasInterceptVariable) {
            i2 = 0 + 1;
            this.dataXVars[0] = "1";
        }
        while (i2 < this.dataXVars.length) {
            this.dataXVars[i2] = independentVariables.get(i2 - i).getName();
            i2++;
        }
        this.dataY = new Predicted<>(this.frame, this.dataYVar);
        this.dataX = new Predictors<>(this.frame, this.dataXVars);
    }

    public void filter(Predicate<M> predicate) {
        this.frame.filter(predicate);
    }

    public void transform(Transformation<M, M> transformation) {
        this.frame = (DataFrame<M>) this.frame.transform(transformation);
        this.dataY = new Predicted<>(this.frame, this.dataYVar);
        this.dataX = new Predictors<>(this.frame, this.dataXVars);
    }

    public Vector<String> getPredictorNames() {
        return this.dataX.getColumnNames();
    }

    public Predictors<M> getPredictors() {
        return this.dataX;
    }

    public Predicted<M> getPredicted() {
        return this.dataY;
    }

    public Matrix getPredictorMatrix() {
        return this.dataX.createMatrix();
    }

    public Matrix getPredictedVector() {
        return this.dataY.createVector();
    }

    public DataFrame<M> getFrame() {
        return this.frame;
    }

    public Map<String, Double> calculateRegression() {
        calculate();
        return collectCoefficients();
    }

    public Map<String, Double> collectCoefficients() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < this.betaHat.getRowDimension(); i++) {
            linkedHashMap.put(this.dataX.getColumnName(i), Double.valueOf(this.betaHat.get(i, 0)));
        }
        return linkedHashMap;
    }

    public Map<String, Double[]> calculateBounds() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Vector<Double[]> sampleBetaBounds = sampleBetaBounds(100);
        for (int i = 0; i < this.betaHat.getRowDimension(); i++) {
            linkedHashMap.put(this.dataX.getColumnName(i), sampleBetaBounds.get(i));
        }
        return linkedHashMap;
    }

    public void calculate() {
        calculate(null);
    }

    public void calculate(BitVector bitVector) {
        calculate(bitVector, (Map<String, Transformation<Double, Double>>) null);
    }

    public void calculate(BitVector bitVector, Map<String, Transformation<Double, Double>> map) {
        calculate(this.dataX.createMatrix(bitVector, map), this.dataY.createVector(bitVector));
    }

    public static Matrix leastSquares(Matrix matrix, Matrix matrix2) {
        QRDecomposition qRDecomposition = new QRDecomposition(matrix);
        return qRDecomposition.getR().solve(qRDecomposition.getQ().transpose().times(matrix2));
    }

    public static double s2(Matrix matrix, Matrix matrix2, Matrix matrix3) {
        Matrix minus = matrix2.minus(matrix.times(matrix3));
        return minus.transpose().times(minus).get(0, 0) / (matrix.getRowDimension() - matrix.getColumnDimension());
    }

    public void calculate(Matrix matrix, Matrix matrix2) {
        this.qr = new QRDecomposition(matrix);
        Matrix r = this.qr.getR();
        this.Rinv = r.inverse();
        this.Vbeta = this.Rinv.times(this.Rinv.transpose());
        this.betaHat = r.solve(this.qr.getQ().transpose().times(matrix2));
        Matrix times = matrix.times(this.betaHat);
        Matrix minus = matrix2.minus(times);
        int rowDimension = matrix.getRowDimension();
        int columnDimension = matrix.getColumnDimension();
        this.s2 = minus.transpose().times(minus).get(0, 0);
        this.s2 /= rowDimension - columnDimension;
        calculateR2(matrix2, times);
    }

    private void calculateR2(Matrix matrix, Matrix matrix2) {
        double d = 0.0d;
        for (int i = 0; i < matrix.getRowDimension(); i++) {
            d += matrix.get(i, 0);
        }
        double rowDimension = d / matrix.getRowDimension();
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i2 = 0; i2 < matrix.getRowDimension(); i2++) {
            double d5 = matrix.get(i2, 0);
            double d6 = d5 - rowDimension;
            double d7 = d5 - matrix2.get(i2, 0);
            double d8 = matrix2.get(i2, 0) - rowDimension;
            d3 += d6 * d6;
            d2 += d7 * d7;
            d4 += d8 * d8;
        }
        this.r2 = 1.0d - (d2 / d3);
    }

    public Matrix getBetaHat() {
        return this.betaHat;
    }

    public Matrix getVarBeta() {
        return this.Vbeta;
    }

    public double getR2() {
        return this.r2;
    }

    public double getS2() {
        return this.s2;
    }

    public int getN() {
        return this.dataX.size();
    }

    public int getK() {
        return this.dataX.getNumColumns();
    }

    public Vector<Double[]> sampleBetaBounds(int i) {
        Vector vector = new Vector();
        for (int i2 = 0; i2 < getK(); i2++) {
            vector.add(new Double[i]);
        }
        for (int i3 = 0; i3 < i; i3++) {
            Matrix sampleBeta = sampleBeta(sampleVar());
            for (int i4 = 0; i4 < getK(); i4++) {
                ((Double[]) vector.get(i4))[i3] = Double.valueOf(sampleBeta.get(i4, 0));
            }
        }
        Vector<Double[]> vector2 = new Vector<>();
        int i5 = i / 4;
        int i6 = 3 * (i / 4);
        for (int i7 = 0; i7 < getK(); i7++) {
            Double[] dArr = (Double[]) vector.get(i7);
            Arrays.sort(dArr);
            vector2.add(new Double[]{dArr[i5], dArr[i6]});
        }
        return vector2;
    }

    public Matrix sampleBeta(double d) {
        Matrix matrix = new Matrix(getK(), 1);
        for (int i = 0; i < matrix.getRowDimension(); i++) {
            matrix.set(i, 0, this.ndist.nextDouble());
        }
        return this.Rinv.times(Math.sqrt(d)).times(matrix).plus(this.betaHat);
    }

    public double sampleVar() {
        double n = getN() - getK();
        return (n * this.s2) / new ChiSquare(n, this.engine).nextDouble();
    }

    public double calculateS2(Matrix matrix) {
        double size = 1.0d / (this.frame.size() - this.dataX.getNumColumns());
        Matrix createVector = this.dataY.createVector();
        Matrix createMatrix = this.dataX.createMatrix();
        if (matrix == null) {
            matrix = calculateBetaHat();
        }
        Matrix minus = createVector.minus(createMatrix.times(matrix));
        return size * minus.transpose().times(minus).get(0, 0);
    }

    public Matrix calculateBetaHat() {
        Matrix transpose = this.dataX.createMatrix().transpose();
        return transpose.times(transpose.transpose()).inverse().times(transpose.times(this.dataY.createVector()));
    }

    private static Vector<String> parseStatement(String str) {
        Matcher matcher = stmtPattern.matcher(str);
        Vector<String> vector = null;
        if (matcher.matches()) {
            vector = new Vector<>();
            vector.add(matcher.group(1));
            String[] split = matcher.group(2).split("\\s+");
            boolean z = false;
            boolean z2 = false;
            boolean z3 = false;
            for (int i = 0; i < split.length; i++) {
                if (i % 2 == 1) {
                    if (split[i].equals("-")) {
                        z3 = true;
                    } else {
                        if (!split[i].equals("+")) {
                            return null;
                        }
                        z3 = false;
                    }
                } else if (split[i].equals("1")) {
                    z = true;
                    if (z3) {
                        z2 = true;
                    } else {
                        vector.add(split[i]);
                    }
                } else {
                    vector.add(split[i]);
                }
            }
            if (!z && !z2) {
                vector.add("1");
            }
        }
        return vector;
    }

    public static void printMatrix(Matrix matrix, PrintStream printStream, int i) {
        String str = "%." + i + "f";
        printStream.print("   \t");
        for (int i2 = 0; i2 < matrix.getColumnDimension(); i2++) {
            if (i2 > 0) {
                printStream.print("  ");
            }
            printStream.print(String.format(" %3d", Integer.valueOf(i2)));
        }
        printStream.println();
        for (int i3 = 0; i3 < matrix.getRowDimension(); i3++) {
            printStream.print(String.format("%3d\t", Integer.valueOf(i3)));
            for (int i4 = 0; i4 < matrix.getColumnDimension(); i4++) {
                if (i4 > 0) {
                    printStream.print("  ");
                }
                printStream.print(String.format(str, Double.valueOf(matrix.get(i3, i4))));
            }
            printStream.println();
        }
    }
}
