# vim: set syntax=python


import random
import sys
import csv
import time
import json
import math

import numpy as np
import pysam
import vcf
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid.axislines import Subplot
from scipy.stats import kde, binned_statistic


savefig = partial(plt.savefig, bbox_inches="tight")


shell.prefix("set -o pipefail; ")


######################################### Config ###############################

# see also https://github.com/broadinstitute/somatic-benchmark
config = {
    "raw_samples": {
        "NA12878": {  # daughter
            "gold_variants": (
                "ftp://ftp-trace.ncbi.nih.gov/giab/ftp/data/NA12878/variant_calls/"
                "NIST/NISTIntegratedCalls_14datasets_131103_allcall_UGHapMerge_HetHomVarPASS"
                "_VQSRv2.18_all_nouncert_excludesimplerep_excludesegdups_"
                "excludedecoy_excludeRepSeqSTRs_noCNVs.vcf.gz"
            ),
            "giab_gold_regions": (
                "ftp://ftp-trace.ncbi.nih.gov/giab/ftp/data/NA12878/variant_calls/NIST/"
                "union13callableMQonlymerged_addcert_nouncert_excludesimplerep_"
                "excludesegdups_excludedecoy_excludeRepSeqSTRs_noCNVs"
                "_v2.18_2mindatasets_5minYesNoRatio.bed.gz"
            ),
            "reads": (
                "ftp://ftp-trace.ncbi.nih.gov/1000genomes/ftp/technical/working/20120117_ceu_trio_b37_decoy/CEUTrio.HiSeq.WEx.b37_decoy.NA12878.clean.dedup.recal.20120117.bam"
            )
        },
        "NA12891": {  # father
            "gold_variants": (
                "ftp://platgene:G3n3s4me@ussd-ftp.illumina.com/NA12891_S1.genome.vcf.gz"
            ),
            "reads": (
                "ftp://ftp-trace.ncbi.nih.gov/1000genomes/ftp/technical/working/20120117_ceu_trio_b37_decoy/CEUTrio.HiSeq.WEx.b37_decoy.NA12891.clean.dedup.recal.20120117.bam"
            )
        },
        "NA12892": {  # mother
            "gold_variants": (
                "ftp://platgene:G3n3s4me@ussd-ftp.illumina.com/NA12892_S1.genome.vcf.gz"
            ),
            "reads": (
                "ftp://ftp-trace.ncbi.nih.gov/1000genomes/ftp/technical/working/20120117_ceu_trio_b37_decoy/CEUTrio.HiSeq.WEx.b37_decoy.NA12892.clean.dedup.recal.20120117.bam"
            )
        }
    },
    "specificity": {
        "queries": [
            "A0-A1", "A0-(A1+A2)", "A0-(A1+A2+A3)", "(A0+A1)-(A2+A3)"
        ],
        "query_plot_layouts": [
            "y", "l", "xy", "x"
        ]
    },
    "sensitivity": {
        "min_depth": 5,
        "sample": "NA12878",
        "filter_sample": "NA12892",
        "queries": [
            "A0-B0", "A0-(B0+B1)", "(A0+A1)-B0", "(A0+A1)-(B0+B1)"
        ],
        "query_plot_layouts": [
            "y", "l", "xy", "x"
        ],
        "xlim": 16,
        "exclude": ['GL000207.1', 'GL000226.1', 'GL000229.1', 'GL000231.1', 'GL000210.1', 'GL000239.1', 'GL000235.1', 'GL000201.1', 'GL000247.1', 'GL000245.1', 'GL000197.1', 'GL000203.1', 'GL000246.1', 'GL000249.1', 'GL000196.1', 'GL000248.1', 'GL000244.1', 'GL000238.1', 'GL000202.1', 'GL000234.1', 'GL000232.1', 'GL000206.1', 'GL000240.1', 'GL000236.1', 'GL000241.1', 'GL000243.1', 'GL000242.1', 'GL000230.1', 'GL000237.1', 'GL000233.1', 'GL000204.1', 'GL000198.1', 'GL000208.1', 'GL000191.1', 'GL000227.1', 'GL000228.1', 'GL000214.1', 'GL000221.1', 'GL000209.1', 'GL000218.1', 'GL000220.1', 'GL000213.1', 'GL000211.1', 'GL000199.1', 'GL000217.1', 'GL000216.1', 'GL000215.1', 'GL000205.1', 'GL000219.1', 'GL000224.1', 'GL000223.1', 'GL000195.1', 'GL000212.1', 'GL000222.1', 'GL000200.1', 'GL000193.1', 'GL000194.1', 'GL000225.1', 'GL000192.1', 'MT', 'NC_012920', 'NC_012920', 'hs37d5']
    },
    "samples": {
        "A0": "NA12878.subsample-0.25-0",
        "A1": "NA12878.subsample-0.25-1",
        "A2": "NA12878.subsample-0.25-2",
        "A3": "NA12878.subsample-0.25-3",
        "B0": "NA12892.subsample-0.25-0",
        "B1": "NA12892.subsample-0.25-1",
        #"B0A0": "NA12892.subsample-0.25-0+NA12878.subsample-0.25-0__at__NA12878-NA12892__with__0.8"
    },
    "max_pileup_depth": 250,
    "caller_names": {
        "gatk": "GATK",
        "alpaca": "ALPACA",
        "freebayes": "FreeBayes",
        "samtools": "SAMtools"
    },
    "parameter_space": {
        "min_qual": list(range(10, 80, 10)),
    }
}


def get_sample_name(
    wildcards,
    sample_to_name={sample: name for name, sample in config["samples"].items()}
):
    return sample_to_name[wildcards.sample]


