Analysis pipeline for ALPACA performance and accuracy

The following Snakemake pipeline conducts the complete performance and accuracy analysis of ALPACA as presented in my thesis “Parallelization, Scalability and Reproducibility in Next-Generation Sequencing Analysis”. The corresponding Snakefile can be downloaded here.

# vim: set syntax=python


import random
import sys
import csv
import time
import json
import math

import numpy as np
import pysam
import vcf
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid.axislines import Subplot
from scipy.stats import kde, binned_statistic


savefig = partial(plt.savefig, bbox_inches="tight")


shell.prefix("set -o pipefail; ")


######################################### Config ###############################

# see also https://github.com/broadinstitute/somatic-benchmark
config = {
    "raw_samples": {
        "NA12878": {  # daughter
            "gold_variants": (
                "ftp://ftp-trace.ncbi.nih.gov/giab/ftp/data/NA12878/variant_calls/"
                "NIST/NISTIntegratedCalls_14datasets_131103_allcall_UGHapMerge_HetHomVarPASS"
                "_VQSRv2.18_all_nouncert_excludesimplerep_excludesegdups_"
                "excludedecoy_excludeRepSeqSTRs_noCNVs.vcf.gz"
            ),
            "giab_gold_regions": (
                "ftp://ftp-trace.ncbi.nih.gov/giab/ftp/data/NA12878/variant_calls/NIST/"
                "union13callableMQonlymerged_addcert_nouncert_excludesimplerep_"
                "excludesegdups_excludedecoy_excludeRepSeqSTRs_noCNVs"
                "_v2.18_2mindatasets_5minYesNoRatio.bed.gz"
            ),
            "reads": (
                "ftp://ftp-trace.ncbi.nih.gov/1000genomes/ftp/technical/working/20120117_ceu_trio_b37_decoy/CEUTrio.HiSeq.WEx.b37_decoy.NA12878.clean.dedup.recal.20120117.bam"
            )
        },
        "NA12891": {  # father
            "gold_variants": (
                "ftp://platgene:G3n3s4me@ussd-ftp.illumina.com/NA12891_S1.genome.vcf.gz"
            ),
            "reads": (
                "ftp://ftp-trace.ncbi.nih.gov/1000genomes/ftp/technical/working/20120117_ceu_trio_b37_decoy/CEUTrio.HiSeq.WEx.b37_decoy.NA12891.clean.dedup.recal.20120117.bam"
            )
        },
        "NA12892": {  # mother
            "gold_variants": (
                "ftp://platgene:G3n3s4me@ussd-ftp.illumina.com/NA12892_S1.genome.vcf.gz"
            ),
            "reads": (
                "ftp://ftp-trace.ncbi.nih.gov/1000genomes/ftp/technical/working/20120117_ceu_trio_b37_decoy/CEUTrio.HiSeq.WEx.b37_decoy.NA12892.clean.dedup.recal.20120117.bam"
            )
        }
    },
    "specificity": {
        "queries": [
            "A0-A1", "A0-(A1+A2)", "A0-(A1+A2+A3)", "(A0+A1)-(A2+A3)"
        ],
        "query_plot_layouts": [
            "y", "l", "xy", "x"
        ]
    },
    "sensitivity": {
        "min_depth": 5,
        "sample": "NA12878",
        "filter_sample": "NA12892",
        "queries": [
            "A0-B0", "A0-(B0+B1)", "(A0+A1)-B0", "(A0+A1)-(B0+B1)"
        ],
        "query_plot_layouts": [
            "y", "l", "xy", "x"
        ],
        "xlim": 16,
        "exclude": ['GL000207.1', 'GL000226.1', 'GL000229.1', 'GL000231.1', 'GL000210.1', 'GL000239.1', 'GL000235.1', 'GL000201.1', 'GL000247.1', 'GL000245.1', 'GL000197.1', 'GL000203.1', 'GL000246.1', 'GL000249.1', 'GL000196.1', 'GL000248.1', 'GL000244.1', 'GL000238.1', 'GL000202.1', 'GL000234.1', 'GL000232.1', 'GL000206.1', 'GL000240.1', 'GL000236.1', 'GL000241.1', 'GL000243.1', 'GL000242.1', 'GL000230.1', 'GL000237.1', 'GL000233.1', 'GL000204.1', 'GL000198.1', 'GL000208.1', 'GL000191.1', 'GL000227.1', 'GL000228.1', 'GL000214.1', 'GL000221.1', 'GL000209.1', 'GL000218.1', 'GL000220.1', 'GL000213.1', 'GL000211.1', 'GL000199.1', 'GL000217.1', 'GL000216.1', 'GL000215.1', 'GL000205.1', 'GL000219.1', 'GL000224.1', 'GL000223.1', 'GL000195.1', 'GL000212.1', 'GL000222.1', 'GL000200.1', 'GL000193.1', 'GL000194.1', 'GL000225.1', 'GL000192.1', 'MT', 'NC_012920', 'NC_012920', 'hs37d5']
    },
    "samples": {
        "A0": "NA12878.subsample-0.25-0",
        "A1": "NA12878.subsample-0.25-1",
        "A2": "NA12878.subsample-0.25-2",
        "A3": "NA12878.subsample-0.25-3",
        "B0": "NA12892.subsample-0.25-0",
        "B1": "NA12892.subsample-0.25-1",
        #"B0A0": "NA12892.subsample-0.25-0+NA12878.subsample-0.25-0__at__NA12878-NA12892__with__0.8"
    },
    "max_pileup_depth": 250,
    "caller_names": {
        "gatk": "GATK",
        "alpaca": "ALPACA",
        "freebayes": "FreeBayes",
        "samtools": "SAMtools"
    },
    "parameter_space": {
        "min_qual": list(range(10, 80, 10)),
    }
}


