#!/usr/bin/env python
"""
Use for benchmarking performance of other scripts. Provides data about
time, memory use, cpu usage, network in, network out about the script ran in
the form of a csv.


Usage
=====

NOTE: Make sure you run ``pip install -r requirements-dev.txt`` before running.

To use the script, run::

    ./benchmark "./my-script-to-run"


If no ``--output-file`` was provided, the data will be saved to
``performance.csv``
"""

import argparse
import os
import subprocess
import sys
import time

import psutil

# Determine the interface to track network IO depending on the platform.
if sys.platform.startswith('linux'):
    INTERFACE = 'eth0'
elif sys.platform == 'darwin':
    INTERFACE = 'en0'
else:
    # TODO: Add support for windows. This would require figuring out what
    # interface to use on windows.
    raise RuntimeError(f'Script cannot be run on {sys.platform}')


def benchmark(args):
    parent_pid = os.getpid()
    child_p = run_script(args)
    try:
        # Benchmark the process where the script is being ran.
        return run_benchmark(child_p.pid, args.output_file, args.data_interval)
    except KeyboardInterrupt:
        # If there is an interrupt, then try to clean everything up.
        proc = psutil.Process(parent_pid)
        procs = proc.children(recursive=True)

        for child in procs:
            child.terminate()

        gone, alive = psutil.wait_procs(procs, timeout=1)
        for child in alive:
            child.kill()
        return 1


def run_script(args):
    return subprocess.Popen(args.script, shell=True)


def run_benchmark(pid, output_file, data_interval):
    p = psutil.Process(pid)
    previous_net = psutil.net_io_counters(pernic=True)[INTERFACE]
    previous_time = time.time()

    with open(output_file, 'w') as f:
        while p.is_running():
            if p.status() == psutil.STATUS_ZOMBIE:
                p.kill()
                break
            time.sleep(data_interval)
            process_to_measure = _get_underlying_python_process(p)
            try:
                # Collect the memory and cpu usage.
                memory_used = process_to_measure.memory_info().rss
                cpu_percent = process_to_measure.cpu_percent()
                current_net = psutil.net_io_counters(pernic=True)[INTERFACE]
            except (psutil.AccessDenied, psutil.ZombieProcess):
                # Trying to get process information from a closed or zombie process will
                # result in corresponding exceptions.
                break

            # Collect data on the in/out network io.
            sent_delta = current_net.bytes_sent - previous_net.bytes_sent
            recv_delta = current_net.bytes_recv - previous_net.bytes_recv

            # Determine the lapsed time to determine the network io rate.
            current_time = time.time()
            previous_net = current_net
            dt = current_time - previous_time
            previous_time = current_time
            sent_rate = sent_delta / dt
            recv_rate = recv_delta / dt

            # Save all of the data into a CSV file.
            f.write(
                f"{current_time},{memory_used},{cpu_percent},"
                f"{sent_rate},{recv_rate}\n"
            )
            f.flush()
    return 0


def _get_underlying_python_process(process):
    # For some scripts such as the streaming CLI commands, the process is
    # nested under a shell script that does not account for the python process.
    # We want to always be measuring the python process.
    children = process.children(recursive=True)
    for child_process in children:
        if 'python' in child_process.name().lower():
            return child_process
    return process


def main():
    parser = argparse.ArgumentParser(usage=__doc__)
    parser.add_argument('script', help='The script to run for benchmarking')
    parser.add_argument(
        '--data-interval',
        default=1,
        type=float,
        help='The interval in seconds to poll for data points',
    )
    parser.add_argument(
        '--output-file',
        default='performance.csv',
        help='The file to output the data collected to',
    )
    args = parser.parse_args()
    return benchmark(args)


if __name__ == '__main__':
    sys.exit(main())