def get_samples_from_query(query, filter_only=False, call_only=False):
    if filter_only:
        try:
            query = query.split("-")[1]
        except IndexError:
            return []
    if call_only:
        query = query.split("-")[0]
    return [s.strip(" ()") for t in query.split("+") for s in t.split("-")]


def expand_samples_from_query(pattern, filter_only=False, call_only=False, orig_name=False):
    def apply(wildcards):
        samples = get_samples_from_query(
            wildcards.query, filter_only=filter_only, call_only=call_only
        )
        if orig_name:
            samples = [config["samples"][s] for s in samples]
        return expand(
            pattern,
            sample=samples
        )
    return apply


include:
    "https://bitbucket.org/johanneskoester/snakemake-workflows/raw/master/bio/ngs/rules/mapping/samfiles.rules"


target_specificity =  expand(
    "plots/specificity/{query}.depth.{layout}.pdf",
    zip,
    query=config["specificity"]["queries"],
    layout=config["specificity"]["query_plot_layouts"]
)


target_sensitivity = expand(
    "plots/sensitivity/{query}.depth.{layout}.pdf",
    zip,
    query=config["sensitivity"]["queries"],
    layout=config["sensitivity"]["query_plot_layouts"]
)


########################### Target rules #######################################


rule all:
    input:
        target_specificity,
        target_sensitivity,
        "plots/compression/runtime_vs_size.pdf",
        expand(
            "data/reads/{sample}.readcount.txt",
            sample="NA12878 NA12892".split()
        )


rule specificity:
    input:
        target_specificity


rule sensitivity:
    input:
        target_sensitivity


rule other_vcf:
    input:
        expand(
            "{caller}/{query}.vcf",
            caller="gatk".split(),
            query=config["specificity"]["queries"] + config["sensitivity"]["queries"]
        )


######################################### Sample stats #########################


rule read_count:
    input:
        "{prefix}.bam"
    output:
        "{prefix}.readcount.txt"
    shell:
        "samtools idxstats {input} | awk '{{s+=$3+$4}} END {{print s}}' "
        "> {output}"


rule coverage_dist:
    input:
        "{prefix}.bam"
    output:
        "{prefix}.coverage.txt"
    resources: benchmark=1
    shell:
        "samtools mpileup {input} | cut -f4 | sort -n | uniq -c | "
        r"awk '{{print $2,$1}}' OFS='\t' > {output}"


rule plot_coverages:
    input:
        expand("data/reads/{sample}.coverage.txt", sample=config["samples"].values())
    output:
        "plots/coverage.pdf"
    resources: benchmark=1
    run:
        figure()
        for f in input:
            cov = np.loadtxt(f, dtype=np.int64)[1:]
            depth, freq = cov[:,0], cov[:,1]
            plt.semilogy(depth, freq, alpha=0.5)
            sites = freq.sum()
            mean = (depth * freq).sum() / sites
            print(mean)
            #plt.semilogy([mean, mean], [0, 10 ** 6], "--")
        plt.xlabel("depth")
        plt.ylabel("frequency")
        plt.savefig(output[0])


######################################### Specificity ##########################


specificity_plot_caller = {
    query: "alpaca/0.05 gatk freebayes samtools".split() for query in config["specificity"]["queries"]
}


rule specificity_covered_depths:
    input:
        bams=lambda wildcards: expand(
            "data/reads/{sample}.bam",
            sample=[
                config["samples"][sample]
                for sample in get_samples_from_query(wildcards.query, filter_only=True)
            ]
        )
    output:
        "specificity/covered_depths/{query}.txt"
    resources: benchmark=1
    shell:
        # use the same samtools params as alpaca
        "samtools mpileup -d 250 -q 17 -Q 13 -C 50 {input.bams} | "
        r"awk -F '\t' '{{depth=0}} {{for (i=4; i<=NF; i+=3) depth+=$i}} {{print depth}}' | "
        r"sort -n | uniq -c | awk '{{print $1,$2}}' OFS='\t' > '{output}'"


def specificity_calls(pattern, blacklist=[]):
    def apply(wildcards):
        return expand(
            pattern,
            query=wildcards.query,
            caller=[
                caller for caller in specificity_plot_caller[wildcards.query]
                if caller not in blacklist
            ]
        )
    return apply


def get_covered_depths(f):
    _covered_depths = np.loadtxt(f, dtype=np.int64)
    max_covered_depth = _covered_depths[:,1].max()
    covered_depths = np.zeros(max_covered_depth + 1, dtype=np.int64)
    covered_depths[_covered_depths[:,1]] = _covered_depths[:,0]
    # depths higher than 250 are capped by alpaca.
    # however we won't plot them, so biased counts do not matter there
    return covered_depths, max_covered_depth


def plot_fpr(f, covered_depths, max_covered_depth, style, label=None, recordfilter=None):
    with open(f) as f:
        reader = vcf.Reader(f)
        if recordfilter:
            reader = filter(recordfilter, reader)
        fp_depths = np.array([
            sum(
                (call.data.DP if call.data.DP is not None else 0)
                for call in record.samples
            ) for record in reader if record.is_snp
        ])

    fp = np.bincount(fp_depths, minlength=max_covered_depth + 1)
    tn = covered_depths - fp

    cum_fp = fp[::-1].cumsum()[::-1]
    cum_tn = tn[::-1].cumsum()[::-1]

    cum_n = cum_fp + cum_tn
    fpr = cum_fp / (cum_n)

    plt.semilogy(np.arange(max_covered_depth + 1), fpr, style, label=label)


