#!/usr/bin/python3
"""
    Copyright (C) 2023  Michael Ablassmeier <abi@grinser.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import os
import sys
import signal
import logging
import argparse
from typing import List
from datetime import datetime
from functools import partial
from concurrent.futures import ThreadPoolExecutor, as_completed
from libvirtnbdbackup import sighandle
from libvirtnbdbackup import argopt
from libvirtnbdbackup import __version__
from libvirtnbdbackup import virt
from libvirtnbdbackup.objects import DomainDisk
from libvirtnbdbackup.virt import checkpoint
from libvirtnbdbackup import output
from libvirtnbdbackup.output import stream
from libvirtnbdbackup import common as lib
from libvirtnbdbackup.logcount import logCount
from libvirtnbdbackup import exceptions
from libvirtnbdbackup.backup import partialfile
from libvirtnbdbackup.backup import job
from libvirtnbdbackup.backup import disk
from libvirtnbdbackup.backup import metadata
from libvirtnbdbackup.ssh.exceptions import sshError
from libvirtnbdbackup.virt.exceptions import (
    domainNotFound,
    connectionFailed,
)
from libvirtnbdbackup.output.exceptions import OutputException


def main() -> None:
    """Handle backup operation and settings."""
    parser = argparse.ArgumentParser(
        description="Backup libvirt/qemu virtual machines",
        epilog=(
            "Examples:\n"
            "   # full backup of domain 'webvm' with all attached disks:\n"
            "\t%(prog)s -d webvm -l full -o /backup/\n"
            "   # incremental backup:\n"
            "\t%(prog)s -d webvm -l inc -o /backup/\n"
            "   # differential backup:\n"
            "\t%(prog)s -d webvm -l diff -o /backup/\n"
            "   # full backup, exclude disk 'vda':\n"
            "\t%(prog)s -d webvm -l full -x vda -o /backup/\n"
            "   # full backup, backup only disk 'vdb':\n"
            "\t%(prog)s -d webvm -l full -i vdb -o /backup/\n"
            "   # full backup, compression enabled:\n"
            "\t%(prog)s -d webvm -l full -z -o /backup/\n"
            "   # full backup, create archive:\n"
            "\t%(prog)s -d webvm -l full -o - > backup.zip\n"
            "   # full backup of vm operating on remote libvirtd:\n"
            "\t%(prog)s -U qemu+ssh://root@remotehost/system "
            "--ssh-user root -d webvm -l full -o /backup/\n"
        ),
        formatter_class=argparse.RawTextHelpFormatter,
    )

    opt = parser.add_argument_group("General options")
    opt.add_argument("-d", "--domain", required=True, type=str, help="Domain to backup")
    opt.add_argument(
        "-l",
        "--level",
        default="copy",
        choices=["copy", "full", "inc", "diff", "auto"],
        type=str,
        help="Backup level. (default: %(default)s)",
    )
    opt.add_argument(
        "-t",
        "--type",
        default="stream",
        type=str,
        choices=["stream", "raw"],
        help="Output type: stream or raw. (default: %(default)s)",
    )
    opt.add_argument(
        "-r",
        "--raw",
        default=False,
        action="store_true",
        help="Include full provisioned disk images in backup. (default: %(default)s)",
    )
    opt.add_argument(
        "-o", "--output", required=True, type=str, help="Output target directory"
    )
    opt.add_argument(
        "-C",
        "--checkpointdir",
        required=False,
        default=None,
        type=str,
        help="Persistent libvirt checkpoint storage directory",
    )
    opt.add_argument(
        "-S",
        "--scratchdir",
        default="/var/tmp",
        required=False,
        type=str,
        help="Target dir for temporary scratch file. (default: %(default)s)",
    )
    opt.add_argument(
        "-i",
        "--include",
        default=None,
        type=str,
        help="Backup only disk with target dev name (-i vda)",
    )
    opt.add_argument(
        "-x",
        "--exclude",
        default=None,
        type=str,
        help="Exclude disk(s) with target dev name (-x vda,vdb)",
    )
    opt.add_argument(
        "-f",
        "--socketfile",
        default=f"/var/tmp/virtnbdbackup.{os.getpid()}",
        type=str,
        help="Use specified file for NBD Server socket (default: %(default)s)",
    )
    opt.add_argument(
        "-n",
        "--noprogress",
        default=False,
        help="Disable progress bar",
        action="store_true",
    )
    opt.add_argument(
        "-z",
        "--compress",
        default=False,
        type=int,
        const=2,
        nargs="?",
        help="Compress with lz4 compression level. (default: %(default)s)",
        action="store",
    )
    opt.add_argument(
        "-w",
        "--worker",
        type=int,
        default=None,
        help=(
            "Amount of concurrent workers used "
            "to backup multiple disks. (default: amount of disks)"
        ),
    )
    opt.add_argument(
        "-F",
        "--freeze-mountpoint",
        type=str,
        default=None,
        help=(
            "If qemu agent available, freeze only filesystems on specified mountpoints within"
            " virtual machine (default: all)"
        ),
    )
    opt.add_argument(
        "-e",
        "--strict",
        default=False,
        help=(
            "Change exit code if warnings occur during backup operation. "
            "(default: %(default)s)"
        ),
        action="store_true",
    )
    opt.add_argument(
        "-T",
        "--threshold",
        type=int,
        default=None,
        help=("Execute backup only if threshold is reached."),
    )
    remopt = parser.add_argument_group("Remote Backup options")
    argopt.addRemoteArgs(remopt)
    logopt = parser.add_argument_group("Logging options")
    logopt.add_argument(
        "-L",
        "--syslog",
        default=False,
        action="store_true",
        help="Additionally send log messages to syslog (default: %(default)s)",
    )
    logopt.add_argument(
        "--quiet",
        default=False,
        action="store_true",
        help="Disable logging to stderr (default: %(default)s)",
    )
    argopt.addLogColorArgs(logopt)
    debopt = parser.add_argument_group("Debug options")
    debopt.add_argument(
        "-q",
        "--qemu",
        default=False,
        action="store_true",
        help="Use Qemu tools to query extents.",
    )
    debopt.add_argument(
        "-s",
        "--startonly",
        default=False,
        help="Only initialize backup job via libvirt, do not backup any data",
        action="store_true",
    )
    debopt.add_argument(
        "-k",
        "--killonly",
        default=False,
        help="Kill any running block job",
        action="store_true",
    )
    debopt.add_argument(
        "-p",
        "--printonly",
        default=False,
        help="Quit after printing estimated checkpoint size.",
        action="store_true",
    )
    argopt.addDebugArgs(debopt)

    repository = output.target()
    args = lib.argparse(parser)

    lib.setThreadName()
    args.stdout = args.output == "-"
    args.sshClient = None
    args.diskInfo = []
    args.offline = False

    if args.quiet is True:
        args.noprogress = True

    fileStream = stream.get(args, repository)

    try:
        if not args.stdout:
            fileStream.create(args.output)
    except OutputException as e:
        logging.error("Can't open output file: [%s]", e)
        sys.exit(1)

    if args.worker is not None and args.worker < 1:
        args.worker = 1

    now = datetime.now().strftime("%m%d%Y%H%M%S")
    logFile = f"{args.output}/backup.{args.level}.{now}.log"
    fileLog = lib.getLogFile(logFile) or sys.exit(1)

    counter = logCount()  # pylint: disable=unreachable
    lib.configLogger(args, fileLog, counter)
    lib.printVersion(__version__)

    logging.info("Backup level: [%s]", args.level)
    if args.compress is not False:
        logging.info("Compression enabled, level [%s]", args.compress)

    if args.compress is not False and args.type == "raw":
        logging.error("Compression not supported with raw output.")
        sys.exit(1)

    if args.stdout is True and args.type == "raw":
        logging.error("Output type raw not supported to stdout.")
        sys.exit(1)

    if args.stdout is True and args.raw is True:
        logging.error("Saving raw images to stdout is not supported.")
        sys.exit(1)

    if (
        args.level not in ("copy", "full", "auto")
        and not lib.hasFullBackup(args)
        and not args.stdout
    ):
        logging.error(
            "Unable to execute [%s] backup: No full backup found in target directory: [%s]",
            args.level,
            args.output,
        )
        sys.exit(1)

    if lib.targetIsEmpty(args) and args.level == "auto":
        logging.info("Backup mode auto, target folder is empty: executing full backup.")
        args.level = "full"
    elif not lib.targetIsEmpty(args) and args.level == "auto":
        if not lib.hasFullBackup(args):
            logging.error(
                "Can't execute switch to auto incremental backup: "
                "specified target folder [%s] does not contain full backup.",
                args.output,
            )
            sys.exit(1)
        logging.info("Backup mode auto: executing incremental backup.")
        args.level = "inc"
    elif not args.stdout and not args.startonly and not args.killonly:
        if not lib.targetIsEmpty(args):
            logging.error("Target directory already contains full or copy backup.")
            sys.exit(1)

    if args.raw is True and args.level in ("inc", "diff"):
        logging.warning(
            "Raw disks can't be included during incremental or differential backup."
        )
        logging.warning("Excluding raw disks.")
        args.raw = False

    if args.type == "raw" and args.level in ("inc", "diff"):
        logging.error(
            "Stream format raw does not support incremental or differential backup."
        )
        sys.exit(1)

    if partialfile.exists(args):
        sys.exit(1)

    if not args.checkpointdir:
        args.checkpointdir = f"{args.output}/checkpoints"
    else:
        logging.info("Store checkpoints in: [%s]", args.checkpointdir)

    fileStream.create(args.checkpointdir)

    try:
        virtClient = virt.client(args)
        domObj = virtClient.getDomain(args.domain)
    except domainNotFound as e:
        logging.error("%s", e)
        sys.exit(1)
    except connectionFailed as e:
        logging.error("Can't connect libvirt daemon: [%s]", e)
        sys.exit(1)

    logging.info("Libvirt library version: [%s]", virtClient.libvirtVersion)

    if virtClient.hasIncrementalEnabled(domObj) is False:
        logging.error(
            (
                "Virtual machine does not support required backup features, "
                "please adjust virtual machine configuration."
            )
        )
        sys.exit(1)

    try:
        checkpoint.checkForeign(args, domObj)
    except exceptions.CheckpointException:
        sys.exit(1)

    if domObj.isActive() == 0:
        if args.level == "full":
            logging.warning("Domain is offline, resetting backup options.")
            args.level = "copy"
            logging.warning("New Backup level: [%s].", args.level)
        args.offline = True

    if args.offline is True and args.startonly is True:
        logging.error("Virtual machine is currently offline")
        logging.error("Virtual machine must be active for this function.")
        sys.exit(1)

    signal.signal(
        signal.SIGINT,
        partial(sighandle.Backup.catch, args, domObj, virtClient, logging),
    )

    vmConfig = virtClient.getDomainConfig(domObj)
    disks: List[DomainDisk] = virtClient.getDomainDisks(args, vmConfig)
    args.info = virtClient.getDomainInfo(vmConfig)

    if args.level != "copy" and lib.hasQcowDisks(disks) is False:
        args.level = "copy"
        logging.info("Only raw disks attached, switching to backup mode copy.")

    if not disks:
        logging.error("Unable to detect disks suitable for backup.")
        metadata.saveFiles(args, vmConfig, disks, fileStream, logFile)
        sys.exit(1)
    if (
        not args.killonly
        and not args.offline
        and virtClient.blockJobActive(domObj, disks)
    ):
        logging.error("Detected an active backup operation for running domain.")
        logging.error(
            "Check with [virsh domjobinfo %s] or use option -k to kill the active job.",
            args.domain,
        )
        sys.exit(1)

    logging.info(
        "Backup will save [%s] attached disks.",
        len(disks),
    )
    if args.worker is None or args.worker > int(len(disks)):
        args.worker = int(len(disks))
    logging.info("Concurrent backup processes: [%s]", args.worker)

    if args.killonly is True:
        logging.info("Stopping backup job")
        if not virtClient.stopBackup(domObj):
            sys.exit(1)
        sys.exit(0)

    try:
        checkpoint.create(args, domObj)
    except exceptions.CheckpointException as errmsg:
        logging.error(errmsg)
        sys.exit(1)

    if args.printonly and args.cpt.parent and not args.offline:
        size = checkpoint.getSize(domObj, args.cpt.parent)
        logging.info("Estimated checkpoint backup size: [%s] Bytes", size)
        sys.exit(0)

    if args.threshold and args.cpt.parent and not args.offline:
        size = checkpoint.getSize(domObj, args.cpt.parent)
        if size < args.threshold:
            logging.info(
                "Backup size [%s] does not meet required threshold [%s], skipping backup.",
                size,
                args.threshold,
            )
            sys.exit(0)

    if virtClient.remoteHost != "":
        args.sshClient = lib.sshSession(args, virtClient.remoteHost)
        if not args.sshClient:
            logging.error("Remote backup detected but ssh session setup failed")
            sys.exit(1)
        logging.info(
            "Remote NBD Endpoint host: [%s]",
            virtClient.remoteHost,
        )
        if args.offline is True:
            logging.info(
                "Remote ports used for backup: [%s-%s]",
                args.nbd_port,
                args.nbd_port + args.worker,
            )
    else:
        logging.info("Local NBD Endpoint socket: [%s]", args.socketfile)

    if args.offline is not True:
        logging.info("Temporary scratch file target directory: [%s]", args.scratchdir)
        fileStream.create(args.scratchdir)
        if not job.start(args, virtClient, domObj, disks):
            sys.exit(1)

    if args.level not in ("copy", "diff") and args.offline is False:
        logging.info("Started backup job with checkpoint, saving information.")
        try:
            checkpoint.save(args)
        except exceptions.CheckpointException as e:
            logging.error("Extending checkpoint file failed: [%s]", e)
            sys.exit(1)
        if not checkpoint.backup(args, domObj):
            virtClient.stopBackup(domObj)
            sys.exit(1)

    if args.startonly is True:
        logging.info("Started backup job for debugging, exiting.")
        sys.exit(0)

    try:
        with ThreadPoolExecutor(max_workers=args.worker) as executor:
            futures = {
                executor.submit(
                    disk.backup, args, Disk, count, fileStream, virtClient
                ): Disk
                for count, Disk in enumerate(disks)
            }
            for future in as_completed(futures):
                if future.result() is not True:
                    raise exceptions.DiskBackupFailed("Backup of one disk failed")
    except exceptions.BackupException as e:
        logging.error("Disk backup failed: [%s]", e)
    except sshError as e:
        logging.error("Remote Disk backup failed: [%s]", e)
    except Exception as e:  # pylint: disable=broad-except
        logging.fatal("Unknown Exception during backup: %s", e)
        logging.exception(e)

    if args.offline is False:
        logging.info("Backup jobs finished, stopping backup task.")
        virtClient.stopBackup(domObj)

    virtClient.close()

    metadata.saveFiles(args, vmConfig, disks, fileStream, logFile)

    if domObj.autostart() == 1:
        metadata.backupAutoStart(args)

    if counter.count.errors > 0:
        logging.error("Error during backup")
        sys.exit(1)

    if args.sshClient:
        args.sshClient.disconnect()

    if counter.count.warnings > 0 and args.strict is True:
        logging.info(
            "[%s] Warnings detected during backup operation, forcing exit code 2",
            counter.count.warnings,
        )
        sys.exit(2)

    logging.info("Finished successfully")


if __name__ == "__main__":
    main()