def get_sample_name(
    wildcards,
    sample_to_name={sample: name for name, sample in config["samples"].items()}
):
    return sample_to_name[wildcards.sample]


def get_samples_from_query(query, filter_only=False, call_only=False):
    if filter_only:
        try:
            query = query.split("-")[1]
        except IndexError:
            return []
    if call_only:
        query = query.split("-")[0]
    return [s.strip(" ()") for t in query.split("+") for s in t.split("-")]


def expand_samples_from_query(pattern, filter_only=False, call_only=False, orig_name=False):
    def apply(wildcards):
        samples = get_samples_from_query(
            wildcards.query, filter_only=filter_only, call_only=call_only
        )
        if orig_name:
            samples = [config["samples"][s] for s in samples]
        return expand(
            pattern,
            sample=samples
        )
    return apply


include:
    "https://bitbucket.org/johanneskoester/snakemake-workflows/raw/master/bio/ngs/rules/mapping/samfiles.rules"


target_specificity =  expand(
    "plots/specificity/{query}.depth.{layout}.pdf",
    zip,
    query=config["specificity"]["queries"],
    layout=config["specificity"]["query_plot_layouts"]
)


target_sensitivity = expand(
    "plots/sensitivity/{query}.depth.{layout}.pdf",
    zip,
    query=config["sensitivity"]["queries"],
    layout=config["sensitivity"]["query_plot_layouts"]
)


########################### Target rules #######################################


rule all:
    input:
        target_specificity,
        target_sensitivity,
        "plots/compression/runtime_vs_size.pdf",
        expand(
            "data/reads/{sample}.readcount.txt",
            sample="NA12878 NA12892".split()
        )


rule specificity:
    input:
        target_specificity


rule sensitivity:
    input:
        target_sensitivity


rule other_vcf:
    input:
        expand(
            "{caller}/{query}.vcf",
            caller="gatk".split(),
            query=config["specificity"]["queries"] + config["sensitivity"]["queries"]
        )


######################################### Sample stats #########################


rule read_count:
    input:
        "{prefix}.bam"
    output:
        "{prefix}.readcount.txt"
    shell:
        "samtools idxstats {input} | awk '{{s+=$3+$4}} END {{print s}}' "
        "> {output}"


rule coverage_dist:
    input:
        "{prefix}.bam"
    output:
        "{prefix}.coverage.txt"
    resources: benchmark=1
    shell:
        "samtools mpileup {input} | cut -f4 | sort -n | uniq -c | "
        r"awk '{{print $2,$1}}' OFS='\t' > {output}"


rule plot_coverages:
    input:
        expand("data/reads/{sample}.coverage.txt", sample=config["samples"].values())
    output:
        "plots/coverage.pdf"
    resources: benchmark=1
    run:
        figure()
        for f in input:
            cov = np.loadtxt(f, dtype=np.int64)[1:]
            depth, freq = cov[:,0], cov[:,1]
            plt.semilogy(depth, freq, alpha=0.5)
            sites = freq.sum()
            mean = (depth * freq).sum() / sites
            print(mean)
            #plt.semilogy([mean, mean], [0, 10 ** 6], "--")
        plt.xlabel("depth")
        plt.ylabel("frequency")
        plt.savefig(output[0])


######################################### Specificity ##########################


specificity_plot_caller = {
    query: "alpaca/0.05 gatk freebayes samtools".split() for query in config["specificity"]["queries"]
}


rule specificity_covered_depths:
    input:
        bams=lambda wildcards: expand(
            "data/reads/{sample}.bam",
            sample=[
                config["samples"][sample]
                for sample in get_samples_from_query(wildcards.query, filter_only=True)
            ]
        )
    output:
        "specificity/covered_depths/{query}.txt"
    resources: benchmark=1
    shell:
        # use the same samtools params as alpaca
        "samtools mpileup -d 250 -q 17 -Q 13 -C 50 {input.bams} | "
        r"awk -F '\t' '{{depth=0}} {{for (i=4; i<=NF; i+=3) depth+=$i}} {{print depth}}' | "
        r"sort -n | uniq -c | awk '{{print $1,$2}}' OFS='\t' > '{output}'"


def specificity_calls(pattern, blacklist=[]):
    def apply(wildcards):
        return expand(
            pattern,
            query=wildcards.query,
            caller=[
                caller for caller in specificity_plot_caller[wildcards.query]
                if caller not in blacklist
            ]
        )
    return apply


def get_covered_depths(f):
    _covered_depths = np.loadtxt(f, dtype=np.int64)
    max_covered_depth = _covered_depths[:,1].max()
    covered_depths = np.zeros(max_covered_depth + 1, dtype=np.int64)
    covered_depths[_covered_depths[:,1]] = _covered_depths[:,0]
    # depths higher than 250 are capped by alpaca.
    # however we won't plot them, so biased counts do not matter there
    return covered_depths, max_covered_depth


def plot_fpr(f, covered_depths, max_covered_depth, style, label=None, recordfilter=None):
    with open(f) as f:
        reader = vcf.Reader(f)
        if recordfilter:
            reader = filter(recordfilter, reader)
        fp_depths = np.array([
            sum(
                (call.data.DP if call.data.DP is not None else 0)
                for call in record.samples
            ) for record in reader if record.is_snp
        ])

    fp = np.bincount(fp_depths, minlength=max_covered_depth + 1)
    tn = covered_depths - fp

    cum_fp = fp[::-1].cumsum()[::-1]
    cum_tn = tn[::-1].cumsum()[::-1]

    cum_n = cum_fp + cum_tn
    fpr = cum_fp / (cum_n)

    plt.semilogy(np.arange(max_covered_depth + 1), fpr, style, label=label)