rule plot_specificity:
    input:
        calls=specificity_calls("{caller}/{query}.vcf"),
        covered="specificity/covered_depths/{query}.txt"
    output:
        "plots/specificity/{query}.depth.{layout,[lxy]+}.pdf"
    run:
        covered_depths, max_covered_depth = get_covered_depths(input.covered)

        styles = "- -- : -. |-".split()
        figure(figsize=(2.8, 2.5))
        for i, (f, caller) in enumerate(zip(
            input.calls, specificity_plot_caller[wildcards.query]
        )):
            plot_fpr(
                f, covered_depths, max_covered_depth,
                styles[i], label=config["caller_names"][caller.split("/")[0]]
            )
        if "x" in wildcards.layout:
            plt.xlabel("minimum depth")
        if "y" in wildcards.layout:
            plt.ylabel("FPR")
        plt.xlim((1,60))
        plt.ylim((1e-8, 1e-2))
        if "l" in wildcards.layout:
            plt.legend(loc="upper right", handlelength=2.5)
        savefig(output[0], bbox_inches="tight")


rule plot_specificity_parameter_space:
    input:
        calls="alpaca/10/A0-A1.vcf",
        fdrcalls="alpaca/0.05/A0-A1.vcf",
        covered="specificity/covered_depths/A0-A1.txt"
    output:
        "plots/specificity/parameter_space.pdf"
    run:
        figure(figsize=(2.8, 2.5))
        covered_depths, max_covered_depth = get_covered_depths(input.covered)
        for min_qual in config["parameter_space"]["min_qual"]:
            plot_fpr(
                input.calls, covered_depths, max_covered_depth, ":",
                recordfilter=lambda record: record.QUAL >= min_qual
            )
        plot_fpr(
            input.fdrcalls, covered_depths, max_covered_depth, "-"
        )
        plt.xlabel("minimum depth")
        plt.ylabel("FPR")
        plt.xlim((1,60))
        #plt.ylim((1e-6, 1e-2))
        savefig(output[0], bbox_inches="tight")


##################################### Sensitivity ##############################


sensitivity_plot_caller = {
    query: "alpaca/0.05 gatk freebayes samtools".split() for query in config["sensitivity"]["queries"]
}


rule sensitivity_gold_variants:
    input:
        vcf=expand("data/gold_variants/{sample}.vcf", sample=config["sensitivity"]["sample"]),
        gtf="ref.genes.gtf"
    output:
        "sensitivity/gold_variants.vcf"
    params:
        exclude=" --not-chr ".join(config["sensitivity"]["exclude"])
    resources: benchmark=1
    shell:
        "bedtools intersect -header -u -a {input.vcf} -b {input.gtf} | "
        "vcftools --vcf - --stdout --remove-indels --not-chr {params.exclude} "
        "--recode > {output}"


rule sensitivity_exclusive_gold_variants:
    input:
        filter_vcf=expand("data/gold_variants/{sample}.fixed.vcf", sample=config["sensitivity"]["filter_sample"]),
        vcf="sensitivity/gold_variants.vcf"
    output:
        "sensitivity/exclusive_gold_variants.vcf"
    resources: benchmark=1
    shell:
        "grep '#' {input.vcf} > {output} && "
        "bedtools subtract -A -a {input.vcf} -b {input.filter_vcf} >> {output}"


rule sensitivity_true_positives:
    input:
        calls="{caller}/{query}.vcf",
        gold="sensitivity/exclusive_gold_variants.vcf"
    output:
        "sensitivity/tp/{caller}/{query,[^\.]+}.vcf"
    resources: benchmark=1
    shell:
        'bedtools intersect -header -u -f 1 -a "{input.calls}" -b {input.gold} > "{output}"'


rule sensitivity_false_negatives:
    input:
        calls="{caller}/{query}.vcf",
        gold="sensitivity/exclusive_gold_variants.vcf"
    output:
        "sensitivity/fn/{caller}/{query,[^\.]+}.vcf"
    resources: benchmark=1
    shell:
        'grep "#" {input.gold} > "{output}" && '
        'bedtools subtract -a {input.gold} -b "{input.calls}" >> "{output}"'


rule sensitivity_allele_depth:
    input:
        vcf="sensitivity/{status}/{caller}/{query}.vcf",
        index=lambda wildcards: expand(
            "alpaca/{sample}.index.gh.hdf5",
            sample=config["samples"][wildcards.sample]
        )
    output:
        "sensitivity/{status}/{caller}/{query}.{sample}.allele_depth.csv"
    threads: 12
    resources: benchmark=1
    shell:
        "alpaca --threads {threads} peek {input.index} < '{input.vcf}' > '{output}'"


def sensitivity_expand_called_samples(pattern):
    def apply(wildcards):
        samples = get_samples_from_query(wildcards.query, call_only=True)
        return expand(
            pattern,
            query=wildcards.query,
            status=wildcards.status,
            caller=wildcards.caller,
            sample=samples
        )
    return apply


rule sensitivity_depth:
    input:
        sensitivity_expand_called_samples(
            "sensitivity/{status}/{caller}/{query}.{sample}.allele_depth.csv"
        )
    output:
        "sensitivity/{status}/{caller}/{query}.depth.csv"
    params:
        pipes=sensitivity_expand_called_samples(
            # cut the third column since it contains the overall depth
            "<( cut -f 3 'sensitivity/{status}/{caller}/{query}.{sample}.allele_depth.csv')"
        )
    resources: benchmark=1
    shell:
        # paste the 3rd columns into one file
        "paste {params.pipes} > '{output}'"


def plot_sensitivity(tp_file, fn_file, style, label=None):
    def sum_depths(f):
        d = np.loadtxt(f, skiprows=1, dtype=np.int32)
        if len(d.shape) == 1:
            return d
        return d.sum(axis=1)

    tp_depths = sum_depths(tp_file)
    fn_depths = sum_depths(fn_file)
    max_depth = max(np.max(fn_depths), np.max(tp_depths))

    fn = np.bincount(fn_depths, minlength=max_depth + 1)
    tp = np.bincount(tp_depths, minlength=max_depth + 1)

    cum_fn = fn[::-1].cumsum()[::-1]
    cum_tp = tp[::-1].cumsum()[::-1]

    cum_p = cum_tp + cum_fn
    sensitivity = cum_tp / cum_p

    plt.plot(np.arange(max_depth + 1), sensitivity, style, label=label)


