#!/usr/bin/env python

# -------------------------------------------------------------------------
# MEGAHIT
# Copyright (C) 2014 - 2015 The University of Hong Kong & L3 Bioinformatics Limited
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
# -------------------------------------------------------------------------


from __future__ import print_function

import getopt
import json
import logging
import multiprocessing
import os
import shutil
import signal
import subprocess
import sys
import tempfile
import time
import math

logger = logging.getLogger(__name__)

_usage_message = '''
contact: Dinghua Li <voutcn@gmail.com>

Usage:
  megahit [options] {{-1 <pe1> -2 <pe2> | --12 <pe12> | -r <se>}} [-o <out_dir>]

  Input options that can be specified for multiple times (supporting plain text and gz/bz2 extensions)
    -1                       <pe1>          comma-separated list of fasta/q paired-end #1 files, paired with files in <pe2>
    -2                       <pe2>          comma-separated list of fasta/q paired-end #2 files, paired with files in <pe1>
    --12                     <pe12>         comma-separated list of interleaved fasta/q paired-end files
    -r/--read                <se>           comma-separated list of fasta/q single-end files

Optional Arguments:
  Basic assembly options:
    --min-count              <int>          minimum multiplicity for filtering (k_min+1)-mers [2]
    --k-list                 <int,int,..>   comma-separated list of kmer size
                                            all must be odd, in the range 15-{0}, increment <= 28)
                                            [21,29,39,59,79,99,119,141]

  Another way to set --k-list (overrides --k-list if one of them set):
    --k-min                  <int>          minimum kmer size (<= {0}), must be odd number [21]
    --k-max                  <int>          maximum kmer size (<= {0}), must be odd number [141]
    --k-step                 <int>          increment of kmer size of each iteration (<= 28), must be even number [12]

  Advanced assembly options:
    --no-mercy                              do not add mercy kmers
    --bubble-level           <int>          intensity of bubble merging (0-2), 0 to disable [2]
    --merge-level            <l,s>          merge complex bubbles of length <= l*kmer_size and similarity >= s [20,0.95]
    --prune-level            <int>          strength of low depth pruning (0-3) [2]
    --prune-depth            <int>          remove unitigs with avg kmer depth less than this value [2]
    --disconnect-ratio       <float>        disconnect unitigs if its depth is less than this ratio times 
                                            the total depth of itself and its siblings [0.1]  
    --low-local-ratio        <float>        remove unitigs if its depth is less than this ratio times
                                            the average depth of the neighborhoods [0.2]
    --max-tip-len            <int>          remove tips less than this value [2*k]
    --cleaning-rounds        <int>          number of rounds for graph cleanning [5]
    --no-local                              disable local assembly
    --kmin-1pass                            use 1pass mode to build SdBG of k_min

  Presets parameters:
    --presets                <str>          override a group of parameters; possible values:
                                            meta-sensitive: '--min-count 1 --k-list 21,29,39,49,...,129,141'
                                            meta-large: '--k-min 27 --k-max 127 --k-step 10'
                                            (large & complex metagenomes, like soil)

  Hardware options:
    -m/--memory              <float>        max memory in byte to be used in SdBG construction
                                            (if set between 0-1, fraction of the machine's total memory) [0.9]
    --mem-flag               <int>          SdBG builder memory mode. 0: minimum; 1: moderate;
                                            others: use all memory specified by '-m/--memory' [1]
    -t/--num-cpu-threads     <int>          number of CPU threads [# of logical processors]
    --no-hw-accel                           run MEGAHIT without BMI2 and POPCNT hardware instructions

  Output options:
    -o/--out-dir             <string>       output directory [./megahit_out]
    --out-prefix             <string>       output prefix (the contig file will be OUT_DIR/OUT_PREFIX.contigs.fa)
    --min-contig-len         <int>          minimum length of contigs to output [200]
    --keep-tmp-files                        keep all temporary files
    --tmp-dir                <string>       set temp directory

Other Arguments:
    --continue                              continue a MEGAHIT run from its last available check point.
                                            please set the output directory correctly when using this option.
    --test                                  run MEGAHIT on a toy test dataset
    -h/--help                               print the usage message
    -v/--version                            print version
'''


def check_output(cmd):
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE)
    out, _ = p.communicate()
    out = out.rstrip().decode('utf-8')
    assert p.wait() == 0
    return out


def abspath(path):
    return os.path.abspath(os.path.expanduser(path))


def mkdir_if_not_exists(path):
    if not os.path.exists(path):
        os.mkdir(path)


def remove_if_exists(file_name):
    if os.path.exists(file_name):
        os.remove(file_name)


class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg


class EarlyTerminate(Exception):
    def __init__(self, kmer_size):
        self.kmer_size = kmer_size