rule plot_specificity:
    input:
        calls=specificity_calls("{caller}/{query}.vcf"),
        covered="specificity/covered_depths/{query}.txt"
    output:
        "plots/specificity/{query}.depth.{layout,[lxy]+}.pdf"
    run:
        covered_depths, max_covered_depth = get_covered_depths(input.covered)

        styles = "- -- : -. |-".split()
        figure(figsize=(2.8, 2.5))
        for i, (f, caller) in enumerate(zip(
            input.calls, specificity_plot_caller[wildcards.query]
        )):
            plot_fpr(
                f, covered_depths, max_covered_depth,
                styles[i], label=config["caller_names"][caller.split("/")[0]]
            )
        if "x" in wildcards.layout:
            plt.xlabel("minimum depth")
        if "y" in wildcards.layout:
            plt.ylabel("FPR")
        plt.xlim((1,60))
        plt.ylim((1e-8, 1e-2))
        if "l" in wildcards.layout:
            plt.legend(loc="upper right", handlelength=2.5)
        savefig(output[0], bbox_inches="tight")


rule plot_specificity_parameter_space:
    input:
        calls="alpaca/10/A0-A1.vcf",
        fdrcalls="alpaca/0.05/A0-A1.vcf",
        covered="specificity/covered_depths/A0-A1.txt"
    output:
        "plots/specificity/parameter_space.pdf"
    run:
        figure(figsize=(2.8, 2.5))
        covered_depths, max_covered_depth = get_covered_depths(input.covered)
        for min_qual in config["parameter_space"]["min_qual"]:
            plot_fpr(
                input.calls, covered_depths, max_covered_depth, ":",
                recordfilter=lambda record: record.QUAL >= min_qual
            )
        plot_fpr(
            input.fdrcalls, covered_depths, max_covered_depth, "-"
        )
        plt.xlabel("minimum depth")
        plt.ylabel("FPR")
        plt.xlim((1,60))
        #plt.ylim((1e-6, 1e-2))
        savefig(output[0], bbox_inches="tight")


##################################### Sensitivity ##############################


sensitivity_plot_caller = {
    query: "alpaca/0.05 gatk freebayes samtools".split() for query in config["sensitivity"]["queries"]
}


rule sensitivity_gold_variants:
    input:
        vcf=expand("data/gold_variants/{sample}.vcf", sample=config["sensitivity"]["sample"]),
        gtf="ref.genes.gtf"
    output:
        "sensitivity/gold_variants.vcf"
    params:
        exclude=" --not-chr ".join(config["sensitivity"]["exclude"])
    resources: benchmark=1
    shell:
        "bedtools intersect -header -u -a {input.vcf} -b {input.gtf} | "
        "vcftools --vcf - --stdout --remove-indels --not-chr {params.exclude} "
        "--recode > {output}"


rule sensitivity_exclusive_gold_variants:
    input:
        filter_vcf=expand("data/gold_variants/{sample}.fixed.vcf", sample=config["sensitivity"]["filter_sample"]),
        vcf="sensitivity/gold_variants.vcf"
    output:
        "sensitivity/exclusive_gold_variants.vcf"
    resources: benchmark=1
    shell:
        "grep '#' {input.vcf} > {output} && "
        "bedtools subtract -A -a {input.vcf} -b {input.filter_vcf} >> {output}"


rule sensitivity_true_positives:
    input:
        calls="{caller}/{query}.vcf",
        gold="sensitivity/exclusive_gold_variants.vcf"
    output:
        "sensitivity/tp/{caller}/{query,[^\.]+}.vcf"
    resources: benchmark=1
    shell:
        'bedtools intersect -header -u -f 1 -a "{input.calls}" -b {input.gold} > "{output}"'


rule sensitivity_false_negatives:
    input:
        calls="{caller}/{query}.vcf",
        gold="sensitivity/exclusive_gold_variants.vcf"
    output:
        "sensitivity/fn/{caller}/{query,[^\.]+}.vcf"
    resources: benchmark=1
    shell:
        'grep "#" {input.gold} > "{output}" && '
        'bedtools subtract -a {input.gold} -b "{input.calls}" >> "{output}"'


rule sensitivity_allele_depth:
    input:
        vcf="sensitivity/{status}/{caller}/{query}.vcf",
        index=lambda wildcards: expand(
            "alpaca/{sample}.index.gh.hdf5",
            sample=config["samples"][wildcards.sample]
        )
    output:
        "sensitivity/{status}/{caller}/{query}.{sample}.allele_depth.csv"
    threads: 12
    resources: benchmark=1
    shell:
        "alpaca --threads {threads} peek {input.index} < '{input.vcf}' > '{output}'"


def sensitivity_expand_called_samples(pattern):
    def apply(wildcards):
        samples = get_samples_from_query(wildcards.query, call_only=True)
        return expand(
            pattern,
            query=wildcards.query,
            status=wildcards.status,
            caller=wildcards.caller,
            sample=samples
        )
    return apply


rule sensitivity_depth:
    input:
        sensitivity_expand_called_samples(
            "sensitivity/{status}/{caller}/{query}.{sample}.allele_depth.csv"
        )
    output:
        "sensitivity/{status}/{caller}/{query}.depth.csv"
    params:
        pipes=sensitivity_expand_called_samples(
            # cut the third column since it contains the overall depth
            "<( cut -f 3 'sensitivity/{status}/{caller}/{query}.{sample}.allele_depth.csv')"
        )
    resources: benchmark=1
    shell:
        # paste the 3rd columns into one file
        "paste {params.pipes} > '{output}'"


