#! /usr/bin/env python

"""\
Run Rivet analyses on inputted events from file or Unix pipe

Examples:
  %prog [options] <hepmcfile> [<hepmcfile2> ...]
  my_generator -o myfifo & \ %prog [options] myfifo
  agile-runmc <genname> -n 100k -o- | %prog [options]

ENVIRONMENT:
 * RIVET_ANALYSIS_PATH: list of paths to be searched for plugin
     analysis libraries at runtime
 * RIVET_REF_PATH: list of paths to be searched for reference
     data files
 * RIVET_INFO_PATH: list of paths to be searched for analysis
     metadata files
"""

import sys
if sys.version_info[:3] < (2, 4, 0):
    print "rivet scripts require Python version >= 2.4.0... exiting"
    sys.exit(1)


import os, time
import logging, signal

## Try to rename the process on Linux
try:
    import ctypes
    libc = ctypes.cdll.LoadLibrary('libc.so.6')
    libc.prctl(15, 'rivet', 0, 0, 0)
except Exception:
    pass

## Try to use Psyco optimiser
try:
    import psyco
    psyco.full()
except ImportError:
    pass

PROGPATH = sys.argv[0]
PROGNAME = os.path.basename(PROGPATH)

## Try to bootstrap the Python path
import commands
try:
    modname = sys.modules[__name__].__file__
    binpath = os.path.dirname(modname)
    rivetconfigpath = os.path.join(binpath, "rivet-config")
    rivetpypath = commands.getoutput(rivetconfigpath + " --pythonpath")
    sys.path.append(rivetpypath)
except:
    pass


## Try importing rivet
try:
    import rivet
    rivet.check_python_version()
except Exception, e:
    sys.stderr.write(PROGNAME + " requires the 'rivet' Python module but it cannot be loaded:\n")
    sys.stderr.write(str(e)+'\n')
    sys.stderr.write("Try re-running with the -v switch to enable more debugging output \n")
    sys.exit(1)


## Parse command line options
from optparse import OptionParser, OptionGroup
parser = OptionParser(usage=__doc__, version="rivet v%s" % rivet.version())

anagroup = OptionGroup(parser, "Analysis handling")
anagroup.add_option("-a", "--analysis", dest="ANALYSES", action="append",
                    default=[], metavar="ANA",
                    help="add an analysis to the processing list.")
# anagroup.add_option("-A", "--all-analyses", dest="ALL_ANALYSES", action="store_true",
#                     default=False, help="add all analyses to the processing list.")
anagroup.add_option("--list-analyses", dest="LIST_ANALYSES", action="store_true",
                    default=False, help="show the list of available analyses' names. With -v, it shows the descriptions, too")
anagroup.add_option("--list-used-analyses", action="store_true", dest="LIST_USED_ANALYSES",
                    default=False, help="list the analyses used by this command (after subtraction of inappropriate ones)")
anagroup.add_option("--show-analysis", "--show-analyses", dest="SHOW_ANALYSES", action="append",
                    default=[], help="show the details of an analysis")
anagroup.add_option("--analysis-path", dest="ANALYSIS_PATH", metavar="PATH", default=None,
                    help="specify the analysis search path (cf. $RIVET_ANALYSIS_PATH).")
anagroup.add_option("--analysis-path-append", dest="ANALYSIS_PATH_APPEND", metavar="PATH", default=None,
                    help="append to the analysis search path (cf. $RIVET_ANALYSIS_PATH).")
anagroup.add_option("--pwd", dest="ANALYSIS_PATH_PWD", action="store_true", default=False,
                    help="append the current directory (pwd) to the analysis search path (cf. $RIVET_ANALYSIS_PATH).")
parser.add_option_group(anagroup)


extragroup = OptionGroup(parser, "Extra run settings")
extragroup.add_option("-H", "--histo-file", dest="HISTOFILE",
                      default="Rivet.aida", help="specify the output histo file path (default = %default)")
extragroup.add_option("-x", "--cross-section", dest="CROSS_SECTION",
                      default=None, metavar="XS",
                      help="specify the signal process cross-section in pb")