class SoftwareInfo(object):
    script_path = os.path.dirname(os.path.realpath(__file__))
    megahit_core = os.path.join(script_path, 'megahit_core')
    megahit_core_popcnt = os.path.join(script_path, 'megahit_core_popcnt')
    megahit_core_no_hw_accel = os.path.join(script_path, 'megahit_core_no_hw_accel')

    @property
    def megahit_version(self):
        return 'MEGAHIT %s' % check_output([self.megahit_core, 'dumpversion'])

    @property
    def max_k_allowed(self):
        return int(check_output([self.megahit_core, 'kmax']))

    @property
    def usage_message(self):
        return _usage_message.format(self.max_k_allowed)


class Options:
    def __init__(self):
        self.out_dir = ''
        self.temp_dir = ''
        self.test_mode = False
        self.continue_mode = False
        self.force_overwrite = False
        self.memory = 0.9
        self.min_contig_len = 200
        self.k_min = 21
        self.k_max = 141
        self.k_step = 10
        self.k_list = [21, 29, 39, 59, 79, 99, 119, 141]
        self.auto_k = True
        self.set_list_by_min_max_step = False
        self.min_count = 2
        self.has_popcnt = True
        self.hw_accel = True
        self.max_tip_len = -1
        self.no_mercy = False
        self.no_local = False
        self.bubble_level = 2
        self.merge_len = 20
        self.merge_similar = 0.95
        self.prune_level = 2
        self.prune_depth = 2
        self.num_cpu_threads = 0
        self.disconnect_ratio = 0.1
        self.low_local_ratio = 0.2
        self.cleaning_rounds = 5
        self.keep_tmp_files = False
        self.mem_flag = 1
        self.out_prefix = ''
        self.kmin_1pass = False
        self.pe1 = []
        self.pe2 = []
        self.pe12 = []
        self.se = []
        self.presets = ''
        self.verbose = False

    @property
    def log_file_name(self):
        if self.out_prefix == '':
            return os.path.join(self.out_dir, 'log')
        else:
            return os.path.join(self.out_dir, self.out_prefix + '.log')

    @property
    def option_file_name(self):
        return os.path.join(self.out_dir, 'options.json')

    @property
    def contig_dir(self):
        return os.path.join(self.out_dir, 'intermediate_contigs')

    @property
    def read_lib_path(self):
        return os.path.join(self.temp_dir, 'reads.lib')

    @property
    def megahit_core(self):
        if self.hw_accel:
            return software_info.megahit_core
        elif self.has_popcnt:
            return software_info.megahit_core_popcnt
        else:
            return software_info.megahit_core_no_hw_accel

    @property
    def host_mem(self):
        if self.memory <= 0:
            raise Usage('Please specify a positive number for -m flag.')
        elif self.memory < 1:
            try:
                total_mem = detect_available_mem()
                return math.floor(total_mem * self.memory)
            except Exception:
                raise Usage('Failed to detect available memory size, '
                            'please specify memory size in byte via -m option')
        else:
            return math.floor(self.memory)

    def dump(self):
        with open(self.option_file_name, 'w') as f:
            json.dump(self.__dict__, f)

    def load_for_continue(self):
        with open(self.option_file_name, 'r') as f:
            self.__dict__.update(json.load(f))


class Checkpoint:

    def __init__(self):
        self._current_checkpoint = 0
        self._logged_checkpoint = None
        self._file = None

    def set_file(self, file):
        self._file = file

    def __call__(self, func):
        def checked_or_call(*args, **kwargs):
            if self._logged_checkpoint is None or self._current_checkpoint > self._logged_checkpoint:
                func(*args, **kwargs)
                with open(self._file, 'a') as cpf:
                    print(str(self._current_checkpoint) + '\t' + 'done', file=cpf)
            else:
                logger.info('passing check point {0}'.format(self._current_checkpoint))
            self._current_checkpoint += 1

        return checked_or_call

    def load_for_continue(self):
        self._logged_checkpoint = -1
        if os.path.exists(self._file):
            with open(self._file, 'r') as f:
                for line in f:
                    a = line.strip().split()
                    if len(a) == 2 and a[1] == 'done':
                        self._logged_checkpoint = int(a[0])


software_info = SoftwareInfo()
opt = Options()
check_point = Checkpoint()


def check_bin():
    if not os.path.exists(opt.megahit_core):
        raise Usage('Cannot find megahit_core, please recompile.')