rule plot_sensitivity:
    input:
        fn=lambda wildcards: expand(
            "sensitivity/fn/{caller}/{query}.depth.csv",
            caller=sensitivity_plot_caller[wildcards.query],
            query=wildcards.query
        ),
        tp=lambda wildcards: expand(
            "sensitivity/tp/{caller}/{query}.depth.csv",
            caller=sensitivity_plot_caller[wildcards.query],
            query=wildcards.query
        )
    output:
        "plots/sensitivity/{query}.depth.{layout,[lxy]+}.pdf"
    run:
        figure(figsize=(2.8, 2.5))
        styles = "- -- : -. |-".split()
        
        for i, (caller, tp, fn) in enumerate(zip(
            sensitivity_plot_caller[wildcards.query], input.tp, input.fn
        )):
            plot_sensitivity(
                tp, fn, styles[i], label=config["caller_names"][caller.split("/")[0]]
            )
        if "x" in wildcards.layout:
            plt.xlabel("minimum depth")
        if "y" in wildcards.layout:
            plt.ylabel("TPR")
        plt.ylim((0, 1))
        plt.xlim((0, config["sensitivity"]["xlim"]))
        if "l" in wildcards.layout:
            plt.legend(loc="lower right", handlelength=2.5)
        savefig(output[0], bbox_inches="tight")


rule plot_sensitivity_parameter_space:
    input:
        fn=lambda wildcards: expand(
            "sensitivity/fn/alpaca/{min_qual}/A0.depth.csv",
            min_qual=config["parameter_space"]["min_qual"]
        ),
        tp=lambda wildcards: expand(
            "sensitivity/tp/alpaca/{min_qual}/A0.depth.csv",
            min_qual=config["parameter_space"]["min_qual"]
        ),
        fdr_fn="sensitivity/fn/alpaca/0.05/A0.depth.csv",
        fdr_tp="sensitivity/tp/alpaca/0.05/A0.depth.csv"
    output:
        "plots/sensitivity/parameter_space.pdf"
    run:
        figure(figsize=(2.8, 2.5))
        for min_qual, tp, fn in zip(config["parameter_space"]["min_qual"], input.tp, input.fn):
            plot_sensitivity(tp, fn, ":")
        plot_sensitivity(input.fdr_tp, input.fdr_fn, "-")
        plt.xlabel("minimum depth")
        plt.ylabel("TPR")
        plt.ylim((0,1))
        plt.xlim((0, config["sensitivity"]["xlim"]))
        savefig(output[0], bbox_inches="tight")


########################### compression ########################################

compression_profiles = "n h l ls lh lsh g gs gh gsh".split()[::-1]


def compression_get_benchmarks(write=False):
    pattern = (
        "benchmarks/alpaca/{sample}.index.{compression}.json"
        if write else
        "benchmarks/alpaca/A0+A1.merge.{compression}.json"
    )
    return expand(
        pattern,
        sample=config["samples"]["A0"],
        compression=compression_profiles
    )


rule plot_compression_efficiency:
    input:
        benchmarks_read=compression_get_benchmarks(),
        benchmarks_write=compression_get_benchmarks(write=True),
        indexes=expand(
            "alpaca/{sample}.index.{compression}.hdf5",
            sample=config["samples"]["A0"],
            compression=compression_profiles
        )
    output:
        "plots/compression/runtime_vs_size.pdf"
    run:
        def get_runtimes(f):
            with open(f) as f:
                return json.load(f)["wall_clock_times"]["s"]

        sizes = {
            compr: os.path.getsize(f) / 1024 ** 3
            for compr, f in zip(compression_profiles, input.indexes)
        }
        runtimes_write = {
            compr: get_runtimes(f)
            for compr, f in zip(compression_profiles, input.benchmarks_write)
        }
        runtimes_read = {
            compr: get_runtimes(f)
            for compr, f in zip(compression_profiles, input.benchmarks_read)
        }
        
        fig = plt.figure(figsize=(7,3))
        

        ax = Subplot(fig, 131)
        ax.axis["right"].set_visible(False)
        ax.axis["top"].set_visible(False)
        ax.axis["left"].line.set_visible(False)
        ax.axis["left"].toggle(ticks=False)
        fig.add_subplot(ax)

        def sep():
            x = plt.xlim()
            for i in range(len(compression_profiles)):
                y = [i + 0.5] * 2
                plt.plot(x, y, ":k", linewidth=0.5)

        ylim = (-0.5, len(compression_profiles) + 1.5)
        
        for i, compr in enumerate(compression_profiles):
            plt.plot(sizes[compr], i, "o")
        plt.xlabel("size of index [GB]")
        plt.xlim((0, plt.xlim()[1]))
        plt.ylim(ylim)
        sep()
        plt.gca().set_yticks(np.arange(len(compression_profiles)))
        plt.gca().set_yticklabels(compression_profiles)

        def plot_runtimes(subplt, runtimes, type, xlim): 
            ax = Subplot(fig, subplt)
            ax.axis["right"].set_visible(False)
            ax.axis["top"].set_visible(False)
            ax.axis["left"].set_visible(False)
            fig.add_subplot(ax)

            for i, compr in enumerate(compression_profiles):
                _runtimes = runtimes[compr]
                plt.plot(_runtimes, [i] * len(_runtimes), "o")
            plt.xlabel("run time of {} [s]".format(type))
            plt.ylim(ylim)
            plt.xlim((0, xlim))
            plt.locator_params(nbins=6)
            sep()
        plot_runtimes(132, runtimes_read, "merging", 600)
        plot_runtimes(133, runtimes_write, "indexing", 2500)

        plt.savefig(output[0], bbox_inches="tight")