def plot_sensitivity(tp_file, fn_file, style, label=None):
    def sum_depths(f):
        d = np.loadtxt(f, skiprows=1, dtype=np.int32)
        if len(d.shape) == 1:
            return d
        return d.sum(axis=1)

    tp_depths = sum_depths(tp_file)
    fn_depths = sum_depths(fn_file)
    max_depth = max(np.max(fn_depths), np.max(tp_depths))

    fn = np.bincount(fn_depths, minlength=max_depth + 1)
    tp = np.bincount(tp_depths, minlength=max_depth + 1)

    cum_fn = fn[::-1].cumsum()[::-1]
    cum_tp = tp[::-1].cumsum()[::-1]

    cum_p = cum_tp + cum_fn
    sensitivity = cum_tp / cum_p

    plt.plot(np.arange(max_depth + 1), sensitivity, style, label=label)


rule plot_sensitivity:
    input:
        fn=lambda wildcards: expand(
            "sensitivity/fn/{caller}/{query}.depth.csv",
            caller=sensitivity_plot_caller[wildcards.query],
            query=wildcards.query
        ),
        tp=lambda wildcards: expand(
            "sensitivity/tp/{caller}/{query}.depth.csv",
            caller=sensitivity_plot_caller[wildcards.query],
            query=wildcards.query
        )
    output:
        "plots/sensitivity/{query}.depth.{layout,[lxy]+}.pdf"
    run:
        figure(figsize=(2.8, 2.5))
        styles = "- -- : -. |-".split()
        
        for i, (caller, tp, fn) in enumerate(zip(
            sensitivity_plot_caller[wildcards.query], input.tp, input.fn
        )):
            plot_sensitivity(
                tp, fn, styles[i], label=config["caller_names"][caller.split("/")[0]]
            )
        if "x" in wildcards.layout:
            plt.xlabel("minimum depth")
        if "y" in wildcards.layout:
            plt.ylabel("TPR")
        plt.ylim((0, 1))
        plt.xlim((0, config["sensitivity"]["xlim"]))
        if "l" in wildcards.layout:
            plt.legend(loc="lower right", handlelength=2.5)
        savefig(output[0], bbox_inches="tight")


rule plot_sensitivity_parameter_space:
    input:
        fn=lambda wildcards: expand(
            "sensitivity/fn/alpaca/{min_qual}/A0.depth.csv",
            min_qual=config["parameter_space"]["min_qual"]
        ),
        tp=lambda wildcards: expand(
            "sensitivity/tp/alpaca/{min_qual}/A0.depth.csv",
            min_qual=config["parameter_space"]["min_qual"]
        ),
        fdr_fn="sensitivity/fn/alpaca/0.05/A0.depth.csv",
        fdr_tp="sensitivity/tp/alpaca/0.05/A0.depth.csv"
    output:
        "plots/sensitivity/parameter_space.pdf"
    run:
        figure(figsize=(2.8, 2.5))
        for min_qual, tp, fn in zip(config["parameter_space"]["min_qual"], input.tp, input.fn):
            plot_sensitivity(tp, fn, ":")
        plot_sensitivity(input.fdr_tp, input.fdr_fn, "-")
        plt.xlabel("minimum depth")
        plt.ylabel("TPR")
        plt.ylim((0,1))
        plt.xlim((0, config["sensitivity"]["xlim"]))
        savefig(output[0], bbox_inches="tight")


########################### compression ########################################

compression_profiles = "n h l ls lh lsh g gs gh gsh".split()[::-1]


def compression_get_benchmarks(write=False):
    pattern = (
        "benchmarks/alpaca/{sample}.index.{compression}.json"
        if write else
        "benchmarks/alpaca/A0+A1.merge.{compression}.json"
    )
    return expand(
        pattern,
        sample=config["samples"]["A0"],
        compression=compression_profiles
    )


rule plot_compression_efficiency:
    input:
        benchmarks_read=compression_get_benchmarks(),
        benchmarks_write=compression_get_benchmarks(write=True),
        indexes=expand(
            "alpaca/{sample}.index.{compression}.hdf5",
            sample=config["samples"]["A0"],
            compression=compression_profiles
        )
    output:
        "plots/compression/runtime_vs_size.pdf"
    run:
        def get_runtimes(f):
            with open(f) as f:
                return json.load(f)["wall_clock_times"]["s"]

        sizes = {
            compr: os.path.getsize(f) / 1024 ** 3
            for compr, f in zip(compression_profiles, input.indexes)
        }
        runtimes_write = {
            compr: get_runtimes(f)
            for compr, f in zip(compression_profiles, input.benchmarks_write)
        }
        runtimes_read = {
            compr: get_runtimes(f)
            for compr, f in zip(compression_profiles, input.benchmarks_read)
        }
        
        fig = plt.figure(figsize=(7,3))
        

        ax = Subplot(fig, 131)
        ax.axis["right"].set_visible(False)
        ax.axis["top"].set_visible(False)
        ax.axis["left"].line.set_visible(False)
        ax.axis["left"].toggle(ticks=False)
        fig.add_subplot(ax)

        def sep():
            x = plt.xlim()
            for i in range(len(compression_profiles)):
                y = [i + 0.5] * 2
                plt.plot(x, y, ":k", linewidth=0.5)

        ylim = (-0.5, len(compression_profiles) + 1.5)
        
        for i, compr in enumerate(compression_profiles):
            plt.plot(sizes[compr], i, "o")
        plt.xlabel("size of index [GB]")
        plt.xlim((0, plt.xlim()[1]))
        plt.ylim(ylim)
        sep()
        plt.gca().set_yticks(np.arange(len(compression_profiles)))
        plt.gca().set_yticklabels(compression_profiles)

        def plot_runtimes(subplt, runtimes, type, xlim): 
            ax = Subplot(fig, subplt)
            ax.axis["right"].set_visible(False)
            ax.axis["top"].set_visible(False)
            ax.axis["left"].set_visible(False)
            fig.add_subplot(ax)

            for i, compr in enumerate(compression_profiles):
                _runtimes = runtimes[compr]
                plt.plot(_runtimes, [i] * len(_runtimes), "o")
            plt.xlabel("run time of {} [s]".format(type))
            plt.ylim(ylim)
            plt.xlim((0, xlim))
            plt.locator_params(nbins=6)
            sep()
        plot_runtimes(132, runtimes_read, "merging", 600)
        plot_runtimes(133, runtimes_write, "indexing", 2500)

        plt.savefig(output[0], bbox_inches="tight")