def parse_option(argv):
    try:
        opts, args = getopt.getopt(argv, 'hm:o:r:t:v1:2:l:f',
                                   ['help',
                                    'read=',
                                    '12=',
                                    'memory=',
                                    'out-dir=',
                                    'min-contig-len=',
                                    'num-cpu-threads=',
                                    'kmin-1pass',
                                    'k-min=',
                                    'k-max=',
                                    'k-step=',
                                    'k-list=',
                                    'num-cpu-threads=',
                                    'min-count=',
                                    'no-mercy',
                                    'no-local',
                                    'max-tip-len=',
                                    'bubble-level=',
                                    'prune-level=',
                                    'prune-depth=',
                                    'merge-level=',
                                    'disconnect-ratio=',
                                    'low-local-ratio=',
                                    'cleaning-rounds=',
                                    'keep-tmp-files',
                                    'tmp-dir=',
                                    'mem-flag=',
                                    'continue',
                                    'version',
                                    'verbose',
                                    'out-prefix=',
                                    'presets=',
                                    'test',
                                    'no-hw-accel',
                                    'force',
                                    # deprecated
                                    'max-read-len=',
                                    'no-low-local',
                                    'cpu-only',
                                    'gpu-mem=',
                                    'use-gpu'])
    except getopt.error as msg:
        raise Usage(software_info.megahit_version + '\n' + str(msg))
    if len(opts) == 0:
        raise Usage(software_info.megahit_version + '\n' + software_info.usage_message)

    for option, value in opts:
        if option in ('-h', '--help'):
            print(software_info.megahit_version + '\n' + software_info.usage_message)
            exit(0)
        elif option in ('-o', '--out-dir'):
            opt.out_dir = abspath(value)
        elif option in ('-m', '--memory'):
            opt.memory = float(value)
        elif option == '--min-contig-len':
            opt.min_contig_len = int(value)
        elif option in ('-t', '--num-cpu-threads'):
            opt.num_cpu_threads = int(value)
        elif option == '--kmin-1pass':
            opt.kmin_1pass = True
        elif option == '--k-min':
            opt.k_min = int(value)
            opt.set_list_by_min_max_step = True
            opt.auto_k = False
        elif option == '--k-max':
            opt.k_max = int(value)
            opt.set_list_by_min_max_step = True
            opt.auto_k = False
        elif option == '--k-step':
            opt.k_step = int(value)
            opt.set_list_by_min_max_step = True
            opt.auto_k = False
        elif option == '--k-list':
            opt.k_list = list(map(int, value.split(',')))
            opt.k_list.sort()
            opt.auto_k = False
            opt.set_list_by_min_max_step = False
        elif option == '--min-count':
            opt.min_count = int(value)
        elif option == '--max-tip-len':
            opt.max_tip_len = int(value)
        elif option == '--merge-level':
            (opt.merge_len, opt.merge_similar) = map(float, value.split(','))
            opt.merge_len = int(opt.merge_len)
        elif option == '--prune-level':
            opt.prune_level = int(value)
        elif option == '--prune-depth':
            opt.prune_depth = float(value)
        elif option == '--bubble-level':
            opt.bubble_level = int(value)
        elif option == '--no-mercy':
            opt.no_mercy = True
        elif option == '--no-local':
            opt.no_local = True
        elif option == '--disconnect-ratio':
            opt.disconnect_ratio = float(value)
        elif option == '--low-local-ratio':
            opt.low_local_ratio = float(value)
        elif option == '--cleaning-rounds':
            opt.cleaning_rounds = int(value)
        elif option == '--keep-tmp-files':
            opt.keep_tmp_files = True
        elif option == '--mem-flag':
            opt.mem_flag = int(value)
        elif option in ('-v', '--version'):
            print(software_info.megahit_version)
            exit(0)
        elif option == '--verbose':
            opt.verbose = True
        elif option == '--continue':
            opt.continue_mode = True
        elif option == '--out-prefix':
            opt.out_prefix = value
        elif option == '--tmp-dir':
            opt.temp_dir = abspath(value)
        elif option in ('--cpu-only', '-l', '--max-read-len', '--no-low-local',
                        '--use-gpu', '--gpu-mem'):
            print('option {0} is deprecated!'.format(option), file=sys.stderr)
            continue  # deprecated options, just ignore
        elif option in ('-r', '--read'):
            opt.se += [abspath(f) for f in value.split(',')]
        elif option == '-1':
            opt.pe1 += [abspath(f) for f in value.split(',')]
        elif option == '-2':
            opt.pe2 += [abspath(f) for f in value.split(',')]
        elif option == '--12':
            opt.pe12 += [abspath(f) for f in value.split(',')]
        elif option == '--presets':
            opt.presets = value
        elif option in ('-f', '--force'):
            opt.force_overwrite = True
        elif option == '--test':
            opt.test_mode = True
        elif option == '--no-hw-accel':
            opt.hw_accel = False
            opt.has_popcnt = False
        else:
            raise Usage('Invalid option {0}'.format(option))


