package edu.mit.csail.cgs.projects.dnaseq;

import edu.mit.csail.cgs.datasets.chipseq.ChipSeqAnalysisResult;
import edu.mit.csail.cgs.datasets.general.Region;
import edu.mit.csail.cgs.datasets.motifs.BackgroundModelLoader;
import edu.mit.csail.cgs.datasets.motifs.MarkovBackgroundModel;
import edu.mit.csail.cgs.datasets.motifs.WMHit;
import edu.mit.csail.cgs.datasets.motifs.WMHitStartComparator;
import edu.mit.csail.cgs.datasets.motifs.WeightMatrix;
import edu.mit.csail.cgs.ewok.verbs.SequenceGenerator;
import edu.mit.csail.cgs.projects.readdb.ClientException;
import edu.mit.csail.cgs.tools.motifs.WeightMatrixScanner;
import edu.mit.csail.cgs.tools.utils.Args;
import edu.mit.csail.cgs.utils.NotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/mit/csail/cgs/projects/dnaseq/HMMTrain.class */
public abstract class HMMTrain extends DNASeqEnrichmentCaller {
    protected WeightMatrix motif;
    protected SequenceGenerator seqgen;
    private float motifCutoff;
    private List<Region> trainingRegions;
    private String modelFname;
    protected int numStates;
    protected int[][] transitions;
    protected HMMState insensState;
    protected HMMState sensState;
    protected HMMState[] motifForwStates;
    protected HMMState[] motifRevStates;
    protected int lastStateNum = -1;
    private WMHitStartComparator hitcomp = new WMHitStartComparator();

    @Override // edu.mit.csail.cgs.projects.dnaseq.DNASeqEnrichmentCaller
    public void parseArgs(String[] strArr) throws NotFoundException, SQLException, IOException {
        super.parseArgs(strArr);
        this.modelFname = Args.parseString(strArr, "modelfile", "hmm.model");
        this.trainingRegions = Args.parseRegions(strArr);
        Iterator<WeightMatrix> it = Args.parseWeightMatrices(strArr).iterator();
        if (!it.hasNext()) {
            throw new NotFoundException("Couldn't find any motifs in the args");
        }
        this.motif = it.next();
        BackgroundModelLoader.getBackgroundModel(Args.parseString(strArr, "bgmodel", "whole genome zero order"), 1, BackgroundModelLoader.MARKOV_TYPE_STRING, this.genome.getDBID());
        if (0 == 0) {
            this.motif.toLogOdds();
        } else {
            this.motif.toLogOdds((MarkovBackgroundModel) null);
        }
        if (it.hasNext()) {
            System.err.println("More than one motif specified in args.  Using the first; " + this.motif.toString());
        }
        this.motifCutoff = (float) (this.motif.getMaxScore() * Args.parseDouble(strArr, "cutoff", 0.7d));
        System.err.println("Motif Cutoff is " + this.motifCutoff);
        this.seqgen = new SequenceGenerator(this.genome);
        this.seqgen.useLocalFiles(true);
        this.seqgen.useCache(true);
        this.numStates = 2 + (2 * this.motif.length());
        this.transitions = new int[this.numStates][this.numStates];
        for (int i = 0; i < this.numStates; i++) {
            for (int i2 = 0; i2 < this.numStates; i2++) {
                this.transitions[i][i2] = 0;
            }
        }
        this.insensState = new HMMState();
        this.sensState = new HMMState();
        this.motifForwStates = new HMMState[this.motif.length()];
        this.motifRevStates = new HMMState[this.motif.length()];
        for (int i3 = 0; i3 < this.motifForwStates.length; i3++) {
            this.motifForwStates[i3] = new HMMState();
            this.motifRevStates[i3] = new HMMState();
        }
    }

