package edu.mit.csail.cgs.tools.motifs;

import cern.jet.random.Binomial;
import cern.jet.random.engine.RandomEngine;
import com.jidesoft.swing.ButtonStyle;
import edu.mit.csail.cgs.datasets.motifs.WeightMatrix;
import edu.mit.csail.cgs.datasets.motifs.WeightMatrixPainter;
import edu.mit.csail.cgs.datasets.species.Genome;
import edu.mit.csail.cgs.ewok.verbs.SequenceGenerator;
import edu.mit.csail.cgs.tools.utils.Args;
import edu.mit.csail.cgs.utils.NotFoundException;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.RenderingHints;
import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.imageio.ImageIO;
import org.apache.batik.util.SVGConstants;
import org.jfree.chart.encoders.ImageFormat;

/* loaded from: input_file:edu/mit/csail/cgs/tools/motifs/DiscriminativeKmers.class */
public class DiscriminativeKmers {
    private Genome genome;
    private double minfoldchange;
    private int k;
    private int mask;
    private int maxmismatch;
    private int minclustersize;
    private int minclustercount;
    private int parsedregionexpand;
    private Map<String, char[]> foreground;
    private Map<String, char[]> background;
    private boolean printKmers;
    private List<WeightMatrix> pwms;
    private String outbase;
    private static final long intmask = 4294967295L;
    private static final int maxshift = 3;
    private static final char[] toChar = {'A', 'C', 'G', 'T'};
    private int randombgcount = 1000;
    private int randombgsize = 100;
    private SequenceGenerator seqgen = new SequenceGenerator();
    private Binomial binomial = new Binomial(100, 0.01d, RandomEngine.makeDefault());

    public static long charsToLong(char[] cArr) {
        long j = 0;
        for (char c : cArr) {
            j <<= 2;
            switch (c) {
                case 'A':
                case 'a':
                    j += 0;
                    break;
                case 'C':
                case 'c':
                    j++;
                    break;
                case 'G':
                case 'g':
                    j += 2;
                    break;
                case 'T':
                case 't':
                    j += 3;
                    break;
            }
        }
        return j;
    }

    public static long addChar(long j, long j2, char c) {
        long j3 = j << 2;
        switch (c) {
            case 'A':
            case 'a':
                j3 += 0;
                break;
            case 'C':
            case 'c':
                j3++;
                break;
            case 'G':
            case 'g':
                j3 += 2;
                break;
            case 'T':
            case 't':
                j3 += 3;
                break;
        }
        return j3 & j2;
    }

    public static String longToString(long j, int i) {
        return new String(longToChars(j, i));
    }

    public static char[] longToChars(long j, int i) {
        char[] cArr = new char[i];
        while (true) {
            int i2 = i;
            i--;
            if (i2 <= 0) {
                return cArr;
            }
            int i3 = (int) (j & 3);
            j >>= 2;
            cArr[i] = toChar[i3];
        }
    }

    public static long reverseComplement(long j, int i) {
        long j2 = 0;
        for (int i2 = 0; i2 < i; i2++) {
            byte b = (byte) ((j ^ 3) & 3);
            j >>= 2;
            j2 = (j2 << 2) | b;
        }
        return j2;
    }

    public static Map<Long, Integer> count(char[] cArr, int i, long j, Map<Long, Integer> map) {
        if (map == null) {
            map = new HashMap();
        }
        long j2 = 0;
        for (int i2 = 0; i2 < i - 1; i2++) {
            j2 = addChar(j2, j, cArr[i2]);
        }
        for (int i3 = i; i3 < cArr.length; i3++) {
            j2 = addChar(j2, j, cArr[i3]);
            if (map.containsKey(Long.valueOf(j2))) {
                map.put(Long.valueOf(j2), Integer.valueOf(map.get(Long.valueOf(j2)).intValue() + 1));
            } else {
                map.put(Long.valueOf(j2), 1);
            }
        }
        return map;
    }

    public static short countBasesSame(long j, long j2, int i) {
        long j3 = j ^ j2;
        short s = 0;
        while (true) {
            int i2 = i;
            i--;
            if (i2 <= 0) {
                return s;
            }
            if ((j3 & 3) == 0) {
                s = (short) (s + 1);
            }
            j3 >>= 2;
        }
    }

