#!/usr/bin/python
__author__ = 'chetan'

import os, sys
import argparse
import traceback
import re, operator
import subprocess
import csv, time
import math
import glob
import datetime

#'''
#Descripton of vw_lda wrapper
#
#Input dataset: accepts two kinds of inputs for datasets
#0) list of files separated by space
#1) directory containing files and directories
#....
# The script executes the following steps:
# 0. Read the input file(s) to prepare the dataset
# 1. Convert input dataset to vw format
# 2. Invoke vw with lda
# 3. Convert vw format output to human readable format
# 4. Write the output to file
# 5. Output partial results on to the console
# 6. The following intermediate files are generated:
#       a. lda_ip.vw : vw input file in word:count format
#       b. hashed_ip_lda.vw : input file in vw format
#       c. hash_values.csv : integer mapping for each word

DEFAULT_NUM_TOPICS=100

def elapsed_time_str(e):
    m, s = divmod(e.total_seconds(), 60)
    return "%02d:%02d MM:SS" % (m, round(s))

class vw_lda(object):

    def __init__(self):
        self.fcount = 0
        self.hash = {}
        self.count = 0
        self.hash_values = []
        self.new_hash = {}
        self.num_topics = 0
        self.op_file = ""
        self.word_list = []
        self.lda_ipfile = "lda_ip.vw"
        self.hashed_lda_ipfile = "hashed_lda_ip.vw"
        self.hashed_lda_model = ""
        self.hash_values = "hash_values.csv"
        self.max_file_count = 10000
        pass

   
    def create_hash(self, doc):
        '''
        Creates hash value for each unseen word
        '''
        words = re.split("[\s+]", doc)
        words = words[1:]
        wrd_lst = []
        for w in words:
            wrd_lst = re.split(":",w)
            #empty words to be ignored
            if wrd_lst[0] == '':
                continue
            if wrd_lst[0] not in self.hash.keys():
                self.hash[wrd_lst[0]] = self.count
                self.count += 1

    def calc_bits_required(self):
        '''
        Calculates the value of b based on the unique word cound
        '''
        return int(math.ceil(math.log(self.count,2)))

    def generate_hash(self, f):
        '''
        reads the input file and creates a hash for each line
        file is in vw format without the freq count for each word
        '''
        with open(f, "r") as input_file:
            lines = input_file.readlines()
            input_file.close()
            for line in lines:
                self.create_hash(line)
        self.print_hash()
        return

    def transform_inputs(self,f):
        '''
        replaces the vw file with word:count format to hash_value:count format
        '''
        f1 = open(self.hashed_lda_ipfile, "w")
        with open(f, "r") as input_file:
            lines = input_file.readlines()
            input_file.close()
            for line in lines:
                if(line != '\n'):
                    l = self.get_hashed_line(line)
                    f1.write(l)
        f1.close()
        return

    def get_hashed_line(self, line):
        '''
        does a lookup for each word to get its corresponding hash value
        '''
        words = re.split("\s+", line)
        words = words[1:]
        op = []
        for w in words:
            wrd_lst = re.split(":",w)
            wrd = wrd_lst[0]
            if(wrd == ''):
                continue
            if len(wrd_lst) > 1 and wrd_lst[1] > 1:
                v = str(self.hash[wrd]) + ":" + str(wrd_lst[1])
            else:
                v = str(self.hash[wrd])
            op.append(v)
        op = "| " + " ".join(op) + "\n"
        return op

    def print_hash(self):
        '''
        prints the hash for each unique word to an
        intermediate file in csv format
        '''
        sorted_x = sorted(self.hash.items(), key=operator.itemgetter(1))
        ofile = open(self.hash_values, "w")
        for l in sorted_x:
            s = str(l[1]) + "," + l[0] + "\n"
            ofile.write(s)
        ofile.close()
        return


    def convert_docs2lda(self,files):
        lda_ip = open(self.lda_ipfile, "w")
        for f in files:
            str2vw_cmd = "vw-doc2lda -f " + f
            str2vw_list = re.split("\s+", str2vw_cmd)
            op = ''
            try:
                op = subprocess.check_output(str2vw_list)
            except Exception as e:
                if(e.args[1] == 'No such file or directory'):
                    print "ERROR: vw-doc2lda not found in the path"
                    sys.exit(1)
                '''
                print type(e)
                print e.args
                print e
                '''
            lda_ip.write(op + "\n")
        lda_ip.close()
        return

    def read_files_from_dir(self, src_dir, num_docs):
        if(num_docs == 0 or num_docs is None):
            num_docs = self.max_file_count
        file_list = []
        file_count = 0
        for root, dirs,files in os.walk(src_dir):
            for f in files:
                if(file_count > num_docs):
                    break
                file_count += 1
                filename = root + "/" + f
                file_list.append(filename)
            
        return file_list

    def process_input_dataset(self, src_dir, file_list, num_docs):
        if(file_list is None or not file_list):
            file_list = self.read_files_from_dir(src_dir, num_docs)
        else:
            del file_list[num_docs:]
        self.convert_docs2lda(file_list)
        return


    def run_vw(self, args):
        """
        runs vw with required inputs, output of vw is dumped onto the console
        vw -d hashed_lda_ip.vw -b 8 --lda 2 --lda_D 5
        --readable_model lda.model.vw
        """
        b = self.calc_bits_required()
        vw_cmd_line = "vw -d " + str(self.hashed_lda_ipfile) + " -b " + str(b) \
            + " --lda " + str(args.lda) + " --lda_D " + str(args.lda_d) \
            + " --readable_model " + str(self.hashed_lda_model) \
            + " --lda_alpha " + str(args.lda_alpha) \
            + " --lda_epsilon " + str(args.lda_epsilon) \
            + " --lda_rho " + str(args.lda_rho)
        print "vw_cmd_line : ", vw_cmd_line
        #print vw_cmd_line
        inp = re.split("\s+", vw_cmd_line)
        try:
            output = subprocess.check_output(inp)
        except Exception as e:
            if(e.args[1] == 'No such file or directory'):
                print "ERROR: vw not found in the path"
                sys.exit(1)
        return

    def process_args(self,args):
        '''
        reads command line arguments, transforms inputs,
        runs vw and transforms outputs
        '''
        self.op_file = args.out_file
        self.hashed_lda_model = args.rd_model
        if(args.lda is None):
            print "WARNING: -t <number_of_topics> not given, defaulting to 100"
            args.lda = DEFAULT_NUM_TOPICS
        self.num_topics = args.lda        

        file_list = []
        print 'reading dataset...'
        rd_start = datetime.datetime.now()
        if(args.lda_d is None or args.lda_d == 0):
            args.lda_d = self.max_file_count
        self.process_input_dataset(args.src_dir, args.flist, args.lda_d)
        rd_end = datetime.datetime.now()
        print "Completed in " + elapsed_time_str(rd_end - rd_start)

        print "creating hash... "
        rd_start = datetime.datetime.now()
        self.generate_hash(self.lda_ipfile)
        rd_end = datetime.datetime.now()
        print "Completed in " + elapsed_time_str(rd_end - rd_start)

        print "transforming inputs "
        rd_start = datetime.datetime.now()
        self.transform_inputs(self.lda_ipfile)
        rd_end = datetime.datetime.now()
        print "Completed in " + elapsed_time_str(rd_end - rd_start)

        print "running vw... "
        rd_start = datetime.datetime.now()
        self.run_vw(args)
        rd_end = datetime.datetime.now()
        print "Completed in " + elapsed_time_str(rd_end - rd_start)

        print "transforming outputs " 
        rd_start = datetime.datetime.now()
        self.transform_outputs(args.max_terms)
        rd_end = datetime.datetime.now()
        print "Completed in " + elapsed_time_str(rd_end - rd_start)

        return

    def populate_hash(self):
        '''
        get word for each hash value to transform output
        '''
        with open(self.hash_values, "r") as hfile:
            self.hash_values = csv.reader(hfile, delimiter=",")
            self.new_hash = {}
            for s in self.hash_values:
                self.new_hash[s[0]] = s[1]
            self.max_word_count = int(s[0])
        return

    def transform_outputs(self, max_terms):
        '''
        Parse the vw lda output to map integers back to words
        '''
        self.populate_hash()
        self.f = open(self.hashed_lda_model, "r")
        lines = self.f.readlines()
        p = re.compile("^Checksum")
        got_data = 0
        topics = []
        line_count = 0
        topic1 = {}
        topic2 = {}
        num_topics = 0
        for l in lines:
            if (got_data):
                word_list = re.split("\s+",l)
                w = self.new_hash[str(word_list[0])]
                num_topics = len(word_list) - 1
                if ( int(word_list[0]) >= self.max_word_count):
                    break
                line_count += 1
                top = {}
                top['word'] = w
                for i in range(1,num_topics):
                    k = "topic" + str(i)
                    top[k] = float(word_list[i])
                topics.append(top)
            if (p.match(l)):
                got_data = 1
                continue

        self.f.close()
        #print topics on to the consolde and write to file
        self.print_topics_summary(topics, num_topics,  max_terms)

        return


    def print_topics_summary(self,topics, num_topics, top_n):
        '''
        For each topic, print the word and its corresponding weight
        in descending order
        '''
        op_file = open(self.op_file,"w")
        print "Topics are..."
        for i in range (1,num_topics):
            sorted_topics = []
            topic = "topic" + str(i)
            sorted_topics = sorted(topics,key=lambda x:x[topic], reverse=True)
            count = 0
            topic = topic.strip()
            print topic
            l = ''
            op_file.write(topic + "\n")
            for top in sorted_topics:
                l = "%20s : %5f" % (top['word'] , top[topic])
                op_file.write(str(l) + "\n")
                if (count <= top_n):
                    print l
                count += 1
            print "\n"
            op_file.write("\n")
        op_file.close()
        return

    def print_sorted_dict(self, d, top_n):
        '''
        helper function to print top_n keys dictionary in sorted order
        '''
        sorted_x = sorted(d.items(), key=operator.itemgetter(1), reverse=True)
        count = 0
        for x in sorted_x:
            if (count > top_n):
                break
            print("%20s : %5f" % (x[0], float(x[1])))
            count += 1

        return sorted_x