############################ floating point precision ##########################


rule fp_precision_call:
    input:
        "alpaca/A0+A1.index.{compression}.hdf5"
    output:
        "fp-precision/call.{compression,[nhsgl]+}.vcf"
    log:
        "logs/fp-precision/call.{compression}.log"
    threads: 12
    resources: benchmark=1
    shell:
        "alpaca --debug --threads {threads} --buffersize 10000000 call "
        "{input} A0-A1 --fdr > {output} 2> {log}"


rule table_precision_loss:
    input:
        n="fp-precision/call.n.vcf",
        h="fp-precision/call.h.vcf"
    output:
        "tables/fp-precision-qualdiff.csv"
    resources: benchmark=1
    run:
        with open(input.n) as fn, open(input.h) as fh:
            uncompressed = {(rec.CHROM, rec.POS): rec.QUAL for rec in vcf.Reader(fn)}
            compressed = {(rec.CHROM, rec.POS): rec.QUAL for rec in vcf.Reader(fh)}
        print("vcf loaded")
        assert compressed.keys() == uncompressed.keys()

        loci = uncompressed.keys()
        x = np.array([uncompressed[locus] for locus in loci])
        y = np.array([compressed[locus] for locus in loci])
        counts = np.bincount(np.abs(x - y))
        with open(output[0], "w") as f:
            for i in range(counts.size):
                print(i, counts[i], sep="\t", file=f)


rule plot_likelihoods:
    input:
        expand(
            "alpaca/{sample}.index.{compression}.hdf5",
            sample=config["samples"]["A0"],
            compression="h"
        )
    output:
        "plots/likelihoods_{type,(float|phred)}.pdf"
    run:
        from alpaca.index.view import SampleIndexView
        from alpaca.utils import LOG_TO_PHRED_FACTOR
        with SampleIndexView(input[0], buffersize=100) as index:
            seq_slice = next(iter(index["1"]))
            likelihoods = seq_slice.pileup_allelefreq_likelihoods
        print(likelihoods[:,0])# * LOG_TO_PHRED_FACTOR)
        print(likelihoods[:,1])# * LOG_TO_PHRED_FACTOR)
        print(likelihoods[:,2])# * LOG_TO_PHRED_FACTOR)
        print(LOG_TO_PHRED_FACTOR)
        figure()
        x = np.arange(likelihoods.shape[0])
        plt.plot(x, likelihoods[:, 0], "-")
        plt.plot(x, likelihoods[:, 1], "--")
        plt.plot(x, likelihoods[:, 2], ":")
        plt.ylim((0, -200))
        plt.savefig(output[0])


############################### run time performance ###########################

performance_queries = config["specificity"]["queries"] + config["sensitivity"]["queries"]
performance_sample_names = sorted(config["samples"])
performance_samples = [config["samples"][name] for name in performance_sample_names]
performance_callers = "alpaca gatk freebayes samtools".split()


def performance_write_runtimes(samples, merge, queries, out):
    def get_runtime(f):
        if f is None:
            return "-"
        with open(f) as f:
            s = json.load(f)["wall_clock_times"]["s"][0]
            m = s // 60
            s = s % 60
            return "{:.0f}:{:02.0f}".format(m, s)

    with open(out, "w") as out:
        json.dump(
            [get_runtime(f) for f in samples] +
            [get_runtime(merge)] +
            [get_runtime(f) for f in queries],
            out
        )


rule performance_alpaca:
    input:
        samples=expand("benchmarks/alpaca/{sample}.index.gh.json", sample=performance_samples),
        merge="benchmarks/alpaca/merge.gh.json",
        calls=expand(
            "benchmarks/alpaca/0.05/{query}.call.json", query=performance_queries
        )
    output:
        "performance/alpaca.json"
    run:
        performance_write_runtimes(input.samples, input.merge, input.calls, output[0])


rule performance_gatk:
    input:
        samples=expand("benchmarks/gatk/{sample}.haplotype_caller.json", sample=performance_sample_names),
        calls=expand("benchmarks/gatk/{query}.genotyping.json", query=performance_queries)
    output:
        "performance/gatk.json"
    run:
        performance_write_runtimes(input.samples, None, input.calls, output[0])


rule performance_freebayes:
    input:
        calls=expand("benchmarks/freebayes/{query}.json", query=performance_queries)
    output:
        "performance/freebayes.json"
    run:
        performance_write_runtimes([None] * len(performance_samples), None, input.calls, output[0])


rule performance_samtools:
    input:
        calls=expand("benchmarks/samtools/{query}.json", query=performance_queries)
    output:
        "performance/samtools.json"
    run:
        performance_write_runtimes([None] * len(performance_samples), None, input.calls, output[0])


rule performance:
    input:
        expand("performance/{caller}.json", caller=performance_callers)
    output:
        "tables/performance.csv"
    run:
        table = []
        for caller, f in zip(performance_callers, input):
            with open(f) as f:
                table.append(json.load(f))
        table = np.array(table).T
        with open(output[0], "w") as out:
            writer = csv.writer(out, delimiter="\t")
            writer.writerow(
                ["task"] +
                [config["caller_names"][caller] for caller in performance_callers]
            )
            row_labels = performance_sample_names + ["merging"] + performance_queries
            for label, row in zip(row_labels, table):
                writer.writerow([label] + list(row))


################################ Bayesian Calling ##############################


