#!/usr/bin/env python3

import argparse
import collections
import dbus
import hashlib
import logging
import os
import subprocess
import sys
import tempfile
import time

from gi.repository import GUdev

from checkbox_support.dbus import connect_to_system_bus
from checkbox_support.dbus.udisks2 import UDISKS2_BLOCK_INTERFACE
from checkbox_support.dbus.udisks2 import UDISKS2_DRIVE_INTERFACE
from checkbox_support.dbus.udisks2 import UDISKS2_FILESYSTEM_INTERFACE
from checkbox_support.dbus.udisks2 import UDisks2Model, UDisks2Observer
from checkbox_support.dbus.udisks2 import is_udisks2_supported
from checkbox_support.dbus.udisks2 import lookup_udev_device
from checkbox_support.dbus.udisks2 import map_udisks1_connection_bus
from checkbox_support.heuristics.udisks2 import is_memory_card
from checkbox_support.parsers.udevadm import CARD_READER_RE, GENERIC_RE, FLASH_RE
from checkbox_support.udev import get_interconnect_speed
from checkbox_support.udev import get_udev_block_devices


class ActionTimer():
    '''Class to implement a simple timer'''
    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *args):
        self.stop = time.time()
        self.interval = self.stop - self.start


class RandomData():
    '''Class to create data files'''
    def __init__(self, size):
        self.tfile = tempfile.NamedTemporaryFile(delete=False)
        self.path = ''
        self.name = ''
        self.path, self.name = os.path.split(self.tfile.name)
        self._write_test_data_file(size)

    def _generate_test_data(self):
        seed = "104872948765827105728492766217823438120"
        phrase = '''
        Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam
        nonummy nibh euismod tincidunt ut laoreet dolore magna aliquam erat
        volutpat. Ut wisi enim ad minim veniam, quis nostrud exerci tation
        ullamcorper suscipit lobortis nisl ut aliquip ex ea commodo consequat.
        Duis autem vel eum iriure dolor in hendrerit in vulputate velit esse
        molestie consequat, vel illum dolore eu feugiat nulla facilisis at vero
        eros et accumsan et iusto odio dignissim qui blandit praesent luptatum
        zzril delenit augue duis dolore te feugait nulla facilisi.
        '''
        words = phrase.replace('\n', '').split()
        word_deque = collections.deque(words)
        seed_deque = collections.deque(seed)
        while True:
            yield ' '.join(list(word_deque))
            word_deque.rotate(int(seed_deque[0]))
            seed_deque.rotate(1)

    def _write_test_data_file(self, size):
        data = self._generate_test_data()
        while os.path.getsize(self.tfile.name) < size:
            self.tfile.write(next(data).encode('UTF-8'))
        return self


def md5_hash_file(path):
    md5 = hashlib.md5()
    try:
        with open(path, 'rb') as stream:
            while True:
                data = stream.read(8192)
                if not data:
                    break
                md5.update(data)
    except IOError as exc:
        logging.error("unable to checksum %s: %s", path, exc)
        return None
    else:
        return md5.hexdigest()