def main(argv):
    '''
    read command line arguments and process them
    '''
    #mandatory arguments
    helper = argparse.ArgumentParser(description='Inputs for using vw with lda')
    group = helper.add_mutually_exclusive_group(required=True)
    group.add_argument('-f', '--file_list', dest = 'flist',nargs='+',type=str,
                        required=False,  metavar='',default=[],
                        help='List of input files separated by space')
    group.add_argument('-s','--src_dir', dest='src_dir', type=str,
                        required=False, default='.',
                        help='Directory containing input documents, defaults to current directory', metavar='')
    #optional arguments
    helper.add_argument('-t', '--lda', dest='lda', 
                        help='Run lda with <int> topics, defaults to ' + str(DEFAULT_NUM_TOPICS),
                        required=False, metavar='')
    helper.add_argument('-n','--lda_D', dest='lda_d', type=int,required=False,
                        help='Number of documents, defaults to 0, all files to be included', metavar='')
    helper.add_argument('-b', '--bit_precision', default=18, dest='b',
                        help='Number of bits in the feature table, defaults to 18', metavar='')
    helper.add_argument('-a', '--lda_alpha', dest='lda_alpha', type=float,
                        required=False,default=0.1,
                        help='Prior on sparsity of per document topic, defaults to 0.1',
                        metavar='')
    helper.add_argument('-r', '--lda_rho', dest='lda_rho', type=float,
                        required=False, default=0.1,
                        help='Prior on sparsity of topic distributions, defaults to 0.1',
                        metavar='')
    helper.add_argument('-e', '--lda_epsilon', dest='lda_epsilon', type=float,
                        required=False, help='Loop convergence threshold, defaults to 0.1',
                        default = 0.1,
                        metavar='')
    helper.add_argument('-o','--output_file', dest='out_file', type=str,
                        required=False, default="op.model", metavar='',
                        help='Output file which has the list of topics, defaults to op.model')
    helper.add_argument('-m','--readable_model', dest='rd_model', type=str,
                        required=False, default="hashed_lda_model.vw", metavar='',
                        help='Output human readable final regressor, defaults to hashed_lda_model.vw')
    helper.add_argument('-mt','--max_terms', dest='max_terms', default=10,
                        type=int,required=False, metavar='',
                        help='Max terms per topic to print on the console, defaults to 10')
    args = ''
    try:
        if(len(sys.argv) <= 1):
            helper.print_help()
            sys.exit(1)

        args = helper.parse_args()
        vw_lda().process_args(args)
    except Exception, e:
        print "Exception: " + str(e)
        traceback.print_exc()

    print("The arguments are", args)


if __name__ == "__main__":
    main(sys.argv[1:])