def setup_output_dir():
    if not opt.out_dir:
        if opt.test_mode:
            opt.out_dir = tempfile.mkdtemp(prefix='megahit_test_')
        else:
            opt.out_dir = abspath('./megahit_out')

    check_point.set_file(os.path.join(opt.out_dir, 'checkpoints.txt'))

    if opt.continue_mode and not os.path.exists(opt.option_file_name):
        print('Cannot find {0}, switching to normal mode'.format(opt.option_file_name), file=sys.stderr)
        opt.continue_mode = False

    if opt.continue_mode:
        print('Continue mode activated. Ignore all options except for -o/--out-dir.', file=sys.stderr)
        opt.load_for_continue()
        check_point.load_for_continue()
    else:
        if not opt.force_overwrite and not opt.test_mode and os.path.exists(opt.out_dir):
            raise Usage(
                'Output directory ' + opt.out_dir +
                ' already exists, please change the parameter -o to another value to avoid overwriting.')

        if opt.temp_dir == '':
            opt.temp_dir = os.path.join(opt.out_dir, 'tmp')
        else:
            opt.temp_dir = tempfile.mkdtemp(dir=opt.temp_dir, prefix='megahit_tmp_')

    mkdir_if_not_exists(opt.out_dir)
    mkdir_if_not_exists(opt.temp_dir)
    mkdir_if_not_exists(opt.contig_dir)


def setup_logger():
    formatter = logging.Formatter('%(asctime)s - %(message)s', '%Y-%m-%d %H:%M:%S')

    file_handler = logging.FileHandler(opt.log_file_name, 'a')
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(formatter)

    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    console.setFormatter(formatter)

    logger.setLevel(logging.DEBUG)
    logger.addHandler(file_handler)
    logger.addHandler(console)

    logger.info(software_info.megahit_version)


def check_and_correct_option():
    # set mode
    if opt.memory < 0:
        raise Usage('-m cannot be less than 0!')

    if opt.presets != '':
        opt.auto_k = True

        if opt.presets == 'meta-sensitive':
            opt.min_count = 1
            opt.k_list = [21, 29, 39, 49, 59, 69, 79, 89, 99, 109, 119, 129, 141]
            opt.set_list_by_min_max_step = False
        elif opt.presets == 'meta-large':
            opt.min_count = 1
            opt.k_min = 27
            opt.k_max = 127
            opt.k_step = 10
            opt.set_list_by_min_max_step = True
        else:
            raise Usage('Invalid preset: ' + opt.presets)

    if opt.set_list_by_min_max_step:
        if opt.k_step % 2 == 1:
            raise Usage('k-step must be even number!')
        if opt.k_min > opt.k_max:
            raise Usage('Error: k_min > k_max!')

        opt.k_list = list()
        k = opt.k_min
        while k < opt.k_max:
            opt.k_list.append(k)
            k = k + opt.k_step
        opt.k_list.append(opt.k_max)

    if len(opt.k_list) == 0:
        raise Usage('k list should not be empty!')

    if opt.k_list[0] < 15 or opt.k_list[-1] > software_info.max_k_allowed:
        raise Usage('All k\'s should be in range [15, %d]' % software_info.max_k_allowed)

    for k in opt.k_list:
        if k % 2 == 0:
            raise Usage('All k must be odd number!')

    for i in range(1, len(opt.k_list)):
        if opt.k_list[i] - opt.k_list[i - 1] > 28:
            raise Usage('--k-step (adjacent k difference) must be <= 28')

    opt.k_min, opt.k_max = opt.k_list[0], opt.k_list[-1]

    if opt.k_max < opt.k_min:
        raise Usage('--k-min should not be larger than --k-max.')
    if opt.min_count <= 0:
        raise Usage('--min-count must be greater than 0.')
    elif opt.min_count == 1:
        opt.kmin_1pass = True
        opt.no_mercy = True
    if opt.prune_level < 0 or opt.prune_level > 3:
        raise Usage('--prune-level must be in 0-3.')
    if opt.merge_len < 0:
        raise Usage('--merge-level: length must be >= 0')
    if opt.merge_similar < 0 or opt.merge_similar > 1:
        raise Usage('--merge-level: similarity must be in [0, 1]')
    if opt.disconnect_ratio < 0 or opt.disconnect_ratio > 0.5:
        raise Usage('--disconnect-ratio should be in [0, 0.5].')
    if opt.low_local_ratio <= 0 or opt.low_local_ratio > 0.5:
        raise Usage('--low-local-ratio should be in (0, 0.5].')
    if opt.cleaning_rounds <= 0:
        raise Usage('--cleaning-rounds must be >= 1')
    if opt.num_cpu_threads > multiprocessing.cpu_count():
        logger.warning('Maximum number of available CPU thread is %d.' % multiprocessing.cpu_count())
        logger.warning('Number of thread is reset to the %d.' % multiprocessing.cpu_count())
        opt.num_cpu_threads = multiprocessing.cpu_count()
    if opt.num_cpu_threads == 0:
        opt.num_cpu_threads = multiprocessing.cpu_count()
    if opt.prune_depth < 0 and opt.prune_level < 3:
        opt.prune_depth = opt.min_count
    if opt.bubble_level < 0:
        logger.warning('Reset bubble level to 0.')
        opt.bubble_level = 0
    if opt.bubble_level > 2:
        logger.warning('Reset bubble level to 2.')
        opt.bubble_level = 2