extragroup.add_option("-n", "--nevts", dest="MAXEVTNUM", type="int",
                      default=None, metavar="NUM",
                      help="restrict the max number of events to read.")
extragroup.add_option("--runname", dest="RUN_NAME", default=None, metavar="NAME",
                      help="give an optional run name, to be prepended as a 'top level directory' in histo paths")
extragroup.add_option("--ignore-beams", dest="IGNORE_BEAMS", action="store_true", default=False,
                      help="Ignore input event beams when checking analysis compatibility.")
parser.add_option_group(extragroup)


timinggroup = OptionGroup(parser, "Timeouts and periodic operations")
timinggroup.add_option("--event-timeout", dest="EVENT_TIMEOUT", type="int",
                       default=21600, metavar="NSECS",
                       help="max time in whole seconds to wait for an event to be generated from the specified source (default = %default)")
timinggroup.add_option("--run-timeout", dest="RUN_TIMEOUT", type="int",
                       default=None, metavar="NSECS",
                       help="max time in whole seconds to wait for the run to finish. This can be useful on batch systems such "
                       "as the LCG Grid where tokens expire on a fixed wall-clock and can render long Rivet runs unable to write "
                       "out the final histogram file (default = unlimited)")
timinggroup.add_option("--histo-interval", dest="HISTO_WRITE_INTERVAL", type=int,
                       default=None, help="[experimental!] specify the number of events between histogram file updates. "
                       "Default is to only write out at the end of the run. Note that intermediate histograms will be those "
                       "from the analyze step only: analysis finalizing is currently not executed until the end of the run.")
parser.add_option_group(timinggroup)


verbgroup = OptionGroup(parser, "Verbosity control")
parser.add_option("-l", dest="NATIVE_LOG_STRS", action="append",
                  default=[], help="set a log level in the Rivet library")
verbgroup.add_option("-v", "--verbose", action="store_const", const=logging.DEBUG, dest="LOGLEVEL",
                     default=logging.INFO, help="print debug (very verbose) messages")
verbgroup.add_option("-q", "--quiet", action="store_const", const=logging.WARNING, dest="LOGLEVEL",
                     default=logging.INFO, help="be very quiet")
parser.add_option_group(verbgroup)
opts, args = parser.parse_args()


## Configure logging
logging.basicConfig(level=opts.LOGLEVEL, format="%(message)s")


## Control native Rivet library logger
for l in opts.NATIVE_LOG_STRS:
    name, level = None, None
    try:
        name, level = l.split("=")
    except:
        name = "Rivet"
        level = l
    ## Fix name
    if name != "Rivet" and not name.startswith("Rivet."):
        name = "Rivet." + name
    try:
        ## Get right error type
        LEVEL = level.upper()
        if LEVEL == "TRACE":
            level = rivet.Log.TRACE
        elif LEVEL == "DEBUG":
            level = rivet.Log.DEBUG
        elif LEVEL == "INFO":
            level = rivet.Log.INFO
        elif LEVEL == "WARNING" or LEVEL == "WARN":
            level = rivet.Log.WARN
        elif LEVEL == "ERROR":
            level = rivet.Log.ERROR
        else:
            level = int(level)
        logging.debug("Setting log level: %s %d" % (name, level))
        rivet.Log.setLogLevel(name, level)
    except:
        logging.warning("Couldn't process logging string '%s'" % l)


## Parse supplied cross-section
if opts.CROSS_SECTION is not None:
    xsstr = opts.CROSS_SECTION
    try:
        opts.CROSS_SECTION = float(xsstr)
    except:
        import re
        suffmatch = re.search(r"[^\d.]", xsstr)
        if not suffmatch:
            raise ValueError("Bad cross-section string: %s" % xsstr)
        factor = base = None
        suffstart = suffmatch.start()
        if suffstart != -1:
            base = xsstr[:suffstart]
            suffix = xsstr[suffstart:].lower()
            if suffix == "mb":
                factor = 1e+9
            elif suffix == "mub":
                factor = 1e+6
            elif suffix == "nb":
                factor = 1e+3
            elif suffix == "pb":
                factor = 1
            elif suffix == "fb":
                factor = 1e-3
            elif suffix == "ab":
                factor = 1e-6
        if factor is None or base is None:
            raise ValueError("Bad cross-section string: %s" % xsstr)
        xs = float(base) * factor
        opts.CROSS_SECTION = xs


