package edu.mit.csail.cgs.utils.models.bns;

import edu.mit.csail.cgs.utils.models.Model;
import edu.mit.csail.cgs.utils.probability.FiniteDistribution;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;

/* loaded from: input_file:edu/mit/csail/cgs/utils/models/bns/BNCpd.class */
public class BNCpd {
    private BNVar[] parents;
    private BNVar child;
    private Map<BNValues, FiniteDistribution> cpd = new LinkedHashMap();

    public BNCpd(BNVar[] bNVarArr, BNVar bNVar) {
        this.parents = (BNVar[]) bNVarArr.clone();
        this.child = bNVar;
        int size = this.child.size();
        BNValuesIterator bNValuesIterator = new BNValuesIterator(this.parents);
        while (bNValuesIterator.hasNext()) {
            this.cpd.put(bNValuesIterator.next(), new FiniteDistribution(size));
        }
    }

    public void print() {
        System.out.println(String.format("CPD: %s", this.child.getName()));
        for (BNValues bNValues : this.cpd.keySet()) {
            System.out.println(String.format("%s -> %s", bNValues.toString(), this.cpd.get(bNValues).toString()));
        }
    }

    public double logLikelihood(Iterator<Model> it) {
        double d = 0.0d;
        while (true) {
            double d2 = d;
            if (!it.hasNext()) {
                return d2;
            }
            d = d2 + logLikelihood(it.next());
        }
    }

    public double logLikelihood(Model model) {
        return Math.log(this.cpd.get(new BNValues(model, this.parents)).getProb(this.child.encode(this.child.findValue(model)).intValue()));
    }

    public int countParameters() {
        int size = this.child.size();
        for (int i = 0; i < this.parents.length; i++) {
            size *= this.parents[i].size();
        }
        return size;
    }

    public Object sample(BNValues bNValues) {
        return this.child.decode(Integer.valueOf(this.cpd.get(bNValues).sampleValue()));
    }

    public void resample(Model model) {
        this.child.setValue(model, sample(new BNValues(model, this.parents)));
    }

    public void learn(Iterator<? extends Model> it) {
        Iterator<BNValues> it2 = this.cpd.keySet().iterator();
        while (it2.hasNext()) {
            this.cpd.get(it2.next()).clear();
        }
        while (it.hasNext()) {
            Model next = it.next();
            BNValues bNValues = new BNValues(next, this.parents);
            this.cpd.get(bNValues).addValue(this.child.encode(this.child.findValue(next)));
        }
        Iterator<BNValues> it3 = this.cpd.keySet().iterator();
        while (it3.hasNext()) {
            this.cpd.get(it3.next()).normalize();
        }
    }
}
