package edu.mit.csail.cgs.utils.models.data;

import Jama.Matrix;
import edu.mit.csail.cgs.utils.BitVector;
import edu.mit.csail.cgs.utils.models.Model;
import edu.mit.csail.cgs.utils.models.ModelFieldAnalysis;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.Vector;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* loaded from: input_file:edu/mit/csail/cgs/utils/models/data/Predictors.class */
public class Predictors<M extends Model> {
    private static Pattern interactionPattern = Pattern.compile("([^:]+):(.+)");
    private DataFrame<M> frame;
    private boolean hasConstant;
    private Vector<Field> numeric = new Vector<>();
    private Vector<Field> factor = new Vector<>();
    private Vector<Predictors<M>.Interaction> interactions = new Vector<>();
    private Map<Field, Vector<Object>> factorCodes = new HashMap();
    private Map<Predictors<M>.Interaction, Vector<Predictors<M>.InteractionValue>> interactionValues = new HashMap();
    private Vector<String> columnNames = new Vector<>();
    private int cols = 0;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/mit/csail/cgs/utils/models/data/Predictors$Interaction.class */
    public class Interaction {
        public Set<Field> fields = new HashSet();

        public Interaction(Field... fieldArr) {
            for (Field field : fieldArr) {
                this.fields.add(field);
            }
        }