def find_test_data_path():
    script_path = software_info.script_path
    for path in [os.path.join(script_path, '..'), os.path.join(script_path, '../share/megahit')]:
        test_data_dir = abspath(os.path.join(path, 'test_data'))
        if os.path.isdir(test_data_dir) and all(
                f in os.listdir(test_data_dir) for f in ['r1.il.fa.gz', 'r2.il.fa.bz2', 'r3_1.fa', 'r3_2.fa', 'r4.fa']):
            return test_data_dir
    raise Usage('Test data not found! Script path = {0}'.format(script_path))


def check_reads():
    if opt.test_mode:
        test_data_dir = find_test_data_path()
        opt.pe12 = [os.path.join(test_data_dir, 'r1.il.fa.gz'), os.path.join(test_data_dir, 'r2.il.fa.bz2')]
        opt.pe1 = [os.path.join(test_data_dir, 'r3_1.fa')]
        opt.pe2 = [os.path.join(test_data_dir, 'r3_2.fa')]
        opt.se = [os.path.join(test_data_dir, 'r4.fa'), os.path.join(test_data_dir, 'loop.fa')]

    if len(opt.pe1) != len(opt.pe2):
        raise Usage('Number of paired-end files not match!')
    for r in opt.pe1 + opt.pe2 + opt.se + opt.pe12:
        if not os.path.exists(r):
            raise Usage('Cannot find file ' + r)


def detect_available_mem():
    try:
        psize = os.sysconf('SC_PAGE_SIZE')
        pcount = os.sysconf('SC_PHYS_PAGES')
        if psize < 0 or pcount < 0:
            raise SystemError
        return psize * pcount
    except ValueError:
        if sys.platform.find("darwin") != -1:
            return int(float(os.popen("sysctl hw.memsize").readlines()[0].split()[1]))
        elif sys.platform.find("linux") != -1:
            return int(float(os.popen("free").readlines()[1].split()[1]) * 1024)
        else:
            raise


def cpu_dispatch():
    if not opt.hw_accel:
        logger.info('Using megahit_core without POPCNT and BMI2 support, '
                    'because --no-hw-accel option manually specified')
    else:
        has_hw_accel = check_output([opt.megahit_core, 'checkcpu'])

        if has_hw_accel == '1':
            logger.info('Using megahit_core with POPCNT and BMI2 support')
        else:
            opt.hw_accel = False
            has_popcnt = check_output([opt.megahit_core, 'checkpopcnt'])
            if has_popcnt == '1':
                opt.has_popcnt = True
                logger.info('Using megahit_core with POPCNT support')
            else:
                logger.info('Using megahit_core without POPCNT and BMI2 support, '
                        'because the features not detected by CPUID ')


def graph_prefix(kmer_k):
    mkdir_if_not_exists(os.path.join(opt.temp_dir, 'k' + str(kmer_k)))
    return os.path.join(opt.temp_dir, 'k' + str(kmer_k), str(kmer_k))


def contig_prefix(kmer_k):
    return os.path.join(opt.contig_dir, 'k' + str(kmer_k))


def remove_temp_after_build(kmer_k):
    for i in range(opt.num_cpu_threads):
        remove_if_exists(graph_prefix(kmer_k) + '.edges.' + str(i))
    for i in range(64):
        remove_if_exists(graph_prefix(kmer_k) + '.mercy_cand.' + str(i))
    for i in range(opt.num_cpu_threads):
        remove_if_exists(graph_prefix(kmer_k) + '.mercy.' + str(i))
    remove_if_exists(graph_prefix(kmer_k) + '.cand')


def remove_temp_after_assemble(kmer_k):
    for extension in ['w', 'last', 'isd', 'dn', 'f', 'mul', 'mul2']:
        remove_if_exists(graph_prefix(kmer_k) + '.' + extension)
    for i in range(opt.num_cpu_threads):
        remove_if_exists(graph_prefix(kmer_k) + '.sdbg.' + str(i))