rule plot_bayesian_calling:
    output:
        "plots/bayesian_calling/single_sample.pdf"
    run:
        from alpaca.prototype import reference_genotype_probability, Pileup, PHRED_TO_LOG_FACTOR

        quals = list(range(ord("("), ord("J"), 10))
        qual_chars = [chr(qual) for qual in quals]
        depths = np.arange(11)

        bases = lambda n, type: (b"AC" * n if type == "heterozygous" else b"CC" * n)
        types = ["heterozygous", "homozygous"]

        figure()
        styles = ["-", "--"]
        handles = [None, None]
        for i, type in enumerate(types):
            for qual in qual_chars:
                pileups = [
                    Pileup(bases(n, type), qual.encode() * n * 2)
                    for n in depths
                ]
                probs = [math.exp(reference_genotype_probability([pileup])) for pileup in pileups]
                handles[i], = plt.semilogy(depths * 2, probs, styles[i], label=type, clip_on=False)

        plt.xlabel("read depth")
        plt.ylabel("reference genotype probability")
        plt.legend(handles, types, "lower left", handlelength=2.5)
        print("QUALS:", [math.exp((qual - 33) * PHRED_TO_LOG_FACTOR) for qual in quals])

        plt.savefig(output[0])


rule table_bayesian_calling_sample_depth:
    output:
        "tables/bayesian_calling/sample_depth.csv"
    run:
        from alpaca.prototype import reference_genotype_probability, Pileup

        setups = [
            [(18, 18), (2, 2)],
            [(10, 10), (10, 10)],
            [(15, 15), (5, 5)],
            [(20, 20), (40, 0)],
            [(20, 20), (0, 0)]
        ]

        with open(output[0], "w") as out:
            print("A", "B", sep="\t", file=out)
            for setup in setups:
                pileups = []
                for ref, alt in setup:
                    pileups.append(Pileup(b"A" * ref + b"C" * alt, b"<" * (ref + alt)))
                    print(ref, alt, sep="+", end="\t", file=out)
                print("{:.4e}".format(math.exp(reference_genotype_probability(pileups))), file=out)


rule plot_algebraic_calling:
    output:
        "plots/algebraic_calling/depth.pdf"
    run:
        from alpaca.prototype import reference_genotype_probability, Pileup, difference
        def calc_prob(n):
            pileup_a = Pileup(b"A" * 20 + b"C" * 20, b"<" * 40)
            pileup_b = Pileup(b"A" * n + b"C" * n, b"<" * n * 2)
            prob_a = reference_genotype_probability([pileup_a])
            prob_b = reference_genotype_probability([pileup_b])
            return difference(prob_a, prob_b), prob_b

        depths = np.arange(6)
        figure()
        probs = np.array([calc_prob(n) for n in depths])
        query_prob = np.exp(probs[:, 0])
        sample_prob = np.exp(probs[:, 1])
        plt.semilogy(depths * 2, query_prob, "-", clip_on=False, label="query probability")
        plt.semilogy(depths * 2, sample_prob, "--", clip_on=False, label="filter sample probability")
        plt.xlabel("filter sample depth")
        plt.legend(loc="lower left", handlelength=2.5)
        plt.savefig(output[0])


rule plot_priors:
    output:
        "plots/bayesian_calling/priors.pdf"
    run:
        het = 0.001
        def prob(m, sample_count):
            if m == 0:
                return 1 - math.fsum(prob(i, sample_count) for i in range(1, 2 * sample_count))
            return het / m
        figure()
        x = np.arange(1, 1001)
        y = [prob(0, n) for n in x]
        plt.plot(x, y, "-")
        plt.xlabel("number of samples")
        plt.ylabel("$Pr(M = 0)$")
        plt.savefig(output[0])


##################################### Alpaca ###################################


rule alpaca_index:
    input:
        "data/reads/{sample}.bam.bai",
        bam="data/reads/{sample}.bam",
        ref="ref.fasta",
        refindex="ref.fasta.fai"
    output:
        "alpaca/{sample}.index.{compression,[nhsgl]+}.hdf5"
    params:
        sample=get_sample_name
    threads: 12
    resources: benchmark=100
    log:
        "logs/alpaca/{sample}.{compression}.log"
    benchmark:
        "benchmarks/alpaca/{sample}.index.{compression}.json"
    shell:
        "alpaca --debug --threads {threads} --buffersize 100000000 index "
        "--sample-name {params.sample} "
        "--compression {wildcards.compression} "
        "{input.ref} {input.bam} {output} 2> {log}"


rule alpaca_merge:
    input:
        "ref.fasta",
        lambda wildcards: expand(
            "alpaca/{sample}.index.{compression}.hdf5",
            sample=config["samples"].values(),
            compression=wildcards.compression
        )
    output:
        "alpaca/all.index.{compression,[nhsgl]+}.hdf5"
    threads: 12
    resources: benchmark=100
    log:
        "logs/alpaca/merge.{compression}.log"
    benchmark:
        "benchmarks/alpaca/merge.{compression}.json"
    shell:
        "alpaca --debug --threads {threads} merge {input} {output} 2> {log}"


rule alpaca_merge_two:
    input:
        "ref.fasta",
        lambda wildcards: expand(
            "alpaca/{sample}.index.{compression}.hdf5",
            sample=[config["samples"]["A0"], config["samples"]["A1"]],
            compression=wildcards.compression
        )
    output:
        "alpaca/A0+A1.index.{compression,[nhsgl]+}.hdf5"
    threads: 12
    resources: benchmark=100
    log:
        "logs/alpaca/A0+A1.merge.{compression}.log"
    benchmark:
        "benchmarks/alpaca/A0+A1.merge.{compression}.json"
    shell:
        "alpaca --debug --threads {threads} merge {input} {output} 2> {log}"


def alpaca_filter(wildcards):
    try:
        int(wildcards.filter)
    except ValueError:
        # FDR filtering
        return "--fdr " + wildcards.filter
    return "--min-qual " + wildcards.filter