        public int hashCode() {
            int i = 17;
            Iterator<Field> it = this.fields.iterator();
            while (it.hasNext()) {
                i += it.next().hashCode();
            }
            return i * 37;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            for (Field field : this.fields) {
                if (isFactorField(field)) {
                    if (sb.length() > 0) {
                        sb.append(":");
                    }
                    sb.append(field.getName());
                }
            }
            for (Field field2 : this.fields) {
                if (!isFactorField(field2)) {
                    if (sb.length() > 0) {
                        sb.append(":");
                    }
                    sb.append(field2.getName());
                }
            }
            return sb.toString();
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof Interaction)) {
                return false;
            }
            Interaction interaction = (Interaction) obj;
            if (this.fields.size() != interaction.fields.size()) {
                return false;
            }
            Iterator<Field> it = this.fields.iterator();
            while (it.hasNext()) {
                if (!interaction.fields.contains(it.next())) {
                    return false;
                }
            }
            return true;
        }

        private boolean isNumericField(Field field) {
            return Model.isSubclass(field.getType(), Number.class);
        }

        private boolean isFactorField(Field field) {
            return Model.isSubclass(field.getType(), String.class);
        }

        public Predictors<M>.InteractionValue calculateInteractionValue(Object obj) {
            Predictors<M>.InteractionValue interactionValue = new InteractionValue();
            for (Field field : this.fields) {
                if (isFactorField(field)) {
                    try {
                        interactionValue.values.add((String) field.get(obj));
                    } catch (IllegalAccessException e) {
                        e.printStackTrace();
                    }
                }
            }
            return interactionValue;
        }

        public Double calculatePredictor(Object obj) {
            Double valueOf = Double.valueOf(1.0d);
            for (Field field : this.fields) {
                if (isNumericField(field)) {
                    try {
                        valueOf = Double.valueOf(valueOf.doubleValue() * ((Number) field.get(obj)).doubleValue());
                    } catch (IllegalAccessException e) {
                        e.printStackTrace();
                    }
                }
            }
            return valueOf;
        }

        public String[] findFactorFieldNames() {
            Vector vector = new Vector();
            for (Field field : this.fields) {
                if (isFactorField(field)) {
                    vector.add(field.getName());
                }
            }
            return (String[]) vector.toArray(new String[vector.size()]);
        }

        public String[] findNumericFieldNames() {
            Vector vector = new Vector();
            for (Field field : this.fields) {
                if (isNumericField(field)) {
                    vector.add(field.getName());
                }
            }
            return (String[]) vector.toArray(new String[vector.size()]);
        }

        public Vector<Predictors<M>.InteractionValue> allInteractionValues(String[] strArr) {
            Vector<Predictors<M>.InteractionValue> vector = new Vector<>();
            vector.add(new InteractionValue());
            for (String str : findFactorFieldNames()) {
                vector = appendValues(vector, Predictors.this.frame.fieldValues(str));
            }
            vector.remove(0);
            return vector;
        }

        private Vector<Predictors<M>.InteractionValue> appendValues(Vector<Predictors<M>.InteractionValue> vector, Set set) {
            Vector<Predictors<M>.InteractionValue> vector2 = new Vector<>();
            Iterator<Predictors<M>.InteractionValue> it = vector.iterator();
            while (it.hasNext()) {
                vector2.addAll(it.next().extend(set));
            }
            return vector2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/mit/csail/cgs/utils/models/data/Predictors$InteractionValue.class */
    public class InteractionValue {
        public Vector values;

        public InteractionValue() {
            this.values = new Vector();
        }

        public InteractionValue(Predictors<M>.InteractionValue interactionValue) {
            this.values = new Vector(interactionValue.values);
        }

        public InteractionValue(Predictors<M>.InteractionValue interactionValue, Object obj) {
            this.values = new Vector(interactionValue.values);
            this.values.add(obj);
        }

        public Vector<Predictors<M>.InteractionValue> extend(Set set) {
            Vector<Predictors<M>.InteractionValue> vector = new Vector<>();
            Iterator it = set.iterator();
            while (it.hasNext()) {
                vector.add(new InteractionValue(this, it.next()));
            }
            return vector;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < this.values.size(); i++) {
                if (i > 0) {
                    sb.append("_");
                }
                sb.append(this.values.get(i).toString());
            }
            return sb.toString();
        }

        public int hashCode() {
            int i = 17;
            Iterator it = this.values.iterator();
            while (it.hasNext()) {
                i = (i + it.next().hashCode()) * 37;
            }
            return i;
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof InteractionValue)) {
                return false;
            }
            InteractionValue interactionValue = (InteractionValue) obj;
            if (interactionValue.values.size() != this.values.size()) {
                return false;
            }
            for (int i = 0; i < this.values.size(); i++) {
                if (!this.values.get(i).equals(interactionValue.values.get(i))) {
                    return false;
                }
            }
            return true;
        }
    }

    public Predictors(DataFrame<M> dataFrame, String... strArr) {
        this.frame = dataFrame;
        this.hasConstant = false;
        this.frame.getModelClass();
        HashSet hashSet = new HashSet();
        ModelFieldAnalysis modelFieldAnalysis = new ModelFieldAnalysis(dataFrame.getModelClass());
        for (int i = 0; i < strArr.length; i++) {
            if (hashSet.contains(strArr[i])) {
                throw new IllegalArgumentException(String.format("Duplicate field name: %s", strArr[i]));
            }
            if (strArr[i].equals("1")) {
                this.hasConstant = true;
                this.cols++;
            } else {
                try {
                    Matcher matcher = interactionPattern.matcher(strArr[i]);
                    Field findField = modelFieldAnalysis.findField(strArr[i]);
                    if (findField != null) {
                        Class<?> type = findField.getType();
                        if (Model.isSubclass(type, Number.class)) {
                            addPredictor(strArr[i]);
                        } else {
                            if (!Model.isSubclass(type, String.class)) {
                                throw new IllegalArgumentException(String.format("Field %s is not a regression-ready predictor", strArr[i]));
                            }
                            addFactor(strArr[i]);
                        }
                    } else {
                        if (!matcher.matches()) {
                            throw new IllegalArgumentException(String.format("Unknown field name: %s", strArr[i]));
                        }
                        addInteraction(strArr[i].split(":"));
                    }
                } catch (NoSuchFieldException e) {
                    throw new IllegalArgumentException(String.format("Unknown field name: %s", strArr[i]));
                }
            }
            hashSet.add(strArr[i]);
        }
        if (this.hasConstant) {
            this.columnNames.insertElementAt("(Intercept)", 0);
        }
    }

    public int size() {
        return this.frame.size();
    }

    public void addConstant() {
        if (this.hasConstant) {
            return;
        }
        this.hasConstant = true;
        this.columnNames.insertElementAt("(Intercept)", 0);
    }

    public void addInteraction(String... strArr) throws NoSuchFieldException {
        Field[] findFields = new ModelFieldAnalysis(this.frame.getModelClass()).findFields(strArr);
        for (int i = 0; i < findFields.length; i++) {
            if (findFields[i] == null) {
                throw new NoSuchFieldException(strArr[i]);
            }
        }
        Predictors<M>.Interaction interaction = new Interaction(findFields);
        if (this.interactions.contains(interaction)) {
            throw new IllegalArgumentException(String.format("Cannot add the same interaction %s twice.", interaction.toString()));
        }
        this.interactions.add(interaction);
        Vector<Predictors<M>.InteractionValue> allInteractionValues = interaction.allInteractionValues(interaction.findFactorFieldNames());
        this.interactionValues.put(interaction, allInteractionValues);
        this.cols += allInteractionValues.size();
        Iterator<Predictors<M>.InteractionValue> it = allInteractionValues.iterator();
        while (it.hasNext()) {
            this.columnNames.add(String.format("%s(%s)", interaction.toString(), it.next().toString()));
        }
    }

    public void addPredictor(String str) throws NoSuchFieldException {
        Field field = this.frame.getModelClass().getField(str);
        Class<?> type = field.getType();
        if (Model.isSubclass(type, Double.class)) {
            this.numeric.add(field);
        } else {
            if (!Model.isSubclass(type, Integer.class)) {
                throw new NoSuchFieldException(String.format("%s is not a numeric field (%s)", str, type.getName()));
            }
            this.numeric.add(field);
        }
        this.cols++;
        this.columnNames.add(str);
    }

    public Set<String> findFactorValues(String str) {
        TreeSet treeSet = new TreeSet();
        Iterator it = this.frame.fieldValues(str).iterator();
        while (it.hasNext()) {
            treeSet.add((String) it.next());
        }
        treeSet.remove(treeSet.first());
        return treeSet;
    }

    public void addFactor(String str) throws NoSuchFieldException {
        Field field = this.frame.getModelClass().getField(str);
        if (!Model.isSubclass(field.getType(), String.class)) {
            throw new IllegalArgumentException(String.format("%s is not a valid factor-field.", str));
        }
        Set<String> findFactorValues = findFactorValues(str);
        this.factor.add(field);
        this.factorCodes.put(field, new Vector<>(findFactorValues));
        this.cols += findFactorValues.size();
        Iterator<String> it = findFactorValues.iterator();
        while (it.hasNext()) {
            this.columnNames.add(String.format("%s(%s)", str, it.next()));
        }
    }

    public int getNumColumns() {
        return this.cols;
    }

    public String getColumnName(int i) {
        return this.columnNames.get(i);
    }

    public Matrix createMatrix() {
        return createMatrix(null);
    }

    public Matrix createMatrix(BitVector bitVector) {
        return createMatrix(bitVector, null);
    }

    public Matrix createMatrix(BitVector bitVector, Map<String, Transformation<Double, Double>> map) {
        Matrix matrix = new Matrix(bitVector != null ? bitVector.countOnBits() : this.frame.size(), this.cols);
        int i = 0;
        if (this.hasConstant) {
            int i2 = 0;
            for (int i3 = 0; i3 < this.frame.size(); i3++) {
                if (bitVector == null || bitVector.isOn(i3)) {
                    matrix.set(i2, 0, 1.0d);
                    i2++;
                }
            }
            i = 0 + 1;
        }
        Iterator<Field> it = this.numeric.iterator();
        while (it.hasNext()) {
            Field next = it.next();
            Transformation<Double, Double> transformation = (map == null || !map.containsKey(next.getName())) ? null : map.get(next.getName());
            int i4 = 0;
            for (int i5 = 0; i5 < this.frame.size(); i5++) {
                if (bitVector == null || bitVector.isOn(i5)) {
                    try {
                        Double valueOf = Double.valueOf(((Number) next.get(this.frame.object(i5))).doubleValue());
                        if (transformation != null) {
                            valueOf = transformation.transform(valueOf);
                        }
                        matrix.set(i4, i, valueOf.doubleValue());
                        i4++;
                    } catch (IllegalAccessException e) {
                        e.printStackTrace();
                        throw new IllegalStateException(String.format("Couldn't access field %s: %s", next.getName(), e.getMessage()));
                    }
                }
            }
            i++;
        }
        Iterator<Field> it2 = this.factor.iterator();
        while (it2.hasNext()) {
            Field next2 = it2.next();
            Vector<Object> vector = this.factorCodes.get(next2);
            int i6 = 0;
            for (int i7 = 0; i7 < this.frame.size(); i7++) {
                if (bitVector == null || bitVector.isOn(i7)) {
                    try {
                        int indexOf = vector.indexOf(next2.get(this.frame.object(i7)));
                        if (indexOf != -1) {
                            matrix.set(i6, i + indexOf, 1.0d);
                        }
                        i6++;
                    } catch (IllegalAccessException e2) {
                        e2.printStackTrace();
                        throw new IllegalStateException(String.format("Couldn't access field %s: %s", next2.getName(), e2.getMessage()));
                    }
                }
            }
            i += vector.size();
        }
        Iterator<Predictors<M>.Interaction> it3 = this.interactions.iterator();
        while (it3.hasNext()) {
            Predictors<M>.Interaction next3 = it3.next();
            Vector<Predictors<M>.InteractionValue> vector2 = this.interactionValues.get(next3);
            int i8 = 0;
            for (int i9 = 0; i9 < this.frame.size(); i9++) {
                if (bitVector == null || bitVector.isOn(i9)) {
                    M object = this.frame.object(i9);
                    Predictors<M>.InteractionValue calculateInteractionValue = next3.calculateInteractionValue(object);
                    Double calculatePredictor = next3.calculatePredictor(object);
                    int indexOf2 = vector2.indexOf(calculateInteractionValue);
                    if (indexOf2 != -1) {
                        matrix.set(i8, i + indexOf2, calculatePredictor.doubleValue());
                    }
                    i8++;
                }
            }
            i += vector2.size();
        }
        return matrix;
    }

    public Vector<String> getColumnNames() {
        return this.columnNames;
    }
}