## Print the available CLI options!
#if opts.LIST_OPTIONS:
#    for o in parser.option_list:
#        print o.get_opt_string()
#    sys.exit(0)


## Set up signal handling
RECVD_KILL_SIGNAL = None
def handleKillSignal(signum, frame):
    "Declare us as having been signalled, and return to default handling behaviour"
    global RECVD_KILL_SIGNAL
    logging.critical("Signal handler called with signal " + str(signum))
    RECVD_KILL_SIGNAL = signum
    signal.signal(signum, signal.SIG_DFL)
## Signals to handle
signal.signal(signal.SIGTERM, handleKillSignal);
signal.signal(signal.SIGHUP,  handleKillSignal);
signal.signal(signal.SIGINT,  handleKillSignal);
signal.signal(signal.SIGUSR1, handleKillSignal);
signal.signal(signal.SIGUSR2, handleKillSignal);
try:
    signal.signal(signal.SIGXCPU, handleKillSignal);
except:
    pass


## Override/modify analysis search path
if opts.ANALYSIS_PATH:
    rivet.setAnalysisLibPaths(opts.ANALYSIS_PATH.split(":"))
if opts.ANALYSIS_PATH_APPEND:
    for ap in opts.ANALYSIS_PATH_APPEND.split(":"):
        rivet.addAnalysisLibPath(ap)
if opts.ANALYSIS_PATH_PWD:
    rivet.addAnalysisLibPath(".")

## List of analyses
all_analyses = rivet.AnalysisLoader.analysisNames()
if opts.LIST_ANALYSES:
    ## Treat args as case-insensitive regexes if present
    regexes = None
    if args:
        import re
        regexes = [re.compile(arg, re.I) for arg in args]
    for aname in all_analyses:
        if not regexes:
            toshow = True
        else:
            toshow = False
            for regex in regexes:
                if regex.search(aname):
                    toshow = True
                    break
        if toshow:
            msg = aname
            if opts.LOGLEVEL <= logging.INFO:
                a = rivet.AnalysisLoader.getAnalysis(aname)
                msg = "%-25s   %s" % (aname, a.summary())
            print msg
    sys.exit(0)


## Show analyses' details
if len(opts.SHOW_ANALYSES) > 0:
    toshow = []
    for i, a in enumerate(opts.SHOW_ANALYSES):
        a_up = a.upper()
        if a_up in all_analyses and a_up not in toshow:
            toshow.append(a_up)
        else:
            ## Treat as a case-insensitive regex
            import re
            regex = re.compile(a, re.I)
            for ana in all_analyses:
                if regex.search(ana) and a_up not in toshow:
                    toshow.append(ana)

    ## Show the matching analyses' details
    import textwrap
    for i, name in enumerate(sorted(toshow)):
        ana = rivet.AnalysisLoader.getAnalysis(name)
        print ""
        print name
        print len(name) * "="
        print ""
        print ana.summary()
        print ""
        print "Status: %s" % ana.status()
        print ""

        if ana.inspireId():
            print "Inspire ID: %s" % ana.inspireId()
            print "Inspire URL: http://inspire-hep.net/record/%s" % ana.inspireId()
            # TODO: Need a way to get to HepData from an Inspire record
        elif ana.spiresId():
            print "Spires ID: %s" % ana.spiresId()
            print "Inspire URL: http://inspire-hep.net/search?p=find+key+%s" % ana.spiresId()
            print "HepData URL: http://hepdata.cedar.ac.uk/view/irn%s" % ana.spiresId()

        if ana.experiment():
            print "Experiment: %s" % ana.experiment(),
            if ana.collider():
                print "(%s)" % ana.collider()

        if ana.year():
            print "Year of publication: %s" % ana.year()

        print "Authors:"
        for a in ana.authors():
            print "  " + a

        print ""

        print "Description:"
        twrap = textwrap.TextWrapper(width=75, initial_indent=2*" ", subsequent_indent=2*" ")
        print twrap.fill(ana.description())

        print ""

        if ana.requiredBeams():
            def pid_to_str(pid):
                if pid == 11:
                    return "e-"
                elif pid == -11:
                    return "e+"
                elif pid == 2212:
                    return "p+"
                elif pid == -2212:
                    return "p-"
                elif pid == 10000:
                    return "*"
                else:
                    return str(pid)
            beamstrs = []
            for bp in ana.requiredBeams():
                beamstrs.append(pid_to_str(bp[0]) + " " + pid_to_str(bp[1]))
            print "Beams:", ", ".join(beamstrs)

        if ana.requiredEnergies():
            print "Beam energies:", "; ".join(["(%0.1f, %0.1f)" % (epair[0], epair[1]) for epair in ana.requiredEnergies()]), "GeV"
        else:
            print "Beam energies: ANY"

        if ana.runInfo():
            print "Run details:"
            twrap = textwrap.TextWrapper(width=75, initial_indent=2*" ", subsequent_indent=4*" ")
            for l in ana.runInfo().split("\n"):
                print twrap.fill(l)

        if ana.references():
            print ""
            print "References:"
            for r in ana.references():
                url = None
                if r.startswith("arXiv:"):
                    code = r.split()[0].replace("arXiv:", "")
                    url = "http://arxiv.org/abs/" + code
                elif r.startswith("doi:"):
                    code = r.replace("doi:", "")
                    url = "http://dx.doi.org/" + code
                if url is not None:
                    r += " - " + url
                print "  %s" % r

        if i+1 < len(toshow):
            print "\n"
    sys.exit(0)


