package edu.mit.csail.cgs.utils.models.bns;

import edu.mit.csail.cgs.utils.graphs.DirectedAlgorithms;
import edu.mit.csail.cgs.utils.graphs.DirectedGraph;
import edu.mit.csail.cgs.utils.models.Model;
import edu.mit.csail.cgs.utils.models.data.DataFrame;
import edu.mit.csail.cgs.utils.probability.FiniteDistribution;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.Vector;

/* loaded from: input_file:edu/mit/csail/cgs/utils/models/bns/BN.class */
public class BN<X extends Model> {
    public DirectedGraph graph;
    private Map<String, BNVar> vars;
    private Map<String, BNCpd> cpds;
    private DataFrame<X> data;

    public BN(BN bn) {
        this.data = bn.data;
        this.vars = new TreeMap(bn.vars);
        this.cpds = new TreeMap(bn.cpds);
        this.graph = new DirectedGraph(bn.graph);
    }

    public BN(DataFrame<X> dataFrame, String... strArr) {
        this.data = dataFrame;
        this.graph = new DirectedGraph();
        this.vars = new TreeMap();
        this.cpds = new TreeMap();
        for (int i = 0; i < strArr.length; i++) {
            if (this.vars.containsKey(strArr[i])) {
                throw new IllegalArgumentException(strArr[i]);
            }
            this.vars.put(strArr[i], new BNVar(strArr[i], this.data.fieldValues(strArr[i])));
            this.graph.addVertex(strArr[i]);
        }
    }

    public BN(DataFrame<X> dataFrame, DirectedGraph directedGraph) {
        this.data = dataFrame;
        this.graph = directedGraph;
        this.vars = new TreeMap();
        this.cpds = new TreeMap();
        for (String str : this.graph.getVertices()) {
            this.vars.put(str, new BNVar(str, this.data.fieldValues(str)));
        }
        learnCPDs();
    }

    public FiniteDistribution posterior(X x, String str) {
        BNVar bNVar = this.vars.get(str);
        Object findValue = bNVar.findValue(x);
        if (findValue == null) {
            return null;
        }
        return new FiniteDistribution(bNVar.size(), bNVar.encode(findValue).intValue());
    }

    public void print() {
        Vector<String> topologicalOrdering = new DirectedAlgorithms(this.graph).getTopologicalOrdering();
        this.graph.printGraph(System.out);
        Iterator<String> it = topologicalOrdering.iterator();
        while (it.hasNext()) {
            this.cpds.get(it.next()).print();
            System.out.println();
        }
    }

    public DataFrame<X> getData() {
        return this.data;
    }

    public Set<String> varNames() {
        return this.graph.getVertices();
    }

    public X sample() {
        try {
            X newInstance = this.data.getModelClass().newInstance();
            Iterator<String> it = new DirectedAlgorithms(this.graph).getTopologicalOrdering().iterator();
            while (it.hasNext()) {
                this.cpds.get(it.next()).resample(newInstance);
            }
            return newInstance;
        } catch (IllegalAccessException e) {
            e.printStackTrace();
            return null;
        } catch (InstantiationException e2) {
            e2.printStackTrace();
            return null;
        }
    }

    public double logLikelihood(Model model) {
        double d = 0.0d;
        Iterator<String> it = this.cpds.keySet().iterator();
        while (it.hasNext()) {
            d += this.cpds.get(it.next()).logLikelihood(model);
        }
        return d;
    }

    public double logLikelihood(Iterator<? extends Model> it) {
        double d = 0.0d;
        while (true) {
            double d2 = d;
            if (!it.hasNext()) {
                return d2;
            }
            d = d2 + logLikelihood(it.next());
        }
    }

    public double logLikelihood() {
        return logLikelihood(this.data.iterator());
    }

    public BNVar getVar(String str) {
        return this.vars.get(str);
    }

    public BNCpd getCPD(String str) {
        return this.cpds.get(str);
    }

    public int countParameters() {
        int i = 0;
        Iterator<String> it = this.cpds.keySet().iterator();
        while (it.hasNext()) {
            i += this.cpds.get(it.next()).countParameters();
        }
        return i;
    }

    public int countParameters(String str, String... strArr) {
        int size = this.vars.get(str).size();
        for (String str2 : strArr) {
            size *= this.vars.get(str2).size();
        }
        return size;
    }

    public void learnCPDs() {
        if (new DirectedAlgorithms(this.graph).hasCycle()) {
            throw new IllegalStateException("Graph has a cycle.");
        }
        for (String str : this.vars.keySet()) {
            this.cpds.put(str, learnCPD(str));
        }
    }

    private BNCpd learnCPD(String str) {
        Set<String> parents = this.graph.getParents(str);
        BNVar[] bNVarArr = new BNVar[parents.size()];
        BNVar bNVar = this.vars.get(str);
        int i = 0;
        Iterator<String> it = parents.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            bNVarArr[i2] = this.vars.get(it.next());
        }
        BNCpd bNCpd = new BNCpd(bNVarArr, bNVar);
        bNCpd.learn(this.data.iterator());
        return bNCpd;
    }
}