rule alpaca_call:
    input:
        "alpaca/all.index.gh.hdf5"
    output:
        "alpaca/{filter}/{query,[^\.]+}.vcf"
    params:
        filter=alpaca_filter
    threads: 12
    resources: benchmark=100
    log:
        "logs/alpaca/{filter}/{query}.call.log"
    benchmark:
        "benchmarks/alpaca/{filter}/{query}.call.json"
    shell:
        "alpaca --debug --threads {threads} --buffersize 10000000 call "
        "{input} '{wildcards.query}' {params.filter} > '{output}' 2> '{log}'"


################################# ALPACA show ##################################


rule alpaca_show:
    input:
        "alpaca/0.05/A0-B0.vcf"
    output:
        "tables/A0-B0.html"
    shell:
        "head -n 1000 {input} | alpaca annotate | "
        "grep -P '(#|missense|stop)' | alpaca show > {output}"

        
##################################### GATK #####################################


rule gatk_haplotype_caller:
    input:
        lambda wildcards: expand("data/reads/{sample}.bam.bai", sample=config["samples"][wildcards.sample]),
        "ref.dict",
        bam=lambda wildcards: expand("data/reads/{sample}.bam", sample=config["samples"][wildcards.sample]),
        ref="ref.fasta"
    output:
        "gatk/{sample}.gvcf"
    log:
        "logs/gatk/{sample}.log"
    benchmark:
        "benchmarks/gatk/{sample}.haplotype_caller.json"
    threads: 2
    resources: benchmark=100
    shell:
        "gatk -T HaplotypeCaller -nct {threads} -R {input.ref} -I {input.bam} "
        "--emitRefConfidence GVCF --variant_index_type LINEAR "
        "-variant_index_parameter 128000 "
        "-o {output} &> {log}"


rule gatk_genotyping:
    input:
        expand_samples_from_query("gatk/{sample}.gvcf"),
        ref="ref.fasta"
    output:
        "gatk/{query,[^\.]+}.raw.vcf"
    params:
        gvcfs=expand_samples_from_query("--variant gatk/{sample}.gvcf")
    log:
        "logs/gatk/{query}.call.log"
    benchmark:
        "benchmarks/gatk/{query}.genotyping.json"
    threads: 12
    resources: benchmark=100
    shell:
        "gatk -T GenotypeGVCFs {params.gvcfs} -nt {threads} -R {input.ref} "
        "-o '{output}' &> '{log}'"


def gatk_query_to_select(query):
    translate = config["samples"].get

    get_sample = lambda sample: 'vc.getGenotype("{}")'

    filter_samples = " && ".join(map(
        '(vc.getGenotype("{0}").isHomRef() || !vc.getGenotype("{0}").isCalled())'.format,
        map(translate, get_samples_from_query(query, filter_only=True))
    ))
    call_samples = " || ".join(map(
        '(vc.getGenotype("{0}").isCalled() && !vc.getGenotype("{0}").isHomRef())'.format,
        map(translate, get_samples_from_query(query, call_only=True))
    ))

    select = "-select '{}'".format(call_samples)
    if filter_samples:
        select += " -select '{}'".format(filter_samples)
    return select

rule gatk_filter_by_query:
    input:
        ref="ref.fasta",
        vcf="{caller}/{query}.raw.vcf"
    output:
        "{caller,(freebayes|gatk|samtools)}/{query,[^\.]+}.vcf"
    params:
        select=lambda wildcards: gatk_query_to_select(wildcards.query)
    log:
        "logs/{caller}/{query}.select.log"
    benchmark:
        "benchmarks/{caller}/{query}.select.json"
    threads: 2
    resources: benchmark=1
    shell:
        "gatk -T SelectVariants -R {input.ref} --variant '{input.vcf}' "
        "-nt {threads} {params.select} -o '{output}' &> '{log}'"


##################################### Freebayes ################################

rule freebayes:
    input:
        expand_samples_from_query("data/reads/{sample}.bam.bai", orig_name=True),
        bams=expand_samples_from_query("data/reads/{sample}.bam", orig_name=True),
        ref="ref.fasta",
        fai="ref.fasta.fai"
    output:
        "freebayes/{query,[^\.]+}.raw.vcf"
    log:
        "logs/freebayes/{query}.log"
    benchmark:
        "benchmarks/freebayes/{query}.json"
    threads: 12
    resources: benchmark=100
    shell:
        "freebayes-parallel <(fasta_generate_regions.py {input.fai} 1000000) "
        "{threads} -f {input.ref} --no-indels --no-mnps --no-complex "
        "{input.bams} > '{output}' 2> '{log}'"


################################# MuTect #######################################


# MuTect crashes with the used datasets
rule mutect:
    input:
        expand_samples_from_query("data/reads/{sample}.bam.bai", orig_name=True),
        bams=expand_samples_from_query("data/reads/{sample}.bam", orig_name=True),
        ref="ref.fasta",
        fai="ref.fasta.fai"
    output:
        "mutect/{query,(A0-A1|A0-B0)}.call_stats.out"
    log:
        "logs/mutect/{query}.log"
    benchmark:
        "benchmarks/mutect/{query}.json"
    threads: 12
    resources: benchmark=100
    shell:
        "mutect -nt {threads} --analysis_type MuTect --reference_sequence {input.ref} "
        "--input_file:normal {input.bams[1]} --input_file:tumor {input.bams[0]} "
        "--out {output} &> {log}"


#################################### Samtools ##################################