## Identify HepMC files/streams
## TODO: check readability, deal with stdin
if len(args) > 0:
    HEPMCFILES = args
else:
    HEPMCFILES = ["-"]


## Event number logging
def logNEvt(n, starttime, maxevtnum):
    nevtloglevel = logging.DEBUG
    if n % 10 == 0:
        nevtloglevel = logging.DEBUG + 5
    if n % 100 == 0:
        nevtloglevel = logging.INFO
    if n % 200 == 0:
        nevtloglevel = logging.INFO + 5
    if n % 500 == 0:
        nevtloglevel = logging.WARNING
    if n % 1000 == 0:
        nevtloglevel = logging.WARNING + 5
    if n % 10000 == 0:
        nevtloglevel = logging.CRITICAL
    timecurrent = time.time()
    timeelapsed = timecurrent - starttime
    if maxevtnum is None:
        logging.log(nevtloglevel, "Event %d (%d s elapsed)" % (n, timeelapsed))
    else:
        timeleft = (maxevtnum-n)*timeelapsed/n
        eta = time.strftime("%a %b %d %H:%M", time.localtime(timecurrent + timeleft))
        logging.log(nevtloglevel, "Event %d (%d s elapsed / %d s left) -> ETA: %s"
                    % (n, timeelapsed, timeleft, eta))


## Set up analysis handler
RUNNAME = opts.RUN_NAME or ""
ah = rivet.AnalysisHandler(RUNNAME)
ah.setIgnoreBeams(opts.IGNORE_BEAMS)
# if opts.ALL_ANALYSES:
#     opts.ANALYSES = all_analyses
for a in opts.ANALYSES:
    #a_up = a.upper()
    ## Print warning message and exit if not a valid analysis name
    if not a in all_analyses:
        logging.warning("'%s' is not a valid analysis. Available analyses are:" % a)
        for aa in all_analyses:
            logging.warning("    %s" % aa)
        logging.warning("Exiting...")
        sys.exit(1)
    logging.debug("Adding analysis '%s'" % a)
    ah.addAnalysis(a)

## Read and process events
run = rivet.Run(ah)
if opts.CROSS_SECTION is not None:
    logging.info("User-supplied cross-section = %e pb" % opts.CROSS_SECTION)
    run.setCrossSection(opts.CROSS_SECTION)
if opts.LIST_USED_ANALYSES is not None:
    run.setListAnalyses(opts.LIST_USED_ANALYSES)

## Print platform type
import platform
logging.info("Rivet %s running on machine %s (%s)" % (rivet.version(), platform.node(), platform.machine()))