############################ floating point precision ##########################


rule fp_precision_call:
    input:
        "alpaca/A0+A1.index.{compression}.hdf5"
    output:
        "fp-precision/call.{compression,[nhsgl]+}.vcf"
    log:
        "logs/fp-precision/call.{compression}.log"
    threads: 12
    resources: benchmark=1
    shell:
        "alpaca --debug --threads {threads} --buffersize 10000000 call "
        "{input} A0-A1 --fdr > {output} 2> {log}"


rule table_precision_loss:
    input:
        n="fp-precision/call.n.vcf",
        h="fp-precision/call.h.vcf"
    output:
        "tables/fp-precision-qualdiff.csv"
    resources: benchmark=1
    run:
        with open(input.n) as fn, open(input.h) as fh:
            uncompressed = {(rec.CHROM, rec.POS): rec.QUAL for rec in vcf.Reader(fn)}
            compressed = {(rec.CHROM, rec.POS): rec.QUAL for rec in vcf.Reader(fh)}
        print("vcf loaded")
        assert compressed.keys() == uncompressed.keys()

        loci = uncompressed.keys()
        x = np.array([uncompressed[locus] for locus in loci])
        y = np.array([compressed[locus] for locus in loci])
        counts = np.bincount(np.abs(x - y))
        with open(output[0], "w") as f:
            for i in range(counts.size):
                print(i, counts[i], sep="\t", file=f)


rule plot_likelihoods:
    input:
        expand(
            "alpaca/{sample}.index.{compression}.hdf5",
            sample=config["samples"]["A0"],
            compression="h"
        )
    output:
        "plots/likelihoods_{type,(float|phred)}.pdf"
    run:
        from alpaca.index.view import SampleIndexView
        from alpaca.utils import LOG_TO_PHRED_FACTOR
        with SampleIndexView(input[0], buffersize=100) as index:
            seq_slice = next(iter(index["1"]))
            likelihoods = seq_slice.pileup_allelefreq_likelihoods
        print(likelihoods[:,0])# * LOG_TO_PHRED_FACTOR)
        print(likelihoods[:,1])# * LOG_TO_PHRED_FACTOR)
        print(likelihoods[:,2])# * LOG_TO_PHRED_FACTOR)
        print(LOG_TO_PHRED_FACTOR)
        figure()
        x = np.arange(likelihoods.shape[0])
        plt.plot(x, likelihoods[:, 0], "-")
        plt.plot(x, likelihoods[:, 1], "--")
        plt.plot(x, likelihoods[:, 2], ":")
        plt.ylim((0, -200))
        plt.savefig(output[0])


############################### run time performance ###########################

performance_queries = config["specificity"]["queries"] + config["sensitivity"]["queries"]
performance_sample_names = sorted(config["samples"])
performance_samples = [config["samples"][name] for name in performance_sample_names]
performance_callers = "alpaca gatk freebayes samtools".split()


def performance_write_runtimes(samples, merge, queries, out):
    def get_runtime(f):
        if f is None:
            return "-"
        with open(f) as f:
            s = json.load(f)["wall_clock_times"]["s"][0]
            m = s // 60
            s = s % 60
            return "{:.0f}:{:02.0f}".format(m, s)

    with open(out, "w") as out:
        json.dump(
            [get_runtime(f) for f in samples] +
            [get_runtime(merge)] +
            [get_runtime(f) for f in queries],
            out
        )


rule performance_alpaca:
    input:
        samples=expand("benchmarks/alpaca/{sample}.index.gh.json", sample=performance_samples),
        merge="benchmarks/alpaca/merge.gh.json",
        calls=expand(
            "benchmarks/alpaca/0.05/{query}.call.json", query=performance_queries
        )
    output:
        "performance/alpaca.json"
    run:
        performance_write_runtimes(input.samples, input.merge, input.calls, output[0])


rule performance_gatk:
    input:
        samples=expand("benchmarks/gatk/{sample}.haplotype_caller.json", sample=performance_sample_names),
        calls=expand("benchmarks/gatk/{query}.genotyping.json", query=performance_queries)
    output:
        "performance/gatk.json"
    run:
        performance_write_runtimes(input.samples, None, input.calls, output[0])


rule performance_freebayes:
    input:
        calls=expand("benchmarks/freebayes/{query}.json", query=performance_queries)
    output:
        "performance/freebayes.json"
    run:
        performance_write_runtimes([None] * len(performance_samples), None, input.calls, output[0])


rule performance_samtools:
    input:
        calls=expand("benchmarks/samtools/{query}.json", query=performance_queries)
    output:
        "performance/samtools.json"
    run:
        performance_write_runtimes([None] * len(performance_samples), None, input.calls, output[0])


rule performance:
    input:
        expand("performance/{caller}.json", caller=performance_callers)
    output:
        "tables/performance.csv"
    run:
        table = []
        for caller, f in zip(performance_callers, input):
            with open(f) as f:
                table.append(json.load(f))
        table = np.array(table).T
        with open(output[0], "w") as out:
            writer = csv.writer(out, delimiter="\t")
            writer.writerow(
                ["task"] +
                [config["caller_names"][caller] for caller in performance_callers]
            )
            row_labels = performance_sample_names + ["merging"] + performance_queries
            for label, row in zip(row_labels, table):
                writer.writerow([label] + list(row))


################################ Bayesian Calling ##############################