def inpipe_cmd(file_name):
    if file_name.endswith('.gz'):
        return 'gzip -cd ' + file_name
    elif file_name.endswith('.bz2'):
        return 'bzip2 -cd ' + file_name
    else:
        return ''


@check_point
def create_library_file():
    with open(opt.read_lib_path, 'w') as lib:
        for i, file_name in enumerate(opt.pe12):
            print(file_name, file=lib)
            if inpipe_cmd(file_name) != '':
                print('interleaved ' + os.path.join(opt.temp_dir, 'inpipe.pe12.' + str(i)), file=lib)
            else:
                print('interleaved ' + file_name, file=lib)

        for i, (file_name1, file_name2) in enumerate(zip(opt.pe1, opt.pe2)):
            print(','.join([file_name1, file_name2]), file=lib)

            if inpipe_cmd(file_name1) != '':
                f1 = os.path.join(opt.temp_dir, 'inpipe.pe1.' + str(i))
            else:
                f1 = file_name1

            if inpipe_cmd(file_name2) != '':
                f2 = os.path.join(opt.temp_dir, 'inpipe.pe2.' + str(i))
            else:
                f2 = file_name2

            print('pe ' + f1 + ' ' + f2, file=lib)

        for i, file_name in enumerate(opt.se):
            print(file_name, file=lib)
            if inpipe_cmd(file_name) != '':
                print('se ' + os.path.join(opt.temp_dir, 'inpipe.se.' + str(i)), file=lib)
            else:
                print('se ' + file_name, file=lib)


@check_point
def build_library():
    cmd = [opt.megahit_core, 'buildlib', opt.read_lib_path, opt.read_lib_path]
    fifos = list()
    pipes = list()

    def create_fifo(read_type, num, command):
        fifo_path = os.path.join(opt.temp_dir, 'inpipe.{0}.{1}'.format(read_type, num))
        remove_if_exists(fifo_path)
        os.mkfifo(fifo_path)
        fifos.append(fifo_path)
        p = subprocess.Popen('{0} > {1}'.format(command, fifo_path), shell=True, preexec_fn=os.setsid)
        pipes.append(p)

    try:
        # create inpipe
        for i in range(len(opt.pe12)):
            if inpipe_cmd(opt.pe12[i]) != '':
                create_fifo('pe12', i, inpipe_cmd(opt.pe12[i]))

        for i in range(len(opt.pe1)):
            if inpipe_cmd(opt.pe1[i]) != '':
                create_fifo('pe1', i, inpipe_cmd(opt.pe1[i]))

            if inpipe_cmd(opt.pe2[i]) != '':
                create_fifo('pe2', i, inpipe_cmd(opt.pe2[i]))

        for i in range(len(opt.se)):
            if inpipe_cmd(opt.se[i]) != '':
                create_fifo('se', i, inpipe_cmd(opt.se[i]))

        run_sub_command(cmd, 'Convert reads to binary library', verbose=True)

        for p in pipes:
            pipe_ret = p.wait()
            if pipe_ret != 0:
                logger.error('Error occurs when reading inputs')
                exit(pipe_ret)

    finally:
        for p in pipes:
            if p.poll() is None:
                os.killpg(p.pid, signal.SIGTERM)
        for f in fifos:
            remove_if_exists(f)


def get_max_read_len():
    ret = 0
    with open(opt.read_lib_path + '.lib_info') as read_info:
        for info in read_info.readlines()[2::2]:
            ret = max(ret, int(info.split()[2]))
    read_info.close()
    return ret


def set_max_k_by_lib():
    if not opt.auto_k or len(opt.k_list) == 1:
        return False

    max_read_len = get_max_read_len()
    new_k_list = [k for k in opt.k_list if k < max_read_len + 20]
    if not new_k_list:
        return False
    else:
        opt.k_list = new_k_list
        opt.k_min = opt.k_list[0]
        opt.k_max = opt.k_list[-1]
        return True


@check_point
def build_first_graph_1pass(option):
    cmd = [opt.megahit_core, 'read2sdbg'] + option
    if not opt.no_mercy:
        cmd.append('--need_mercy')

    run_sub_command(cmd, 'Extracting solid (k+1)-mers and building sdbg for k = %d' % opt.k_min)

    if not opt.keep_tmp_files:
        remove_temp_after_build(opt.k_min)


@check_point
def count_mink(option):
    cmd = [opt.megahit_core, 'count'] + option
    run_sub_command(cmd, 'Extract solid (k+1)-mers for k = %d ' % opt.k_min)


