#!/usr/bin/env python3
#
#Copyright (C) 2013- The University of Notre Dame
#This software is distributed under the GNU General Public License.
#See the file COPYING for details.
#
# This program implements a distributed version of BWA, using Makeflow and WorkQueue

# Author: Nick Hazekamp
# Date: 09/03/2013


import optparse, subprocess, os, sys, re, string, logging as log, stat
from collections import defaultdict

bwa_args = defaultdict(list)
current_algo = None;

def callback_bwa_args(option, opt_str, value, parser):
	global current_algo
	current_algo = value

def callback_bwa_help(option, opt_str, value, parser):
	args = ["bwa"]
	if value:
		args.append(value)
	subprocess.call(args)
	sys.exit(0)

class PassThroughParser(optparse.OptionParser):
    def _process_args(self, largs, rargs, values):
		global current_algo
		while rargs:
			try:
				optparse.OptionParser._process_args(self,largs,rargs,values)
			except (optparse.AmbiguousOptionError, optparse.BadOptionError) as e:
				bwa_args[current_algo].append(e.opt_str)
				if "--" not in rargs[0]:
					bwa_args[current_algo].append(rargs[0])


def parse_bwa_options(parser):
	group = optparse.OptionGroup(parser, "BWA Options")
	group.add_option('--bwa', type="string", action="callback", callback=callback_bwa_args, help='Algorithm specification and options : Example : --bwa mem -t 4')
	group.add_option('--bwa-help', type="string", action="callback", callback=callback_bwa_help, help='Algorithm specification help')
	group.add_option('--ref', dest='ref', help='The reference genome to use or index')
	group.add_option('--query', dest='fastq', help='The (forward) fastq file to use for the mapping')
	group.add_option('--rquery', dest='rfastq', help='The reverse fastq file to use for mapping if paired-end data')
	group.add_option('--output_SAM', dest='output_SAM', help='The file to save the output (SAM format) [FASTQ.sam]')
	group.add_option('--output_err', dest='output_err', help='The file to save the output (SAM format) [FASTQ.err]')
	group.add_option('--algo', dest='algo', help='The type of Alignment Algorithm (backtrack, bwasw, mem) [mem]')
	parser.add_option_group(group)

def parse_makeflow_options(parser):
	group = optparse.OptionGroup(parser, "Makeflow Preparation Options")
	group.add_option('--makeflow', dest='makeflow', help='Makeflow destination file [stdout]')
	group.add_option('--makeflow-config', dest='config', help='Makeflow configurations')
	group.add_option('--num_seq', dest='num_seq', help='Number of Sequences per partition [50000]', default=50000,type=int)
	group.add_option('--verbose', dest='verbose', help='Show verbose level output', action='store_true',default="False")
	parser.add_option_group(group)

#Helper function for finding executables in path
def search_file(filename, search_path, pathsep=os.pathsep):
    """ Given a search path, find file with requested name """
    for path in string.split(search_path, pathsep):
        candidate = os.path.join(path, filename)
        if os.path.exists(candidate): return os.path.abspath(candidate)
    return None


def count_splits( fastq , num_reads ):
	log.info("Counting Number of Partitions")
	num_outputs = 0
	line_count=0

	FILE = open(fastq, "r")

	read_count = 0
	num_outputs += 1
	for line in FILE:
		if (re.search('^[@]', line) and line_count % 4 ==0):
			if (read_count == num_reads):
				num_outputs += 1
				read_count = 0
				if ( num_outputs % 10 == 0):
					log.info("Current Number of Partitions : {0}".format(num_outputs))
			else:
				read_count += 1
		#place all other lines in FASTQ file under same sequence
		line_count += 1
	FILE.close()
	log.info("Total Number of Partitions : {0}".format(num_outputs))
	return num_outputs

def write_fastq_reduce(destination):
	log.info("Writing Partitioning Script to : {0}".format(destination))
	fastq_r = open(destination, 'w')
	try:
		fastq_r.write('''
#!/usr/bin/perl
#
#Copyright (C) 2013- The University of Notre Dame
#This software is distributed under the GNU General Public License.
#See the file COPYING for details.
#
#Programmer: Brian Kachmarck
#Date: 7/28/2009
#
#Revised: Nick Hazekamp
#Date: 12/02/2013
#
#Purpose: Split a FASTQ file into smaller files determined by the number of sequences input

use strict;


my $numargs = $#ARGV + 1;

my $file = $ARGV[0];

my $num_reads= $ARGV[1];

my $num_outputs = 0;
my $line_count=0;

#Open input file
open(INPUT, $file);

my $read_count = 0;
open (OUTPUT,\">$file.$num_outputs\");
$num_outputs++;
while (my $line = <INPUT>) {
	chomp $line;
	#FASTQ files begin sequence with '@' character
	#If line begins with '@' then it is a new sequence and has 3 lines in between
	if ($line =~ /^[@]/ and $line_count % 4 ==0){
		#Check if the new sequence should be placed in a new file, otherwise place it in same file
		if ($read_count == $num_reads){
			close(OUTPUT);
			open(OUTPUT, \">$file.$num_outputs\");
			print OUTPUT $line;
			print OUTPUT \"\\n\";
			$num_outputs++;
			$read_count = 0;
		}
		else{
			print OUTPUT $line;
			print OUTPUT \"\\n\";
			$read_count++;
		}
	}
	#place all other lines in FASTQ file under same sequence
	else {
		print OUTPUT $line;
		print OUTPUT \"\\n\";
	}

	$line_count++;
}
print $num_outputs;

close(INPUT);
''')
	finally:
		fastq_r.close()