    private Collection<Region> getTrainingRegions() throws SQLException {
        return this.trainingRegions;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void newTrainingRegion(Region region) {
    }

    protected abstract boolean hasBinding(int i);

    public void train() throws IOException, SQLException, ClientException {
        int i = 0;
        for (Region region : getTrainingRegions()) {
            newTrainingRegion(region);
            System.err.println("Training on " + region);
            ReadCounts readCounts = this.reads.getReadCounts(region, getFGAlignments(), getBGAlignments());
            List<ChipSeqAnalysisResult> hyperSensitiveRegions = getHyperSensitiveRegions(region, readCounts);
            int start = region.getStart();
            for (ChipSeqAnalysisResult chipSeqAnalysisResult : hyperSensitiveRegions) {
                i++;
                addTraining(false, new Region(this.genome, chipSeqAnalysisResult.getChrom(), start, chipSeqAnalysisResult.getStart()), readCounts);
                addTraining(true, chipSeqAnalysisResult, readCounts);
                start = chipSeqAnalysisResult.getEnd();
            }
            addTraining(false, new Region(this.genome, region.getChrom(), start, region.getEnd()), readCounts);
        }
        System.err.println("Trained on " + i + " hypersensitive regions");
    }

    private List<ChipSeqAnalysisResult> filterSensitiveRegions(List<ChipSeqAnalysisResult> list) {
        ArrayList arrayList = new ArrayList();
        for (ChipSeqAnalysisResult chipSeqAnalysisResult : list) {
            if (chipSeqAnalysisResult.foregroundReadCount.doubleValue() > 200.0d && Math.log(chipSeqAnalysisResult.pvalue.doubleValue()) < -20.0d) {
                arrayList.add(chipSeqAnalysisResult);
            }
        }
        return arrayList;
    }

    private void addTraining(boolean z, Region region, ReadCounts readCounts) throws SQLException {
        HMMState hMMState;
        int i;
        char[] charArray = this.seqgen.execute((SequenceGenerator) region).toCharArray();
        List<WMHit> scanSequence = WeightMatrixScanner.scanSequence(this.motif, this.motifCutoff, charArray);
        Collections.sort(scanSequence, this.hitcomp);
        int[] iArr = new int[scanSequence.size()];
        for (int i2 = 0; i2 < scanSequence.size(); i2++) {
            iArr[i2] = scanSequence.get(i2).getStart() + region.getStart();
            if (i2 > 0 && iArr[i2] - iArr[i2 - 1] < this.motif.length()) {
                System.err.println("Overlapping motifs " + iArr[i2 - 1] + "," + iArr[i2]);
            }
        }
        int i3 = -1;
        boolean z2 = true;
        boolean z3 = false;
        for (int start = region.getStart(); start < region.getEnd(); start++) {
            char c = charArray[start - region.getStart()];
            int binarySearch = Arrays.binarySearch(iArr, start);
            boolean z4 = binarySearch >= 0;
            if (z4) {
                z2 = scanSequence.get(binarySearch).getStrand().equals("+");
            }
            if (i3 >= 0) {
                i3++;
            }
            if (i3 == this.motif.length()) {
                i3 = -1;
            }
            if (i3 == -1 && z4) {
                i3 = 0;
                z3 = hasBinding(start + (this.motif.length() / 2));
            }
            if (i3 == -1) {
                z3 = hasBinding(start);
            }
            int count = readCounts.getCount(start);
            if (z3 && z && i3 >= 0) {
                hMMState = z2 ? this.motifForwStates[i3] : this.motifRevStates[i3];
                i = 2 + (z2 ? 0 : this.motif.length()) + i3;
            } else if (z) {
                hMMState = this.sensState;
                i = 1;
            } else {
                hMMState = this.insensState;
                i = 0;
            }
            if (this.lastStateNum != -1) {
                int[] iArr2 = this.transitions[this.lastStateNum];
                int i4 = i;
                iArr2[i4] = iArr2[i4] + 1;
            }
            this.lastStateNum = i;
            hMMState.addData(charArray[start - region.getStart()], count);
        }
    }

    private void printTransitions() {
        for (int i = 0; i < this.numStates; i++) {
            for (int i2 = 0; i2 < this.numStates; i2++) {
                System.out.print(String.format("%d\t", Integer.valueOf(this.transitions[i][i2])));
            }
            System.out.println();
        }
    }

    private void printState(HMMState hMMState, String str) {
        System.out.println("\t" + str);
        System.out.println(hMMState.toString());
    }

    public void saveModel() throws IOException {
        PrintWriter printWriter = new PrintWriter(this.modelFname);
        printWriter.println(this.numStates);
        printWriter.print(".98\t.01");
        for (int i = 0; i < this.motifForwStates.length * 2; i++) {
            printWriter.print(String.format("\t%.4f", Double.valueOf(0.01d / (2 * this.motifForwStates.length))));
        }
        printWriter.println();
        printWriter.println(this.insensState.serialize());
        printWriter.println(this.sensState.serialize());
        for (int i2 = 0; i2 < this.motifForwStates.length; i2++) {
            printWriter.println(this.motifForwStates[i2].serialize());
        }
        for (int i3 = 0; i3 < this.motifRevStates.length; i3++) {
            printWriter.println(this.motifRevStates[i3].serialize());
        }
        for (int i4 = 0; i4 < this.numStates; i4++) {
            for (int i5 = 0; i5 < this.numStates; i5++) {
                printWriter.print(String.format("%d\t", Integer.valueOf(this.transitions[i4][i5])));
            }
            printWriter.println();
        }
        printWriter.close();
    }

    public void printModel() {
        printTransitions();
        printState(this.insensState, "non sensitive");
        printState(this.sensState, "sensitive");
        for (int i = 0; i < this.motifForwStates.length; i++) {
            printState(this.motifForwStates[i], String.format("motif %d", Integer.valueOf(i)));
        }
        for (int i2 = 0; i2 < this.motifRevStates.length; i2++) {
            printState(this.motifRevStates[i2], String.format("motif %d", Integer.valueOf(i2)));
        }
    }
}