class DiskTest():
    ''' Class to contain various methods for testing removable disks '''

    def __init__(self, device, memorycard):
        self.rem_disks = {}     # mounted before the script running
        self.rem_disks_nm = {}  # not mounted before the script running
        self.rem_disks_memory_cards = {}
        self.rem_disks_memory_cards_nm = {}
        self.rem_disks_speed = {}
        self.data = ''
        self.device = device
        self.memorycard = memorycard
        self._probe_disks()

    def read_file(self, source):
        with open(source, 'rb') as infile:
            try:
                self.data = infile.read()
            except IOError as exc:
                logging.error("Unable to read data from %s: %s", source, exc)
                return False
            else:
                return True

    def write_file(self, data, dest):
        try:
            outfile = open(dest, 'wb', 0)
        except OSError as exc:
            logging.error("Unable to open %s for writing.", dest)
            logging.error("  %s", exc)
            return False
        with outfile:
            try:
                outfile.write(self.data)
            except IOError as exc:
                logging.error("Unable to write data to %s: %s", dest, exc)
                return False
            else:
                outfile.flush()
                os.fsync(outfile.fileno())
                return True

    def clean_up(self, target):
        try:
            os.unlink(target)
        except OSError as exc:
            logging.error("Unable to remove tempfile %s", target)
            logging.error("  %s", exc)

    def _probe_disks(self):
        """
        Internal method used to probe for available disks

        Indirectly sets:
            self.rem_disks{,_nm,_memory_cards,_memory_cards_nm,_speed}
        """
        bus, loop = connect_to_system_bus()
        if is_udisks2_supported(bus):
            self._probe_disks_udisks2(bus)
        else:
            self._probe_disks_udisks1(bus)

    def _probe_disks_udisks2(self, bus):
        """
        Internal method used to probe / discover available disks using udisks2
        dbus interface using the provided dbus bus (presumably the system bus)
        """
        # We'll need udisks2 and udev to get the data we need
        udisks2_observer = UDisks2Observer()
        udisks2_model = UDisks2Model(udisks2_observer)
        udisks2_observer.connect_to_bus(bus)
        udev_client = GUdev.Client()
        # Get a collection of all udev devices corresponding to block devices
        udev_devices = get_udev_block_devices(udev_client)
        # Get a collection of all udisks2 objects
        udisks2_objects = udisks2_model.managed_objects
        # Let's get a helper to simplify the loop below

        def iter_filesystems_on_block_devices():
            """
            Generate a collection of UDisks2 object paths that
            have both the filesystem and block device interfaces
            """
            for udisks2_object_path, interfaces in udisks2_objects.items():
                if (UDISKS2_FILESYSTEM_INTERFACE in interfaces
                    and UDISKS2_BLOCK_INTERFACE in interfaces):
                    yield udisks2_object_path
        # We need to know about all IO candidates,
        # let's iterate over all the block devices reported by udisks2
        for udisks2_object_path in iter_filesystems_on_block_devices():
            # Get interfaces implemented by this object
            udisks2_object = udisks2_objects[udisks2_object_path]
            # Find the path of the udisks2 object that represents the drive
            # this object is a part of
            drive_object_path = (
                udisks2_object[UDISKS2_BLOCK_INTERFACE]['Drive'])
            # Lookup the drive object, if any. This can fail when
            try:
                drive_object = udisks2_objects[drive_object_path]
            except KeyError:
                logging.error(
                    "Unable to locate drive associated with %s",
                    udisks2_object_path)
                continue
            else:
                drive_props = drive_object[UDISKS2_DRIVE_INTERFACE]
            # Get the connection bus property from the drive interface of the
            # drive object. This is required to filter out the devices we don't
            # want to look at now.
            connection_bus = drive_props["ConnectionBus"]
            desired_connection_buses = set([
                map_udisks1_connection_bus(device)
                for device in self.device])
            # Skip devices that are attached to undesired connection buses
            if connection_bus not in desired_connection_buses:
                continue
            # Lookup the udev object that corresponds to this object
            try:
                udev_device = lookup_udev_device(udisks2_object, udev_devices)
            except LookupError:
                logging.error(
                    "Unable to locate udev object that corresponds to: %s",
                    udisks2_object_path)
                continue
            # Get the block device pathname,
            # to avoid the confusion, this is something like /dev/sdbX
            dev_file = udev_device.get_device_file()
            # Get the list of mount points of this block device
            mount_points = (
                udisks2_object[UDISKS2_FILESYSTEM_INTERFACE]['MountPoints'])
            # Get the speed of the interconnect that is associated with the
            # block device we're looking at. This is purely informational but
            # it is a part of the required API
            interconnect_speed = get_interconnect_speed(udev_device)
            if interconnect_speed:
                self.rem_disks_speed[dev_file] = (
                    interconnect_speed * 10 ** 6)
            else:
                self.rem_disks_speed[dev_file] = None
            # We need to skip-non memory cards if we look for memory cards and
            # vice-versa so let's inspect the drive and use heuristics to
            # detect memory cards (a memory card reader actually) now.
            if self.memorycard != is_memory_card(drive_props['Vendor'],
                                                 drive_props['Model'],
                                                 drive_props['Media']):
                continue
            # The if/else test below simply distributes the mount_point to the
            # appropriate variable, to keep the API requirements. It is
            # confusing as _memory_cards is variable is somewhat dummy.
            if mount_points:
                # XXX: Arbitrarily pick the first of the mount points
                mount_point = mount_points[0]
                self.rem_disks_memory_cards[dev_file] = mount_point
                self.rem_disks[dev_file] = mount_point
            else:
                self.rem_disks_memory_cards_nm[dev_file] = None
                self.rem_disks_nm[dev_file] = None

    def _probe_disks_udisks1(self, bus):
        """
        Internal method used to probe / discover available disks using udisks1
        dbus interface using the provided dbus bus (presumably the system bus)
        """
        ud_manager_obj = bus.get_object("org.freedesktop.UDisks",
                                        "/org/freedesktop/UDisks")
        ud_manager = dbus.Interface(ud_manager_obj, 'org.freedesktop.UDisks')
        for dev in ud_manager.EnumerateDevices():
            device_obj = bus.get_object("org.freedesktop.UDisks", dev)
            device_props = dbus.Interface(device_obj, dbus.PROPERTIES_IFACE)
            udisks = 'org.freedesktop.UDisks.Device'
            if not device_props.Get(udisks, "DeviceIsDrive"):
                dev_bus = device_props.Get(udisks, "DriveConnectionInterface")
                if dev_bus in self.device:
                    parent_model = parent_vendor = ''
                    if device_props.Get(udisks, "DeviceIsPartition"):
                        parent_obj = bus.get_object(
                            "org.freedesktop.UDisks",
                            device_props.Get(udisks, "PartitionSlave"))
                        parent_props = dbus.Interface(
                            parent_obj, dbus.PROPERTIES_IFACE)
                        parent_model = parent_props.Get(udisks, "DriveModel")
                        parent_vendor = parent_props.Get(udisks, "DriveVendor")
                        parent_media = parent_props.Get(udisks, "DriveMedia")
                    if self.memorycard:
                        if (dev_bus != 'sdio'
                            and not FLASH_RE.search(parent_media)
                            and not CARD_READER_RE.search(parent_model)
                            and not GENERIC_RE.search(parent_vendor)):
                            continue
                    else:
                        if (FLASH_RE.search(parent_media)
                            or CARD_READER_RE.search(parent_model)
                            or GENERIC_RE.search(parent_vendor)):
                            continue
                    dev_file = str(device_props.Get(udisks, "DeviceFile"))
                    dev_speed = str(device_props.Get(udisks,
                                                     "DriveConnectionSpeed"))
                    self.rem_disks_speed[dev_file] = dev_speed
                    if len(device_props.Get(udisks, "DeviceMountPaths")) > 0:
                        devPath = str(device_props.Get(udisks,
                                                       "DeviceMountPaths")[0])
                        self.rem_disks[dev_file] = devPath
                        self.rem_disks_memory_cards[dev_file] = devPath
                    else:
                        self.rem_disks_nm[dev_file] = None
                        self.rem_disks_memory_cards_nm[dev_file] = None

    def mount(self):
        passed_mount = {}

        for key in self.rem_disks_nm:
            temp_dir = tempfile.mkdtemp()
            if self._mount(key, temp_dir) != 0:
                logging.error("can't mount %s", key)
            else:
                passed_mount[key] = temp_dir

        if len(self.rem_disks_nm) == len(passed_mount):
            self.rem_disks_nm = passed_mount
            return 0
        else:
            count = len(self.rem_disks_nm) - len(passed_mount)
            self.rem_disks_nm = passed_mount
            return count

    def _mount(self, dev_file, mount_point):
        return subprocess.call(['mount', dev_file, mount_point])

    def umount(self):
        errors = 0
        for disk in self.rem_disks_nm:
            if not self.rem_disks_nm[disk]:
                continue
            if self._umount(disk) != 0:
                errors += 1
                logging.error("can't umount %s on %s",
                              disk, self.rem_disks_nm[disk])
        return errors

    def _umount(self, mount_point):
        # '-l': lazy umount, dealing problem of unable to umount the device.
        return subprocess.call(['umount', '-l', mount_point])

    def clean_tmp_dir(self):
        for disk in self.rem_disks_nm:
            if not self.rem_disks_nm[disk]:
                continue
            if not os.path.ismount(self.rem_disks_nm[disk]):
                os.rmdir(self.rem_disks_nm[disk])


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('device',
                        choices=['usb', 'firewire', 'sdio',
                                 'scsi', 'ata_serial_esata'],
                        nargs='+',
                        help=("The type of removable media "
                              "(usb, firewire, sdio, scsi or ata_serial_esata)"
                              "to test."))
    parser.add_argument('-l', '--list',
                        action='store_true',
                        default=False,
                        help="List the removable devices and mounting status")
    parser.add_argument('-m', '--min-speed',
                        action='store',
                        default=0,
                        type=int,
                        help="Minimum speed a device must support to be "
                             "considered eligible for being tested (bits/s)")
    parser.add_argument('-p', '--pass-speed',
                        action='store',
                        default=0,
                        type=int,
                        help="Minimum average throughput from all eligible"
                             "devices for the test to pass (MB/s)")
    parser.add_argument('-i', '--iterations',
                        action='store',
                        default='1',
                        type=int,
                        help=("The number of test cycles to run. One cycle is"
                              "comprised of generating --count data files of "
                              "--size bytes and writing them to each device."))
    parser.add_argument('-c', '--count',
                        action='store',
                        default='1',
                        type=int,
                        help='The number of random data files to generate')
    parser.add_argument('-s', '--size',
                        action='store',
                        type=int,
                        default=1048576,
                        help=("The size of the test data file to use "
                              "in Bytes. Default is %(default)s"))
    parser.add_argument('-n', '--skip-not-mount',
                        action='store_true',
                        default=False,
                        help=("skip the removable devices "
                              "which haven't been mounted before the test."))
    parser.add_argument('--memorycard', action="store_true",
                        help=("Memory cards devices on bus other than sdio "
                              "require this parameter to identify "
                              "them as such"))

    args = parser.parse_args()

    test = DiskTest(args.device, args.memorycard)

    errors = 0
    # If we do have removable drives attached and mounted
    if len(test.rem_disks) > 0 or len(test.rem_disks_nm) > 0:
        if args.list:  # Simply output a list of drives detected
            print('-' * 20)
            print("Removable devices currently mounted:")
            if args.memorycard:
                if len(test.rem_disks_memory_cards) > 0:
                    for disk, mnt_point in test.rem_disks_memory_cards.items():
                        print("%s : %s" % (disk, mnt_point))
                else:
                    print("None")

                print("Removable devices currently not mounted:")
                if len(test.rem_disks_memory_cards_nm) > 0:
                    for disk in test.rem_disks_memory_cards_nm:
                        print(disk)
                else:
                    print("None")
            else:
                if len(test.rem_disks) > 0:
                    for disk, mnt_point in test.rem_disks.items():
                        print("%s : %s" % (disk, mnt_point))
                else:
                    print("None")

                print("Removable devices currently not mounted:")
                if len(test.rem_disks_nm) > 0:
                    for disk in test.rem_disks_nm:
                        print(disk)
                else:
                    print("None")

            print('-' * 20)

            return 0

        else:  # Create a file, copy to disk and compare hashes
            if args.skip_not_mount:
                disks_all = test.rem_disks
            else:
                # mount those haven't be mounted yet.
                errors_mount = test.mount()

                if errors_mount > 0:
                    print("There're total %d device(s) failed at mounting."
                           % errors_mount)
                    errors += errors_mount

                disks_all = dict(list(test.rem_disks.items())
                                 + list(test.rem_disks_nm.items()))

            if len(disks_all) > 0:
                print("Found the following mounted %s partitions:"
                       % ', '.join(args.device))

                for disk, mount_point in disks_all.items():
                    supported_speed = test.rem_disks_speed[disk]
                    print("    %s : %s : %s bits/s" %
                            (disk, mount_point, supported_speed),
                            end="")
                    if (args.min_speed
                        and int(args.min_speed) > int(supported_speed)):
                        print(" (Will not test it, speed is below %s bits/s)" %
                                args.min_speed, end="")

                    print("")

                print('-' * 20)

                disks_eligible = {disk: disks_all[disk] for disk in disks_all
                                  if not args.min_speed or
                                  int(test.rem_disks_speed[disk])
                                  >= int(args.min_speed)}
                write_sizes = []
                test_files = {}
                # Generate our data file(s)
                for count in range(args.count):
                    test_files[count] = RandomData(args.size)
                    write_sizes.append(os.path.getsize(
                                        test_files[count].tfile.name))
                    total_write_size = sum(write_sizes)

                try:
                    for disk, mount_point in disks_eligible.items():
                        print("%s (Total Data Size / iteration: %0.4f MB):" %
                              (disk, (total_write_size / 1024 / 1024)))
                        iteration_write_size = (
                            total_write_size * args.iterations) / 1024 / 1024
                        iteration_write_times = []
                        for iteration in range(args.iterations):
                            target_file_list = []
                            write_times = []
                            for file_index in range(args.count):
                                parent_file = test_files[file_index].tfile.name
                                parent_hash = md5_hash_file(parent_file)
                                target_filename = (
                                    test_files[file_index].name +
                                    '.%s' % iteration)
                                target_path = mount_point
                                target_file = os.path.join(target_path,
                                                           target_filename)
                                target_file_list.append(target_file)
                                test.read_file(parent_file)
                                with ActionTimer() as timer:
                                    if not test.write_file(test.data,
                                                           target_file):
                                        logging.error(
                                            "Failed to copy %s to %s",
                                            parent_file, target_file)
                                        errors += 1
                                        continue
                                write_times.append(timer.interval)
                                child_hash = md5_hash_file(target_file)
                                if parent_hash != child_hash:
                                    logging.warning(
                                        "[Iteration %s] Parent and Child"
                                        " copy hashes mismatch on %s!",
                                        iteration, target_file)
                                    logging.warning(
                                        "\tParent hash: %s", parent_hash)
                                    logging.warning(
                                        "\tChild hash: %s", child_hash)
                                    errors += 1
                            for file in target_file_list:
                                test.clean_up(file)
                            total_write_time = sum(write_times)
                            avg_write_time = total_write_time / args.count
                            try:
                                avg_write_speed = ((
                                    total_write_size / total_write_time)
                                    / 1024 / 1024)
                            except ZeroDivisionError:
                                avg_write_speed = 0.00
                            finally:
                                iteration_write_times.append(total_write_time)
                                print("\t[Iteration %s] Average Speed: %0.4f"
                                      % (iteration, avg_write_speed))
                        for iteration in range(args.iterations):
                            iteration_write_time = sum(iteration_write_times)
                        print("\tSummary:")
                        print("\t\tTotal Data Attempted: %0.4f MB"
                              % iteration_write_size)
                        print("\t\tTotal Time to write: %0.4f secs"
                              % iteration_write_time)
                        print("\t\tAverage Write Time: %0.4f secs" %
                              (iteration_write_time / args.iterations))
                        try:
                            avg_write_speed = (iteration_write_size /
                                              iteration_write_time)
                        except ZeroDivisionError:
                            avg_write_speed = 0.00
                        finally:
                            print("\t\tAverage Write Speed: %0.4f MB/s" %
                                  avg_write_speed)
                finally:
                    for key in range(args.count):
                        test.clean_up(test_files[key].tfile.name)
                    if (len(test.rem_disks_nm) > 0):
                        if test.umount() != 0:
                            errors += 1
                        test.clean_tmp_dir()

                if errors > 0:
                    logging.warning(
                        "Completed %s test iterations, but there were"
                        " errors", args.count)
                    return 1
                elif len(disks_eligible) == 0:
                    logging.error(
                        "No %s disks with speed higher than %s bits/s",
                        args.device, args.min_speed)
                    return 1

                else:
                    #Pass is not assured!
                    if (not args.pass_speed or
                        avg_write_speed >= args.pass_speed):
                        return 0
                    else:
                        print("FAIL: Average speed was lower than desired "
                              "pass speed of %s MB/s" % args.pass_speed)
                        return 1
            else:
                logging.error("No device being mounted successfully "
                              "for testing, aborting")
                return 1

    else:  # If we don't have removable drives attached and mounted
        logging.error("No removable drives were detected, aborting")
        return 1

if __name__ == '__main__':
    sys.exit(main())