rule plot_bayesian_calling:
    output:
        "plots/bayesian_calling/single_sample.pdf"
    run:
        from alpaca.prototype import reference_genotype_probability, Pileup, PHRED_TO_LOG_FACTOR

        quals = list(range(ord("("), ord("J"), 10))
        qual_chars = [chr(qual) for qual in quals]
        depths = np.arange(11)

        bases = lambda n, type: (b"AC" * n if type == "heterozygous" else b"CC" * n)
        types = ["heterozygous", "homozygous"]

        figure()
        styles = ["-", "--"]
        handles = [None, None]
        for i, type in enumerate(types):
            for qual in qual_chars:
                pileups = [
                    Pileup(bases(n, type), qual.encode() * n * 2)
                    for n in depths
                ]
                probs = [math.exp(reference_genotype_probability([pileup])) for pileup in pileups]
                handles[i], = plt.semilogy(depths * 2, probs, styles[i], label=type, clip_on=False)

        plt.xlabel("read depth")
        plt.ylabel("reference genotype probability")
        plt.legend(handles, types, "lower left", handlelength=2.5)
        print("QUALS:", [math.exp((qual - 33) * PHRED_TO_LOG_FACTOR) for qual in quals])

        plt.savefig(output[0])


rule table_bayesian_calling_sample_depth:
    output:
        "tables/bayesian_calling/sample_depth.csv"
    run:
        from alpaca.prototype import reference_genotype_probability, Pileup

        setups = [
            [(18, 18), (2, 2)],
            [(10, 10), (10, 10)],
            [(15, 15), (5, 5)],
            [(20, 20), (40, 0)],
            [(20, 20), (0, 0)]
        ]

        with open(output[0], "w") as out:
            print("A", "B", sep="\t", file=out)
            for setup in setups:
                pileups = []
                for ref, alt in setup:
                    pileups.append(Pileup(b"A" * ref + b"C" * alt, b"<" * (ref + alt)))
                    print(ref, alt, sep="+", end="\t", file=out)
                print("{:.4e}".format(math.exp(reference_genotype_probability(pileups))), file=out)


rule plot_algebraic_calling:
    output:
        "plots/algebraic_calling/depth.pdf"
    run:
        from alpaca.prototype import reference_genotype_probability, Pileup, difference
        def calc_prob(n):
            pileup_a = Pileup(b"A" * 20 + b"C" * 20, b"<" * 40)
            pileup_b = Pileup(b"A" * n + b"C" * n, b"<" * n * 2)
            prob_a = reference_genotype_probability([pileup_a])
            prob_b = reference_genotype_probability([pileup_b])
            return difference(prob_a, prob_b), prob_b

        depths = np.arange(6)
        figure()
        probs = np.array([calc_prob(n) for n in depths])
        query_prob = np.exp(probs[:, 0])
        sample_prob = np.exp(probs[:, 1])
        plt.semilogy(depths * 2, query_prob, "-", clip_on=False, label="query probability")
        plt.semilogy(depths * 2, sample_prob, "--", clip_on=False, label="filter sample probability")
        plt.xlabel("filter sample depth")
        plt.legend(loc="lower left", handlelength=2.5)
        plt.savefig(output[0])


rule plot_priors:
    output:
        "plots/bayesian_calling/priors.pdf"
    run:
        het = 0.001
        def prob(m, sample_count):
            if m == 0:
                return 1 - math.fsum(prob(i, sample_count) for i in range(1, 2 * sample_count))
            return het / m
        figure()
        x = np.arange(1, 1001)
        y = [prob(0, n) for n in x]
        plt.plot(x, y, "-")
        plt.xlabel("number of samples")
        plt.ylabel("$Pr(M = 0)$")
        plt.savefig(output[0])


##################################### Alpaca ###################################


rule alpaca_index:
    input:
        "data/reads/{sample}.bam.bai",
        bam="data/reads/{sample}.bam",
        ref="ref.fasta",
        refindex="ref.fasta.fai"
    output:
        "alpaca/{sample}.index.{compression,[nhsgl]+}.hdf5"
    params:
        sample=get_sample_name
    threads: 12
    resources: benchmark=100
    log:
        "logs/alpaca/{sample}.{compression}.log"
    benchmark:
        "benchmarks/alpaca/{sample}.index.{compression}.json"
    shell:
        "alpaca --debug --threads {threads} --buffersize 100000000 index "
        "--sample-name {params.sample} "
        "--compression {wildcards.compression} "
        "{input.ref} {input.bam} {output} 2> {log}"


rule alpaca_merge:
    input:
        "ref.fasta",
        lambda wildcards: expand(
            "alpaca/{sample}.index.{compression}.hdf5",
            sample=config["samples"].values(),
            compression=wildcards.compression
        )
    output:
        "alpaca/all.index.{compression,[nhsgl]+}.hdf5"
    threads: 12
    resources: benchmark=100
    log:
        "logs/alpaca/merge.{compression}.log"
    benchmark:
        "benchmarks/alpaca/merge.{compression}.json"
    shell:
        "alpaca --debug --threads {threads} merge {input} {output} 2> {log}"


rule alpaca_merge_two:
    input:
        "ref.fasta",
        lambda wildcards: expand(
            "alpaca/{sample}.index.{compression}.hdf5",
            sample=[config["samples"]["A0"], config["samples"]["A1"]],
            compression=wildcards.compression
        )
    output:
        "alpaca/A0+A1.index.{compression,[nhsgl]+}.hdf5"
    threads: 12
    resources: benchmark=100
    log:
        "logs/alpaca/A0+A1.merge.{compression}.log"
    benchmark:
        "benchmarks/alpaca/A0+A1.merge.{compression}.json"
    shell:
        "alpaca --debug --threads {threads} merge {input} {output} 2> {log}"


def alpaca_filter(wildcards):
    try:
        int(wildcards.filter)
    except ValueError:
        # FDR filtering
        return "--fdr " + wildcards.filter
    return "--min-qual " + wildcards.filter