def write_cat_bwa(destination):
	log.info("Writing Joining Script to : {0}".format(destination))
	cat_bwa = open(destination, 'w')
	try:
		cat_bwa.write('''#!/usr/bin/perl
#
#Copyright (C) 2013- The University of Notre Dame
#This software is distributed under the GNU General Public License.
#See the file COPYING for details.
#
#
#This script is used to cat SAM files into a single result
#
# Usage: cat_bwa output input1 input2 ... intput*
#
# Author: Nick Hazekamp
# Date: 11/01/2013
#

my $output = shift;
if (-e $output) {
	unlink($output) || die \"Could not delete $output\";
}

open $OUT,'>>',$output or die(\"Could not open \" + $output + \" file.\");

$file = 0;

while ($in = shift) {
	++$file;

	open(IN,$in) or die(\"Could not open \" + $in + \" file.\");
	$print = 0;
	while(my $line = <IN>) {
		if (($line =~ /^@/) and ($file ne 1)) { $print = 0;}
		elsif (($line =~ /main/)) { $print = 0; }
		elsif (($line =~ /M::/)) { $print = 0; }
		else { $print = 1; }
		if ($print gt 0){
			print { $OUT } $line;
		}
	}
	close (IN);
}
close (OUT);
''')
	finally:
		cat_bwa.close()

def write_makeflow(destination, configuration, num_reads_per_split, fastq_name, cat_name, bwa_args, algo):
	if destination:
		makeflow = open(destination,'w')
		log.info("Writing Makeflow to : {0}".format(destination))
	else:
		makeflow = sys.stdout
		log.info("Writing Makeflow to : STDOUT")

	try:
		if options.config:
			log.info("Reading Configuration from : {0}".format(configuration))
			config = open(configuration, 'r')
			makeflow.write(config.read())

		if num_reads_per_split > 0:
			splits = count_splits(options.fastq, num_reads_per_split)
			rsplits = 0
			inputlist = []
			for i in range(splits):
				inputlist.append(options.fastq + "." + str(i))

			log.info("Writing Partitioning Rules")

			makeflow.write("\n\n{0} : {1} {2}".format(' '.join(inputlist), fastq_name, options.fastq))
			makeflow.write("\n\tLOCAL perl {0} {1} {2}".format(fastq_name, options.fastq, num_reads_per_split))


			if options.rfastq:

				rsplits = count_splits(options.rfastq, num_reads_per_split)
				rinputlist = []
				for i in range(rsplits):
					rinputlist.append(options.rfastq + "." + str(i))

				makeflow.write("\n\n{0} : {1} {2}".format(' '.join(rinputlist), fastq_name, options.rfastq));
				makeflow.write("\n\tLOCAL perl {0} {1} {2}".format(fastq_name, options.rfastq, num_reads_per_split))
		else:
			splits = 1
			inputlist = [options.fastq]
			if options.rfastq:
				rsplits = 2
				rinputlist = [options.rfastq]


	 # Rule and Command for Indexing FASTA File
		ref = options.ref
		index = "{0}.bwt {0}.pac {0}.amb {0}.ann {0}.sa".format(ref)

		log.info("Writing Indexing Rules")

		index_a = ""
		if bwa_args["index"]:
			index_a = ' '.join(bwa_args["index"])

		makeflow.write("\n\n{0} : bwa {1}".format(index, options.ref))
		makeflow.write("\n\tLOCAL ./bwa index {0} {1}".format(index_a, options.ref))

		index = options.ref + " " + index
		results_to_cat = []
		errors_to_cat = []

		log.info("Writing BWA {0} Rules".format(algo))

		for i in range(splits):
			query1=query2=sai1=sai2=output=""

			query1 = inputlist[i]
			sai1 = query1+".sai"
			if options.rfastq:
				query2 = rinputlist[i]
				sai2 = query2+".sai"

			output = query1+".sam"
			results_to_cat.append(output)

			error = query1+".err"
			errors_to_cat.append(error)

			if algo=="backtrack":

				aln_a = ""
				if bwa_args["aln"]:
					aln_a = ' '.join(bwa_args["aln"])

				makeflow.write("\n\n{0} : bwa {1} {2}".format(sai1, index, query1))
				makeflow.write("\n\t./bwa aln {0} {1} {2} > {3}".format(aln_a,ref,query1,sai1))

				if options.rfastq:
					makeflow.write("\n\n{0} : bwa {1} {2}".format(sai2, index, query2))
					makeflow.write("\n\t./bwa aln {0} {1} {2} > {3}".format(aln_a,ref,query2,sai2))

					pe_a = ""
					if bwa_args["sampe"]:
						pe_a = ' '.join(bwa_args["sampe"])

					makeflow.write("\n\n{0} {1} : bwa {2} {3} {4} {5} {6}".format(output, error, index, query1, query2, sai1, sai2))
					makeflow.write("\n\t./bwa sampe {0} {1} {2} {3} {4} {5} > {6} 2> {7}".format(pe_a, ref, sai1, sai2, query1, query2, output, error))
				else:
					se_a = ""
					if bwa_args["samse"]:
						se_a = ' '.join(bwa_args["samse"])

					makeflow.write("\n\n {0} {1} : bwa {2} {3} {4}".format(output, error, index, query1, sai1))
					makeflow.write("\n\t./bwa samse {0} {1} {2} {3} > {4} 2> {5}".format(se_a, ref, sai1, query1, output, error))

			else:
				args = " -v 0 "
				if(algo=="bwasw"):
					algo = "bwasw"
					args = ""

				if bwa_args[algo]:
					args = args + ' '.join(bwa_args[algo])

				makeflow.write("\n\n{0} {1} : bwa {2} {3} {4}".format(output, error, index, query1, query2))
				makeflow.write("\n\t./bwa {0} {1} {2} {3} {4} > {5} 2> {6}".format(algo, args, ref, query1, query2, output, error))

		log.info("Writing Joining Rules")

		(base, sep, file_ext) = options.fastq.rpartition(".")
		if not re.search("f(ast)?q", file_ext, re.IGNORECASE):
			base = "{0}.{1}".format(base,file_ext)

		output = "{0}.sam".format(base)
		if options.output_SAM:
			output = options.output_SAM

		error = "{0}.err".format(base)
		if options.output_err:
			error = options.output_err

		if num_reads_per_split > 0:
			makeflow.write("\n\n{0} : {1} {2}".format(output, cat_name, ' '.join(results_to_cat)))
			makeflow.write("\n\tLOCAL perl {0} {1} {2}".format(cat_name, output, ' '.join(results_to_cat)))

			makeflow.write("\n\n{0} : {1}".format(error, ' '.join(errors_to_cat)))
			makeflow.write("\n\tLOCAL cat {0} > {1}".format(' '.join(errors_to_cat), error))
		else:
			if output != results_to_cat[0]:
				makeflow.write("\n\n{0} : {1}".format(output, results_to_cat[0]))
				makeflow.write("\n\tLOCAL mv {1} {0}".format(output, results_to_cat[0]))

			if error != errors_to_cat[0]:
				makeflow.write("\n\n{0} : {1}".format(error, errors_to_cat[0]))
				makeflow.write("\n\tLOCAL mv {0} {1}".format(errors_to_cat[0], error))

		makeflow.write("\n")

	finally:
		if options.makeflow is not None:
			makeflow.close()