    public static int countBasesSameOneDir(long j, long j2, int i, int i2) {
        short countBasesSame = countBasesSame(j, j2, i);
        short s = 0;
        short s2 = 1;
        while (true) {
            short s3 = s2;
            if (s3 > i2) {
                break;
            }
            short countBasesSame2 = countBasesSame(j >> (s3 * 2), j2, i - s3);
            if (countBasesSame2 > countBasesSame) {
                countBasesSame = countBasesSame2;
                s = s3;
            }
            s2 = (short) (s3 + 1);
        }
        short s4 = 1;
        while (true) {
            short s5 = s4;
            if (s5 > i2) {
                return (s << 16) | countBasesSame;
            }
            short countBasesSame3 = countBasesSame(j, j2 >> (s5 * 2), i - s5);
            if (countBasesSame3 > countBasesSame) {
                countBasesSame = countBasesSame3;
                s = (short) ((-1) * s5);
            }
            s4 = (short) (s5 + 1);
        }
    }

    public static long countBasesSame(long j, long j2, int i, int i2) {
        long countBasesSameOneDir = countBasesSameOneDir(j, j2, i, i2);
        long countBasesSameOneDir2 = countBasesSameOneDir(reverseComplement(j, i), j2, i, i2);
        return (countBasesSameOneDir & 65535) > (countBasesSameOneDir2 & 65535) ? countBasesSameOneDir << 1 : (countBasesSameOneDir2 << 1) | 1;
    }

    public static short getRC(long j) {
        return (short) (j & 1);
    }

    public static short getSameness(long j) {
        return (short) ((j >> 1) & 65535);
    }

    public static short getShift(long j) {
        return (short) ((j >> 17) & 65535);
    }

    public static void paintMotif(WeightMatrix weightMatrix, String str) throws IOException {
        File file = new File(str);
        BufferedImage bufferedImage = new BufferedImage(800, 200, 1);
        Graphics createGraphics = bufferedImage.createGraphics();
        createGraphics.setRenderingHints(new RenderingHints(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON));
        WeightMatrixPainter weightMatrixPainter = new WeightMatrixPainter();
        createGraphics.setColor(Color.WHITE);
        createGraphics.fillRect(0, 0, 800, 200);
        weightMatrixPainter.paint(weightMatrix, createGraphics, 0, 0, 800, 200);
        ImageIO.write(bufferedImage, ImageFormat.PNG, file);
    }