def build_first_graph():
    common_option = ['-k', str(opt.k_min),
                     '-m', str(opt.min_count),
                     '--host_mem', str(opt.host_mem),
                     '--mem_flag', str(opt.mem_flag),
                     '--output_prefix', graph_prefix(opt.k_min),
                     '--num_cpu_threads', str(opt.num_cpu_threads),
                     '--read_lib_file', opt.read_lib_path]

    if not opt.kmin_1pass:
        count_mink(common_option)
        build_graph(opt.k_min, 0)
    else:
        build_first_graph_1pass(common_option)


@check_point
def build_graph(kmer_k, kmer_from):
    build_comm_opt = ['--host_mem', str(opt.host_mem),
                      '--mem_flag', str(opt.mem_flag),
                      '--output_prefix', graph_prefix(kmer_k),
                      '--num_cpu_threads', str(opt.num_cpu_threads),
                      '-k', str(kmer_k),
                      '--kmer_from', str(kmer_from)]

    build_cmd = [opt.megahit_core, 'seq2sdbg'] + build_comm_opt

    file_size = 0

    if os.path.exists(graph_prefix(kmer_k) + '.edges.0'):
        build_cmd += ['--input_prefix', graph_prefix(kmer_k)]
        file_size += os.path.getsize(graph_prefix(kmer_k) + '.edges.0')
        tid = 1
        while os.path.exists(graph_prefix(kmer_k) + '.edges.' + str(tid)):
            file_size += os.path.getsize(graph_prefix(kmer_k) + '.edges.' + str(tid))
            tid += 1

    if os.path.exists(contig_prefix(kmer_from) + '.addi.fa'):
        build_cmd += ['--addi_contig', contig_prefix(kmer_from) + '.addi.fa']
        file_size += os.path.getsize(contig_prefix(kmer_from) + '.addi.fa')

    if os.path.exists(contig_prefix(kmer_from) + '.local.fa'):
        build_cmd += ['--local_contig', contig_prefix(kmer_from) + '.local.fa']
        file_size += os.path.getsize(contig_prefix(kmer_from) + '.local.fa')

    if os.path.exists(contig_prefix(kmer_from) + '.contigs.fa'):
        build_cmd += ['--contig', contig_prefix(kmer_from) + '.contigs.fa']
        build_cmd += ['--bubble', contig_prefix(kmer_from) + '.bubble_seq.fa']

    if file_size == 0 and kmer_from != 0:
        raise EarlyTerminate(kmer_from)

    if not opt.no_mercy and kmer_k == opt.k_min:
        build_cmd.append('--need_mercy')

    run_sub_command(build_cmd, 'Build graph for k = %d ' % kmer_k)

    if not opt.keep_tmp_files:
        remove_temp_after_build(kmer_k)


@check_point
def iterate(cur_k, step):
    next_k = cur_k + step
    iterate_cmd = [opt.megahit_core, 'iterate',
                   '-c', contig_prefix(cur_k) + '.contigs.fa',
                   '-b', contig_prefix(cur_k) + '.bubble_seq.fa',
                   '-t', str(opt.num_cpu_threads),
                   '-k', str(cur_k),
                   '-s', str(step),
                   '-o', graph_prefix(next_k),
                   '-r', opt.read_lib_path + '.bin']

    run_sub_command(iterate_cmd, 'Extract iterative edges from k = %d to %d ' % (cur_k, next_k))


@check_point
def assemble(cur_k):
    min_standalone = max(min(opt.k_max * 3 - 1, int(opt.min_contig_len * 1.5)), opt.min_contig_len)
    if opt.max_tip_len >= 0:
        min_standalone = max(opt.max_tip_len + opt.k_max - 1, opt.min_contig_len)

    assembly_cmd = [opt.megahit_core, 'assemble',
                    '-s', graph_prefix(cur_k),
                    '-o', contig_prefix(cur_k),
                    '-t', str(opt.num_cpu_threads),
                    '--min_standalone', str(min_standalone),
                    '--prune_level', str(opt.prune_level),
                    '--merge_len', str(int(opt.merge_len)),
                    '--merge_similar', str(opt.merge_similar),
                    '--cleaning_rounds', str(opt.cleaning_rounds),
                    '--disconnect_ratio', str(opt.disconnect_ratio),
                    '--low_local_ratio', str(opt.low_local_ratio),
                    '--cleaning_rounds', str(opt.cleaning_rounds),
                    '--min_depth', str(opt.prune_depth),
                    '--bubble_level', str(opt.bubble_level)]

    if opt.max_tip_len == -1 and cur_k * 3 - 1 > opt.min_contig_len * 1.5:
        assembly_cmd += ['--max_tip_len', str(max(1, opt.min_contig_len * 1.5 + 1 - cur_k))]
    else:
        assembly_cmd += ['--max_tip_len', str(opt.max_tip_len)]

    if cur_k < opt.k_max:
        assembly_cmd.append('--careful_bubble')

    if cur_k == opt.k_max:
        assembly_cmd.append('--is_final_round')

    if opt.no_local:
        assembly_cmd.append('--output_standalone')

    run_sub_command(assembly_cmd, 'Assemble contigs from SdBG for k = %d' % cur_k)

    if (not opt.keep_tmp_files) and (cur_k != opt.k_max):
        remove_temp_after_assemble(cur_k)