rule alpaca_call:
    input:
        "alpaca/all.index.gh.hdf5"
    output:
        "alpaca/{filter}/{query,[^\.]+}.vcf"
    params:
        filter=alpaca_filter
    threads: 12
    resources: benchmark=100
    log:
        "logs/alpaca/{filter}/{query}.call.log"
    benchmark:
        "benchmarks/alpaca/{filter}/{query}.call.json"
    shell:
        "alpaca --debug --threads {threads} --buffersize 10000000 call "
        "{input} '{wildcards.query}' {params.filter} > '{output}' 2> '{log}'"


################################# ALPACA show ##################################


rule alpaca_show:
    input:
        "alpaca/0.05/A0-B0.vcf"
    output:
        "tables/A0-B0.html"
    shell:
        "head -n 1000 {input} | alpaca annotate | "
        "grep -P '(#|missense|stop)' | alpaca show > {output}"

        
##################################### GATK #####################################


rule gatk_haplotype_caller:
    input:
        lambda wildcards: expand("data/reads/{sample}.bam.bai", sample=config["samples"][wildcards.sample]),
        "ref.dict",
        bam=lambda wildcards: expand("data/reads/{sample}.bam", sample=config["samples"][wildcards.sample]),
        ref="ref.fasta"
    output:
        "gatk/{sample}.gvcf"
    log:
        "logs/gatk/{sample}.log"
    benchmark:
        "benchmarks/gatk/{sample}.haplotype_caller.json"
    threads: 2
    resources: benchmark=100
    shell:
        "gatk -T HaplotypeCaller -nct {threads} -R {input.ref} -I {input.bam} "
        "--emitRefConfidence GVCF --variant_index_type LINEAR "
        "-variant_index_parameter 128000 "
        "-o {output} &> {log}"


rule gatk_genotyping:
    input:
        expand_samples_from_query("gatk/{sample}.gvcf"),
        ref="ref.fasta"
    output:
        "gatk/{query,[^\.]+}.raw.vcf"
    params:
        gvcfs=expand_samples_from_query("--variant gatk/{sample}.gvcf")
    log:
        "logs/gatk/{query}.call.log"
    benchmark:
        "benchmarks/gatk/{query}.genotyping.json"
    threads: 12
    resources: benchmark=100
    shell:
        "gatk -T GenotypeGVCFs {params.gvcfs} -nt {threads} -R {input.ref} "
        "-o '{output}' &> '{log}'"


def gatk_query_to_select(query):
    translate = config["samples"].get

    get_sample = lambda sample: 'vc.getGenotype("{}")'

    filter_samples = " && ".join(map(
        '(vc.getGenotype("{0}").isHomRef() || !vc.getGenotype("{0}").isCalled())'.format,
        map(translate, get_samples_from_query(query, filter_only=True))
    ))
    call_samples = " || ".join(map(
        '(vc.getGenotype("{0}").isCalled() && !vc.getGenotype("{0}").isHomRef())'.format,
        map(translate, get_samples_from_query(query, call_only=True))
    ))

    select = "-select '{}'".format(call_samples)
    if filter_samples:
        select += " -select '{}'".format(filter_samples)
    return select

rule gatk_filter_by_query:
    input:
        ref="ref.fasta",
        vcf="{caller}/{query}.raw.vcf"
    output:
        "{caller,(freebayes|gatk|samtools)}/{query,[^\.]+}.vcf"
    params:
        select=lambda wildcards: gatk_query_to_select(wildcards.query)
    log:
        "logs/{caller}/{query}.select.log"
    benchmark:
        "benchmarks/{caller}/{query}.select.json"
    threads: 2
    resources: benchmark=1
    shell:
        "gatk -T SelectVariants -R {input.ref} --variant '{input.vcf}' "
        "-nt {threads} {params.select} -o '{output}' &> '{log}'"


##################################### Freebayes ################################

rule freebayes:
    input:
        expand_samples_from_query("data/reads/{sample}.bam.bai", orig_name=True),
        bams=expand_samples_from_query("data/reads/{sample}.bam", orig_name=True),
        ref="ref.fasta",
        fai="ref.fasta.fai"
    output:
        "freebayes/{query,[^\.]+}.raw.vcf"
    log:
        "logs/freebayes/{query}.log"
    benchmark:
        "benchmarks/freebayes/{query}.json"
    threads: 12
    resources: benchmark=100
    shell:
        "freebayes-parallel <(fasta_generate_regions.py {input.fai} 1000000) "
        "{threads} -f {input.ref} --no-indels --no-mnps --no-complex "
        "{input.bams} > '{output}' 2> '{log}'"


################################# MuTect #######################################


# MuTect crashes with the used datasets
rule mutect:
    input:
        expand_samples_from_query("data/reads/{sample}.bam.bai", orig_name=True),
        bams=expand_samples_from_query("data/reads/{sample}.bam", orig_name=True),
        ref="ref.fasta",
        fai="ref.fasta.fai"
    output:
        "mutect/{query,(A0-A1|A0-B0)}.call_stats.out"
    log:
        "logs/mutect/{query}.log"
    benchmark:
        "benchmarks/mutect/{query}.json"
    threads: 12
    resources: benchmark=100
    shell:
        "mutect -nt {threads} --analysis_type MuTect --reference_sequence {input.ref} "
        "--input_file:normal {input.bams[1]} --input_file:tumor {input.bams[0]} "
        "--out {output} &> {log}"


#################################### Samtools ##################################