    public WeightMatrix toWeightMatrix(int[] iArr, int[][] iArr2) {
        double[] dArr = new double[4];
        Iterator<String> it = this.background.keySet().iterator();
        while (it.hasNext()) {
            for (char c : this.background.get(it.next())) {
                switch (c) {
                    case 'A':
                    case 'a':
                        dArr[0] = dArr[0] + 1.0d;
                        break;
                    case 'C':
                    case 'c':
                        dArr[1] = dArr[1] + 1.0d;
                        break;
                    case 'G':
                    case 'g':
                        dArr[2] = dArr[2] + 1.0d;
                        break;
                    case 'T':
                    case 't':
                        dArr[3] = dArr[3] + 1.0d;
                        break;
                }
            }
        }
        double d = dArr[0] + dArr[1] + dArr[2] + dArr[3];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = dArr[i] / d;
        }
        WeightMatrix weightMatrix = new WeightMatrix(iArr.length);
        for (int i2 = 0; i2 < iArr.length; i2++) {
            weightMatrix.matrix[i2][65] = (float) Math.log((iArr2[i2][0] / iArr[i2]) / dArr[0]);
            weightMatrix.matrix[i2][67] = (float) Math.log((iArr2[i2][1] / iArr[i2]) / dArr[1]);
            weightMatrix.matrix[i2][71] = (float) Math.log((iArr2[i2][2] / iArr[i2]) / dArr[2]);
            weightMatrix.matrix[i2][84] = (float) Math.log((iArr2[i2][3] / iArr[i2]) / dArr[3]);
        }
        return weightMatrix;
    }

    public List<KmerCluster> cluster(List<KmerCount> list, int i, int i2) {
        Collections.sort(list, new KmerCountComparator());
        ArrayList arrayList = new ArrayList();
        arrayList.add(new KmerCluster(list.remove(0)));
        for (int i3 = 0; i3 < list.size(); i3++) {
            KmerCount kmerCount = list.get(i3);
            int i4 = -1;
            short s = -1;
            for (int i5 = 0; i5 < arrayList.size(); i5++) {
                short sameness = getSameness(countBasesSame(kmerCount.kmer, ((KmerCluster) arrayList.get(i5)).centroid, i, 3));
                if (sameness > s) {
                    s = sameness;
                    i4 = i5;
                }
            }
            if (s > i - i2) {
                ((KmerCluster) arrayList.get(i4)).members.add(kmerCount);
            } else {
                arrayList.add(new KmerCluster(kmerCount));
            }
        }
        Collections.sort(arrayList, new KmerClusterComparator());
        return arrayList;
    }

    public void printClusters(List<KmerCluster> list) {
        int i = 0;
        this.pwms = new ArrayList();
        for (KmerCluster kmerCluster : list) {
            if (kmerCluster.totalCount() >= this.minclustercount && kmerCluster.members.size() >= this.minclustersize) {
                int[] iArr = new int[this.k + 6];
                int[][] iArr2 = new int[this.k + 6][4];
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    iArr[i2] = 4;
                    for (int i3 = 0; i3 < 4; i3++) {
                        iArr2[i2][i3] = 1;
                    }
                }
                for (KmerCount kmerCount : kmerCluster.members) {
                    long countBasesSame = countBasesSame(kmerCount.kmer, kmerCluster.centroid, this.k, 3);
                    getSameness(countBasesSame);
                    short shift = getShift(countBasesSame);
                    short rc = getRC(countBasesSame);
                    long j = kmerCount.kmer;
                    if (rc == 1) {
                        j = reverseComplement(j, this.k);
                    }
                    String longToString = longToString(j, this.k);
                    int i4 = 3 + shift;
                    for (int i5 = 1; i5 <= this.k; i5++) {
                        int i6 = (i4 + this.k) - i5;
                        iArr[i6] = iArr[i6] + kmerCount.count;
                        int[] iArr3 = iArr2[i6];
                        int i7 = (int) (j & 3);
                        iArr3[i7] = iArr3[i7] + kmerCount.count;
                        j >>= 2;
                    }
                    if (this.printKmers) {
                        StringBuilder sb = new StringBuilder("  ");
                        for (int i8 = 0; i8 < i4; i8++) {
                            sb.append(" ");
                        }
                        sb.append(longToString);
                        for (int i9 = 0; i9 < 6 - i4; i9++) {
                            sb.append(" ");
                        }
                        sb.append("\t" + kmerCount.count);
                        System.out.println(sb.toString());
                    }
                }
                WeightMatrix weightMatrix = toWeightMatrix(iArr, iArr2);
                weightMatrix.name = this.outbase + "_" + i;
                weightMatrix.version = String.format("mfc %.2f expand %d size %d count %d", Double.valueOf(this.minfoldchange), Integer.valueOf(this.parsedregionexpand), Integer.valueOf(this.minclustersize), Integer.valueOf(this.minclustercount));
                weightMatrix.type = "DiscriminativeKmers";
                CEResult doScan = CompareEnrichment.doScan(weightMatrix, this.foreground, this.background, null, null, 0.1d, 0.01d, 2.0d, 0.05d, 1.0d, null, null, false);
                if (doScan.freqone > 0.1d) {
                    System.out.println("Cluster centroid " + longToString(kmerCluster.centroid, this.k) + " and total count " + kmerCluster.totalCount());
                    for (int i10 = 0; i10 < 4; i10++) {
                        System.out.print(toChar[i10]);
                        for (int i11 = 0; i11 < iArr.length; i11++) {
                            System.out.print(String.format("\t%.2f", Double.valueOf(iArr2[i11][i10] / iArr[i11])));
                        }
                        System.out.println();
                    }
                    System.out.println(doScan.toString());
                    try {
                        int i12 = i;
                        i++;
                        paintMotif(weightMatrix, this.outbase + i12 + ".png");
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    }

    public void setK(int i) {
        this.k = i;
        this.mask = 0;
        for (int i2 = 0; i2 < i; i2++) {
            this.mask <<= 2;
            this.mask |= 3;
        }
    }

    public void parseArgs(String[] strArr) throws NotFoundException, IOException, FileNotFoundException {
        setK(Args.parseInteger(strArr, SVGConstants.SVG_K_ATTRIBUTE, 10));
        this.printKmers = Args.parseFlags(strArr).contains("printkmers");
        this.minfoldchange = Args.parseDouble(strArr, "minfoldchange", 1.0d);
        this.parsedregionexpand = Args.parseInteger(strArr, "expand", 30);
        this.randombgcount = Args.parseInteger(strArr, "randombgcount", 1000);
        this.randombgsize = Args.parseInteger(strArr, "randombgsize", 100);
        this.maxmismatch = Args.parseInteger(strArr, "maxmismatch", 3);
        this.minclustersize = Args.parseInteger(strArr, "minclustersize", 2);
        this.minclustercount = Args.parseInteger(strArr, "minclustercount", 30);
        this.outbase = Args.parseString(strArr, "outbase", "motif");
        this.genome = Args.parseGenome(strArr).cdr();
        String parseString = Args.parseString(strArr, ButtonStyle.SEGMENT_POSITION_FIRST, null);
        String parseString2 = Args.parseString(strArr, "second", null);
        if (parseString == null) {
            System.err.println("No --first specified.  Reading from stdin");
            this.foreground = CompareEnrichment.readRegions(this.genome, new BufferedReader(new InputStreamReader(System.in)), this.parsedregionexpand, null);
        } else if (parseString.matches(".*\\.fasta") || parseString.matches(".*\\.fa")) {
            this.foreground = CompareEnrichment.readFasta(new BufferedReader(new FileReader(parseString)));
        } else {
            this.foreground = CompareEnrichment.readRegions(this.genome, new BufferedReader(new FileReader(parseString)), this.parsedregionexpand, null);
        }
        if (parseString2 == null) {
            System.err.println("No background file given.  Generating " + this.randombgcount + " regions of size " + this.randombgsize);
            this.background = CompareEnrichment.randomRegions(this.genome, this.randombgcount, this.randombgsize);
        } else if (parseString2.matches(".*\\.fasta") || parseString2.matches(".*\\.fa")) {
            this.background = CompareEnrichment.readFasta(new BufferedReader(new FileReader(parseString2)));
        } else {
            this.background = CompareEnrichment.readRegions(this.genome, new BufferedReader(new FileReader(parseString2)), this.parsedregionexpand, null);
        }
    }

    public void run() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        Iterator<char[]> it = this.foreground.values().iterator();
        while (it.hasNext()) {
            count(it.next(), this.k, this.mask, hashMap);
        }
        Iterator<char[]> it2 = this.background.values().iterator();
        while (it2.hasNext()) {
            count(it2.next(), this.k, this.mask, hashMap2);
        }
        int i = 0;
        int i2 = 0;
        Iterator it3 = hashMap.keySet().iterator();
        while (it3.hasNext()) {
            i += ((Integer) hashMap.get(Long.valueOf(((Long) it3.next()).longValue()))).intValue();
        }
        Iterator it4 = hashMap2.keySet().iterator();
        while (it4.hasNext()) {
            i2 += ((Integer) hashMap2.get(Long.valueOf(((Long) it4.next()).longValue()))).intValue();
        }
        System.err.println("Read " + i + " kmers from the fg set and " + i2 + " from the background set");
        ArrayList arrayList = new ArrayList();
        Iterator it5 = hashMap.keySet().iterator();
        while (it5.hasNext()) {
            long longValue = ((Long) it5.next()).longValue();
            double intValue = ((hashMap2.containsKey(Long.valueOf(longValue)) ? ((Integer) hashMap2.get(Long.valueOf(longValue))).intValue() : 0) + 1) / i2;
            int intValue2 = ((Integer) hashMap.get(Long.valueOf(longValue))).intValue();
            double d = intValue2 / i;
            this.binomial.setNandP(i, intValue);
            Math.log(1.0d - this.binomial.cdf(intValue2));
            if (d > intValue * this.minfoldchange) {
                KmerCount kmerCount = new KmerCount(longValue, intValue2);
                kmerCount.count = (int) (kmerCount.count - (intValue * i));
                if (kmerCount.count > 0) {
                    arrayList.add(kmerCount);
                }
            }
        }
        printClusters(cluster(arrayList, this.k, this.maxmismatch));
    }

    public static void main(String[] strArr) throws Exception {
        DiscriminativeKmers discriminativeKmers = new DiscriminativeKmers();
        discriminativeKmers.parseArgs(strArr);
        discriminativeKmers.run();
    }
}