def min_nonnull(a, b):
    "A version of min which considers None to always be greater than a real number"
    rtn = min(a, b)
    if rtn is not None:
        return rtn
    if a is not None:
        return a
    return b


## Set up an event timeout handler
class TimeoutException(Exception):
    pass
if opts.EVENT_TIMEOUT or opts.RUN_TIMEOUT:
    def evttimeouthandler(signum, frame):
        logging.warn("It has taken more than %d secs to get an event! Is the input event stream working?" %
                     min_nonnull(opts.EVENT_TIMEOUT, opts.RUN_TIMEOUT))
        raise TimeoutException("Event timeout")
    signal.signal(signal.SIGALRM, evttimeouthandler)


## Init run based on one event
hepmcfile = HEPMCFILES[0]
## Apply a file-level weight derived from the filename
hepmcfileweight = 1.0
if ":" in hepmcfile:
    hepmcfile, hepmcfileweight = hepmcfile.rsplit(":", 1)
    hepmcfileweight = float(hepmcfileweight)
try:
    if opts.EVENT_TIMEOUT or opts.RUN_TIMEOUT:
        signal.alarm(min_nonnull(opts.EVENT_TIMEOUT, opts.RUN_TIMEOUT))
    init_ok = run.init(hepmcfile, hepmcfileweight)
    signal.alarm(0)
    if not init_ok:
        logging.error("Failed to initialise using event file '%s'... exiting" % hepmcfile)
        sys.exit(2)
except TimeoutException, te:
    logging.error("Timeout in initialisation from event file '%s'... exiting" % hepmcfile)
    sys.exit(3)


## Event loop
evtnum = 0
starttime = time.time()
for fileidx, hepmcfile in enumerate(HEPMCFILES):
    ## Apply a file-level weight derived from the filename
    hepmcfileweight = 1.0
    if ":" in hepmcfile:
        hepmcfile, hepmcfileweight = hepmcfile.rsplit(":", 1)
        hepmcfileweight = float(hepmcfileweight)
    ## Open next HepMC file (NB. this doesn't apply to the first file: it was already used for the run init)
    if fileidx > 0:
        run.openFile(hepmcfile, hepmcfileweight)
        if not run.readEvent():
            logging.warning("Could not read events from '%s'" % hepmcfile)
            continue
    msg = "Reading events from '%s'" % hepmcfile
    if hepmcfileweight != 1.0:
        msg += " (file weight = %e)" % hepmcfileweight
    logging.info(msg)
    while opts.MAXEVTNUM is None or evtnum < opts.MAXEVTNUM:
        evtnum += 1
        logNEvt(evtnum, starttime, opts.MAXEVTNUM)

        ## Process this event
        processed_ok = run.processEvent()
        if not processed_ok:
            logging.warn("Event processing failed for evt #%i!" % evtnum)
            break

        ## Set flag to exit event loop if run timeout exceeded
        if opts.RUN_TIMEOUT and (time.time() - starttime) > opts.RUN_TIMEOUT:
            logging.warning("Run timeout of %d secs exceeded... exiting gracefully" % opts.RUN_TIMEOUT)
            RECVD_KILL_SIGNAL = True

        ## Exit the loop if signalled
        if RECVD_KILL_SIGNAL is not None:
            break

        ## Read next event (with timeout handling if requested)
        try:
            if opts.EVENT_TIMEOUT:
                signal.alarm(opts.EVENT_TIMEOUT)
            read_ok = run.readEvent()
            signal.alarm(0)
            if not read_ok:
                break
        except TimeoutException, te:
            logging.error("Timeout in reading event from '%s'... exiting" % hepmcfile)
            sys.exit(3)

        ## Write a histo file snapshot if appropriate
        if opts.HISTO_WRITE_INTERVAL is not None:
            if evtnum % opts.HISTO_WRITE_INTERVAL == 0:
                ah.writeData(opts.HISTOFILE)

logging.info("Finished event loop")
run.finalize()

## Finalize and write out data file
print "Cross-section = %e pb" % ah.crossSection()
ah.finalize()
ah.writeData(opts.HISTOFILE)