#MAIN

    #Parse Command Line
parser = PassThroughParser()

parser.usage="usage: %prog --ref <reference> --query <query> [--rquery rquery] --algo <algorithm> [options]"

parse_makeflow_options(parser)
parse_bwa_options(parser)

(options, args) = parser.parse_args()

if options.verbose == True:
        log.basicConfig(format="%(levelname)s: %(message)s", level=log.DEBUG)
        log.info("Verbose output.")
else:
        log.basicConfig(format="%(levelname)s: %(message)s")

if not options.ref:
	log.error("Reference Fasta required : --ref")
	sys.exit(4)
else:
	log.info("Reference Fasta : {0}".format(options.ref))

if not options.fastq:
	log.error("Query Fastq required : --query")
	sys.exit(4)
else:
	log.info("Query Fastq : {0}".format(options.fastq))

algo = options.algo
if not algo:
	algo = "mem"

log.info("Selected Algorithm : BWA {0}".format(algo))

num_reads_per_split=50000
if options.num_seq:
	num_reads_per_split = int(options.num_seq)
log.info("Partition Size : {0}".format(num_reads_per_split))

fastq_name = "fastq_reduce"
if os.path.exists(fastq_name):
    os.remove(fastq_name)
write_fastq_reduce(fastq_name)
st = os.stat(fastq_name)
os.chmod(fastq_name, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
log.info("{0} made executable".format(fastq_name))

cat_name = "cat_bwa"
if os.path.exists(cat_name):
    os.remove(cat_name)
write_cat_bwa(cat_name)
st = os.stat(cat_name)
os.chmod(cat_name, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
log.info("{0} made executable".format(cat_name))

path = os.getenv("PATH")
path += os.pathsep + os.pathsep + "."

log.info("Searching PATH for BWA")
bwa = search_file("bwa", path)
if os.path.exists("./bwa"):
		log.info("BWA found locally")
elif bwa and not os.path.exists("./bwa"):
        log.info("BWA linked from : " + bwa)
        os.symlink(bwa, "bwa")
elif not os.path.exists("./bwa"):
        log.error("Unable to find bwa")
        sys.exit(3)

write_makeflow(options.makeflow, options.config, num_reads_per_split, fastq_name, cat_name, bwa_args, algo)