rule samtools:
    input:
        expand_samples_from_query("data/reads/{sample}.bam.bai", orig_name=True),
        bams=expand_samples_from_query("data/reads/{sample}.bam", orig_name=True),
        ref="ref.fasta",
        fai="ref.fasta.fai"
    output:
        "samtools/{query,[^\.]+}.raw.vcf"
    log:
        "logs/samtools/{query}.log"
    benchmark:
        "benchmarks/samtools/{query}.json"
    threads: 12
    resources: benchmark=100
    shell:
        "cut -f1 {input.fai} | parallel -j {threads} "
        "'samtools mpileup -t DP --skip-indels -gu -C50 -f {input.ref} "
        "-r {{}} {input.bams} | "
        "bcftools call -m -v - > \"{output}.{{}}.bcf\"'; "
        "cut -f1 {input.fai} | "
        "xargs -I '{{}}' bcftools concat -Ov '{output}.{{}}.bcf' > '{output}'; "
        "cut -f1 {input.fai} | xargs -I '{{}}' rm '{output}.{{}}.bcf'"


##################################### Utils ####################################



def get_subsample_param(wildcards):
    seed = int(wildcards.sample[2:]) + int(wildcards.run)
    return "{}{}".format(seed, wildcards.fraction)


rule subsample:
    input:
        "data/reads/{sample}.bam"
    output:
        temp("data/reads/{sample}.subsample-{fraction}-{run,\d+}.pre.bam")
    params:
        subsample=get_subsample_param
    resources: benchmark=1
    shell:
        "samtools view -bs {params.subsample} {input} > {output}"


rule set_read_group:
    input:
        "data/reads/{sample}.subsample-{fraction}-{run}.pre.bam"
    output:
        "data/reads/{sample}.subsample-{fraction}-{run,\d+}.bam"
    params:
        rgsm="{sample}.subsample-{fraction}-{run}"
    log:
        "logs/reads/{sample}.subsample-{fraction}-{run}.set_read_group.log"
    resources: benchmark=1
    shell:
        "picard-tools AddOrReplaceReadGroups INPUT={input} OUTPUT={output} "
        "RGID={params.rgsm} RGLB={params.rgsm} RGPU={params.rgsm} "
        "VALIDATION_STRINGENCY=SILENT "
        "RGPL=illumina RGSM={params.rgsm} &> {log}"


rule spike_sample:
    input:
        base="data/reads/{base}.bam",
        spike="data/reads/{spike}.bam",
        variants="data/gold_variants/{variants}.vcf"
    output:
        "data/reads/{base}+{spike}__at__{variants}__with__{p}.bam",
        "data/reads/{base}+{spike}__at__{variants}__with__{p}.valid.vcf"
    resources: benchmark=1
    run:
        p = float(wildcards.p)
        min_depth = config["sensitivity"]["min_depth"]
        with pysam.Samfile(input.base, "rb") as base, pysam.Samfile(input.spike, "rb") as spike, open(input.variants) as variants, open(output.vcf, "w") as out_variants:
            with pysam.Samfile(output.bam, "wb", template=base) as out_reads:
                variants = vcf_reader(variants)
                out_variants = vcf_writer(out_variants, variants)

                for record in variants:
                    # fetch reads at this locus
                    base_reads = list(base.fetch(record.CHROM, record.POS))
                    spike_reads = list(spike.fetch(record.CHROM, record.POS))

                    if len(base_reads) < min_depth:
                        print("skipping locus due to low base sample coverage", file=sys.stderr)
                        continue

                    # select the reads to spike in from binomial distribution
                    k = np.random.binomial(len(base_reads), p)
                    if len(spike_reads) < k:
                        print("skipping locus due to low spike sample coverage", file=sys.stderr)
                        continue

                    # sample the reads according to k
                    base_keep = random.sample(base_reads, len(base_reads) - k)
                    spike_keep = random.sample(spike_reads, k)
                    mixed = base_reads[base_keep] + spike_reads[spike_keep]

                    # write reads and variants
                    for read in mixed:
                        out_reads.write(read)
                    out_variants.write(record)


rule subtract_variants:
    input:
        a="data/gold_variants/{a}.vcf",
        b="data/gold_variants/{b}.vcf"
    output:
        "data/gold_variants/{a}-{b}.vcf"
    resources: benchmark=1
    shell:
        "bedtools subtract -A -a {input.a} -b {input.b} > {output}"


rule download_reads:
    output:
        "data/reads/{sample}.bam"
    params:
        url=lambda wildcards: config["raw_samples"][wildcards.sample]["reads"]
    shell:
        "wget -O {output} {params.url}"


rule download_variants:
    output:
        "data/gold_variants/{sample}.vcf"
    params:
        url=lambda wildcards: config["raw_samples"][wildcards.sample]["gold_variants"]
    shell:
        """
        wget -O - {params.url} > {output}.gz
        gzip -d {output}.gz || true
        """

rule fix_variants:
    input:
        "data/gold_variants/{sample}.vcf"
    output:
        "data/gold_variants/{sample}.fixed.vcf"
    resources: benchmark=1
    shell:
        "sed 's/chr//' {input} > {output}"


rule download_ref:
    output:
        "ref.fasta"
    shell:
        """
        wget -O {output}.gz ftp://ftp.1000genomes.ebi.ac.uk/vol1/ftp/technical/reference/phase2_reference_assembly_sequence/hs37d5.fa.gz
        gzip -d --quiet {output}.gz || true  # ignore trailing garbage
        """


rule download_track:
    output:
        "ref.genes.gtf"
    shell:
        "wget -O - ftp://ftp.ensembl.org/pub/release-75/gtf/homo_sapiens/Homo_sapiens.GRCh37.75.gtf.gz | zcat --quiet > {output}"


def figure(right=False, top=False, left=True, bottom=True, figsize=None):
    fig = plt.figure(figsize=figsize)
    ax = Subplot(fig, 111)
    ax.axis["right"].set_visible(right)
    ax.axis["top"].set_visible(top)
    ax.axis["left"].set_visible(left)
    ax.axis["bottom"].set_visible(bottom)
    fig.add_subplot(ax)