rule samtools:
    input:
        expand_samples_from_query("data/reads/{sample}.bam.bai", orig_name=True),
        bams=expand_samples_from_query("data/reads/{sample}.bam", orig_name=True),
        ref="ref.fasta",
        fai="ref.fasta.fai"
    output:
        "samtools/{query,[^\.]+}.raw.vcf"
    log:
        "logs/samtools/{query}.log"
    benchmark:
        "benchmarks/samtools/{query}.json"
    threads: 12
    resources: benchmark=100
    shell:
        "cut -f1 {input.fai} | parallel -j {threads} "
        "'samtools mpileup -t DP --skip-indels -gu -C50 -f {input.ref} "
        "-r {{}} {input.bams} | "
        "bcftools call -m -v - > \"{output}.{{}}.bcf\"'; "
        "cut -f1 {input.fai} | "
        "xargs -I '{{}}' bcftools concat -Ov '{output}.{{}}.bcf' > '{output}'; "
        "cut -f1 {input.fai} | xargs -I '{{}}' rm '{output}.{{}}.bcf'"


##################################### Utils ####################################



def get_subsample_param(wildcards):
    seed = int(wildcards.sample[2:]) + int(wildcards.run)
    return "{}{}".format(seed, wildcards.fraction)


rule subsample:
    input:
        "data/reads/{sample}.bam"
    output:
        temp("data/reads/{sample}.subsample-{fraction}-{run,\d+}.pre.bam")
    params:
        subsample=get_subsample_param
    resources: benchmark=1
    shell:
        "samtools view -bs {params.subsample} {input} > {output}"


rule set_read_group:
    input:
        "data/reads/{sample}.subsample-{fraction}-{run}.pre.bam"
    output:
        "data/reads/{sample}.subsample-{fraction}-{run,\d+}.bam"
    params:
        rgsm="{sample}.subsample-{fraction}-{run}"
    log:
        "logs/reads/{sample}.subsample-{fraction}-{run}.set_read_group.log"
    resources: benchmark=1
    shell:
        "picard-tools AddOrReplaceReadGroups INPUT={input} OUTPUT={output} "
        "RGID={params.rgsm} RGLB={params.rgsm} RGPU={params.rgsm} "
        "VALIDATION_STRINGENCY=SILENT "
        "RGPL=illumina RGSM={params.rgsm} &> {log}"


rule spike_sample:
    input:
        base="data/reads/{base}.bam",
        spike="data/reads/{spike}.bam",
        variants="data/gold_variants/{variants}.vcf"
    output:
        "data/reads/{base}+{spike}__at__{variants}__with__{p}.bam",
        "data/reads/{base}+{spike}__at__{variants}__with__{p}.valid.vcf"
    resources: benchmark=1
    run:
        p = float(wildcards.p)
        min_depth = config["sensitivity"]["min_depth"]
        with pysam.Samfile(input.base, "rb") as base, pysam.Samfile(input.spike, "rb") as spike, open(input.variants) as variants, open(output.vcf, "w") as out_variants:
            with pysam.Samfile(output.bam, "wb", template=base) as out_reads:
                variants = vcf_reader(variants)
                out_variants = vcf_writer(out_variants, variants)

                for record in variants:
                    # fetch reads at this locus
                    base_reads = list(base.fetch(record.CHROM, record.POS))
                    spike_reads = list(spike.fetch(record.CHROM, record.POS))

                    if len(base_reads) < min_depth:
                        print("skipping locus due to low base sample coverage", file=sys.stderr)
                        continue

                    # select the reads to spike in from binomial distribution
                    k = np.random.binomial(len(base_reads), p)
                    if len(spike_reads) < k:
                        print("skipping locus due to low spike sample coverage", file=sys.stderr)
                        continue

                    # sample the reads according to k
                    base_keep = random.sample(base_reads, len(base_reads) - k)
                    spike_keep = random.sample(spike_reads, k)
                    mixed = base_reads[base_keep] + spike_reads[spike_keep]

                    # write reads and variants
                    for read in mixed:
                        out_reads.write(read)
                    out_variants.write(record)


rule subtract_variants:
    input:
        a="data/gold_variants/{a}.vcf",
        b="data/gold_variants/{b}.vcf"
    output:
        "data/gold_variants/{a}-{b}.vcf"
    resources: benchmark=1
    shell:
        "bedtools subtract -A -a {input.a} -b {input.b} > {output}"


rule download_reads:
    output:
        "data/reads/{sample}.bam"
    params:
        url=lambda wildcards: config["raw_samples"][wildcards.sample]["reads"]
    shell:
        "wget -O {output} {params.url}"


rule download_variants:
    output:
        "data/gold_variants/{sample}.vcf"
    params:
        url=lambda wildcards: config["raw_samples"][wildcards.sample]["gold_variants"]
    shell:
        """
        wget -O - {params.url} > {output}.gz
        gzip -d {output}.gz || true
        """

rule fix_variants:
    input:
        "data/gold_variants/{sample}.vcf"
    output:
        "data/gold_variants/{sample}.fixed.vcf"
    resources: benchmark=1
    shell:
        "sed 's/chr//' {input} > {output}"


rule download_ref:
    output:
        "ref.fasta"
    shell:
        """
        wget -O {output}.gz ftp://ftp.1000genomes.ebi.ac.uk/vol1/ftp/technical/reference/phase2_reference_assembly_sequence/hs37d5.fa.gz
        gzip -d --quiet {output}.gz || true  # ignore trailing garbage
        """


rule download_track:
    output:
        "ref.genes.gtf"
    shell:
        "wget -O - ftp://ftp.ensembl.org/pub/release-75/gtf/homo_sapiens/Homo_sapiens.GRCh37.75.gtf.gz | zcat --quiet > {output}"


def figure(right=False, top=False, left=True, bottom=True, figsize=None):
    fig = plt.figure(figsize=figsize)
    ax = Subplot(fig, 111)
    ax.axis["right"].set_visible(right)
    ax.axis["top"].set_visible(top)
    ax.axis["left"].set_visible(left)
    ax.axis["bottom"].set_visible(bottom)
    fig.add_subplot(ax)