@check_point
def local_assemble(cur_k, kmer_to):
    la_cmd = [opt.megahit_core, 'local',
              '-c', contig_prefix(cur_k) + '.contigs.fa',
              '-l', opt.read_lib_path,
              '-t', str(opt.num_cpu_threads),
              '-o', contig_prefix(cur_k) + '.local.fa',
              '--kmax', str(kmer_to)]
    run_sub_command(la_cmd, 'Local assembly for k = %d' % cur_k)


@check_point
def merge_final(final_k):
    logger.info('Merging to output final contigs ')
    final_contig_name = os.path.join(opt.out_dir, 'final.contigs.fa')
    if opt.out_prefix != '':
        final_contig_name = os.path.join(opt.out_dir, opt.out_prefix + '.contigs.fa')

    with open(final_contig_name, 'w') as final_contigs:
        merge_cmd = 'cat ' + opt.contig_dir + '/*.final.contigs.fa ' + \
                    contig_prefix(final_k) + '.contigs.fa | ' + \
                    opt.megahit_core + ' filterbylen ' + str(opt.min_contig_len)
        p = subprocess.Popen(merge_cmd, shell=True, stdout=final_contigs, stderr=subprocess.PIPE)
        _, log = p.communicate()
        logger.info(log.rstrip().decode('utf-8'))
        ret_code = p.wait()

    if ret_code != 0:
        logger.error('Error occurs when merging final contigs, please refer to %s for detail' % opt.log_file_name)
        logger.error('Exit code %d' % ret_code)
        exit(ret_code)


def run_sub_command(cmd, msg, verbose=False):
    if opt.verbose:
        verbose = True
    logger.info(msg)
    logger.debug('command %s' % ' '.join(cmd))

    p = subprocess.Popen(cmd, stderr=subprocess.PIPE)

    try:
        while True:
            line = p.stderr.readline().rstrip()
            if not line:
                break
            if verbose:
                logger.info(line)
            else:
                logger.debug(line)

        ret_code = p.wait()

        if ret_code != 0:
            logger.error('Error occurs, please refer to %s for detail' % opt.log_file_name)
            logger.error('Command: %s; Exit code %d' % (' '.join(cmd), ret_code))
            exit(ret_code)
    except KeyboardInterrupt:
        p.terminate()
        p.wait()
        exit(signal.SIGINT)


def main(argv=None):
    if argv is None:
        argv = sys.argv

    try:
        start_time = time.time()
        check_bin()
        parse_option(argv[1:])
        setup_output_dir()
        setup_logger()

        check_and_correct_option()

        check_reads()
        cpu_dispatch()
        opt.dump()

        create_library_file()
        build_library()

        if set_max_k_by_lib():
            logger.info('k-max reset to: %d ' % opt.k_max)

        logger.info('Start assembly. Number of CPU threads %d ' % opt.num_cpu_threads)
        logger.info('k list: %s ' % ','.join(map(str, opt.k_list)))
        logger.info('Memory used: %d' % opt.host_mem)

        build_first_graph()
        assemble(opt.k_min)

        cur_k = opt.k_min
        next_k_idx = 0

        try:
            while cur_k < opt.k_max:
                next_k_idx += 1
                next_k = opt.k_list[next_k_idx]
                k_step = next_k - cur_k

                if not opt.no_local:
                    local_assemble(cur_k, next_k)

                iterate(cur_k, k_step)
                build_graph(next_k, cur_k)

                assemble(next_k)
                cur_k = next_k
            merge_final(opt.k_max)

        except EarlyTerminate as et:
            merge_final(et.kmer_size)

        if not opt.keep_tmp_files and os.path.exists(opt.temp_dir):
            shutil.rmtree(opt.temp_dir)

        open(os.path.join(opt.out_dir, 'done'), 'w').close()

        if not opt.keep_tmp_files and opt.test_mode:
            shutil.rmtree(opt.out_dir)

        logger.info('ALL DONE. Time elapsed: %f seconds ' % (time.time() - start_time))

    except Usage as usg:
        print(sys.argv[0].split('/')[-1] + ': ' + str(usg.msg), file=sys.stderr)
        exit(1)


if __name__ == '__main__':
    main()
