package edu.mit.csail.cgs.tools.motifs;

import edu.mit.csail.cgs.datasets.motifs.WMHit;
import edu.mit.csail.cgs.tools.utils.Args;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import org.broadinstitute.gatk.utils.jna.lsf.v7_0_6.LibBat;

/* loaded from: input_file:edu/mit/csail/cgs/tools/motifs/SVMCombinatorial.class */
public class SVMCombinatorial extends CombinatorialEnrichment {
    private double[] trainY;
    private double[] testY;
    private svm_node[][] trainX;
    private svm_node[][] testX;
    private double[] matrixMaxScores;
    private svm_parameter param;
    private svm_problem prob;
    private svm_model model;
    private int traini;
    private int testi;
    private PrintWriter saveCalls;
    private double trainfrac = 0.3d;
    private List<String> trainKeys = new ArrayList();
    private List<String> testKeys = new ArrayList();

    @Override // edu.mit.csail.cgs.tools.motifs.CombinatorialEnrichment, edu.mit.csail.cgs.tools.motifs.CompareEnrichment
    public void parseArgs(String[] strArr) throws Exception {
        super.parseArgs(strArr);
        this.trainfrac = Args.parseDouble(strArr, "trainfrac", this.trainfrac);
        String parseString = Args.parseString(strArr, "savecalls", null);
        if (parseString == null) {
            this.saveCalls = new PrintWriter(LibBat._PATH_NULL);
        } else {
            this.saveCalls = new PrintWriter(parseString);
        }
    }

    private void fillsvm(Map<String, WMHit[]> map, double d) {
        for (String str : map.keySet()) {
            WMHit[] wMHitArr = map.get(str);
            if (this.traini >= this.trainY.length || (Math.random() >= this.trainfrac && this.testi < this.testY.length)) {
                this.testY[this.testi] = d;
                for (int i = 0; i < wMHitArr.length; i++) {
                    this.testX[this.testi][i] = new svm_node();
                    this.testX[this.testi][i].index = i;
                    this.testX[this.testi][i].value = wMHitArr[i] == null ? 0.0d : wMHitArr[i].getScore() / this.matrixMaxScores[i];
                }
                this.testKeys.add(str);
                this.testi++;
            } else {
                this.trainY[this.traini] = d;
                for (int i2 = 0; i2 < wMHitArr.length; i2++) {
                    this.trainX[this.traini][i2] = new svm_node();
                    this.trainX[this.traini][i2].index = i2;
                    this.trainX[this.traini][i2].value = wMHitArr[i2] == null ? 0.0d : wMHitArr[i2].getScore() / this.matrixMaxScores[i2];
                }
                this.trainKeys.add(str);
                this.traini++;
            }
        }
    }

    public void setupSVM() {
        int size = this.fghits.size() + this.bghits.size();
        int i = (int) (size * this.trainfrac);
        int i2 = size - i;
        this.trainY = new double[i];
        this.testY = new double[i2];
        this.trainX = new svm_node[i][this.matrices.size()];
        this.testX = new svm_node[i2][this.matrices.size()];
        this.traini = 0;
        this.testi = 0;
        this.matrixMaxScores = new double[this.matrices.size()];
        for (int i3 = 0; i3 < this.matrices.size(); i3++) {
            this.matrixMaxScores[i3] = this.matrices.get(i3).getMaxScore();
        }
        fillsvm(this.fghits, 1.0d);
        System.err.println("Used " + this.traini + " of " + this.trainY.length + " from fg dataset for training");
        fillsvm(this.bghits, -1.0d);
        this.param = new svm_parameter();
        this.param.svm_type = 0;
        this.param.kernel_type = 0;
        this.param.degree = 1;
        this.param.gamma = 0.0d;
        this.param.coef0 = 0.0d;
        this.param.nu = 0.5d;
        this.param.cache_size = 100.0d;
        this.param.C = 1.0d;
        this.param.eps = 0.001d;
        this.param.p = 0.1d;
        this.param.shrinking = 1;
        this.param.probability = 0;
        this.param.nr_weight = 0;
        this.param.weight_label = new int[0];
        this.param.weight = new double[0];
        this.prob = new svm_problem();
        this.prob.l = this.trainY.length;
        this.prob.x = this.trainX;
        this.prob.y = this.trainY;
    }

    public void trainSVM() {
        this.model = svm.svm_train(this.prob, this.param);
    }

    public void testSVM() {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < this.testY.length; i5++) {
            double svm_predict = svm.svm_predict(this.model, this.testX[i5]);
            if (this.testY[i5] > 0.0d) {
                if (svm_predict > 0.0d) {
                    this.saveCalls.println(this.testKeys.get(i5) + " ++");
                    i++;
                } else {
                    this.saveCalls.println(this.testKeys.get(i5) + " +-");
                    i2++;
                }
            } else if (svm_predict > 0.0d) {
                this.saveCalls.println(this.testKeys.get(i5) + " -+");
                i3++;
            } else {
                this.saveCalls.println(this.testKeys.get(i5) + " --");
                i4++;
            }
        }
        System.out.println(String.format("++ %d, +- %d, -+ %d, -- %d", Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3), Integer.valueOf(i4)));
    }

    public void report() {
    }

    public static void main(String[] strArr) throws Exception {
        SVMCombinatorial sVMCombinatorial = new SVMCombinatorial();
        sVMCombinatorial.parseArgs(strArr);
        System.err.println("Masking and saving");
        sVMCombinatorial.maskSequence();
        sVMCombinatorial.saveSequences();
        System.err.println("Doing weight matrix scanning");
        sVMCombinatorial.doScans();
        System.err.println("Translating to SVM Format");
        sVMCombinatorial.setupSVM();
        System.err.println("Training");
        sVMCombinatorial.trainSVM();
        System.err.println("Testing");
        sVMCombinatorial.testSVM();
        sVMCombinatorial.report();
    }
}
