#!/usr/bin/python3
#  OpenVPN 3 Linux client -- Next generation OpenVPN client
#
#  SPDX-License-Identifier: AGPL-3.0-only
#
#  Copyright (C) 2021 - 2023  OpenVPN Inc <sales@openvpn.net>
#  Copyright (C) 2021 - 2023  David Sommerseth <davids@openvpn.net>
#

##
# @file  openvpn3-systemd
#
# @brief OpenVPN 3 session management interface for systemd
#

import argparse
import datetime
import dbus
import errno
import openvpn3
import os
import subprocess
import sys
import systemd.daemon
import time

from dbus.mainloop.glib import DBusGMainLoop
from gi.repository import GLib
from openvpn3 import StatusMajor, StatusMinor, SessionManagerEventType
from openvpn3 import ClientAttentionType

if sys.version_info[0] < 3:
    print("Python 3.x is required")
    sys.exit(1)


##
#  Main class for managing OpenVPN 3 VPN sessions via systemd
#
#  This class takes care of the main loop of a running VPN session
#  and provides log entries to stdout.  The systemd journal will
#  pick up this and preserve this information
#
class OpenVPN3systemd(object):
    def __init__(self):
        # Set up the main GLib loop and connect to the system bus
        self.__mainloop = GLib.MainLoop()
        self.__dbusloop = DBusGMainLoop(set_as_default=True)
        self.__sysbus = dbus.SystemBus(mainloop=self.__dbusloop)

        # Connect to the configuration and session manager
        self.__cmgr = openvpn3.ConfigurationManager(self.__sysbus)
        self.__config = None
        self.__smgr = openvpn3.SessionManager(self.__sysbus)
        self.__smgr.SessionManagerCallback(self._SessionManagerEvent)

        # Initial states for session information
        self.__session = None
        self.__sesspath = None
        self.__error_set = False


    ##
    #  Method for updating the systemd status field of the
    #  systemd service
    #
    def sd_notify(self, ready, msg = None):
        if self.__error_set:
            return

        notifstr = ''
        if ready > 0:
            notifstr += 'READY=%i\n' % (ready)
        if msg is not None and len(msg) > 0:
            notifstr += "%s\n" % (msg)

        systemd.daemon.notify(notifstr)


    ##
    #  Simple Log signal callback function.  Called each time a Log event
    #  happens on this session.
    #
    def _LogHandler(self, group, catg, msg):
        loglines = [l for l in msg.split('\n') if len(l) > 0]
        if len(loglines) < 1:
            return

        print('%s %s' % (datetime.datetime.now(), loglines[0]))
        for line in loglines[1:]:
            print('%s%s' % (' ' * 33, line))


    ##
    #  Simple StatusChange callback function.  Called each time a
    #  StatusChange event happens on this session.
    #
    #  This handler will also shutdown this currently running session if
    #  the session is disconnected outside of this program
    #
    def _StatusHandler(self, major, minor, msg):
        maj = StatusMajor(major)
        min = StatusMinor(minor)

        print('%s [STATUS] (%s, %s) %s' % (datetime.datetime.now(),
                                           maj, min, msg))
        self.sd_notify(1, "STATUS=%s:%s, %s" % (maj, min, msg))

        # Session got most likely disconnected outside of this program
        if StatusMajor.SESSION == maj and StatusMinor.PROC_STOPPED == min:
            self.sd_notify(1, "STATUS=%s:%s, %s\nSTOPPING=1" % (maj, min, msg))
            self.__mainloop.quit()
        if StatusMajor.CONNECTION == maj and StatusMinor.CONN_AUTH_FAILED == min:
            self.sd_notify(0, "STATUS=%s:%s, %s\nERRRNO=%i" % (maj, min, msg, errno.EACCES))
            self.__error_set = True
            self.__mainloop.quit()
        if StatusMajor.CONNECTION == maj and StatusMinor.CONN_DISCONNECTED == min:
            self.sd_notify(1, "STATUS=%s:%s, %s\nSTOPPING=1" % (maj, min, msg))
            self.__mainloop.quit()
        if StatusMajor.CONNECTION == maj and StatusMinor.CFG_REQUIRE_USER == min:
            self.__request_credentials()
            self.__session.Connect()


    ##
    #  This handles the SessionManagerEvent signals from the session manager.
    #  This is used to catch if the session manager closes the session, which
    #  means this script can stop running.
    #
    def _SessionManagerEvent(self, event):
        if 'OPENVPN3_DEBUG' in os.environ:
            print('%s %s' % (datetime.datetime.now(), str(event)))

        if self.__sesspath == event and \
           SessionManagerEventType.SESS_DESTROYED == event:
              self.__mainloop.quit()

    ##
    #  Method for retrieving an openvpn3.Config object from a configuration
    #  profile name.  This is needed to start a new VPN session
    #
    def RetrieveConfig(self, config):
        if self.__config is not None:
            raise RuntimeError("Configuration profile already retrieved")

        cfglist = self.__cmgr.LookupConfigName(config)
        if 0 == len(cfglist):
            raise RuntimeError("No configuration found")
        elif 1 < len(cfglist):
            raise RuntimeError("More than one configuration profile found.")

        self.__config = self.__cmgr.Retrieve(cfglist[0])
        self.__config.Validate()
        print("Loaded configuration profile %s (path: %s)" % (
            self.__config.GetProperty('name'),
            self.__config.GetPath()))


    ##
    #  Method for retrieving a running VPN session based on the configuration
    #  profile name.  This is needed to restart or stop an on-going VPN
    #  session via systemctl.
    #
    def RetrieveSession(self, config):
        if self.__session is not None:
            raise RuntimeError("Session already retrieved")

        sesslist = self.__smgr.LookupConfigName(config)
        if 0 == len(sesslist):
            raise RuntimeError("No running VPN session for configuration")
        elif 1 < len(sesslist):
            raise RuntimeError("More than one running session uses this configuration profile")

        self.__session = self.__smgr.Retrieve(sesslist[0])
        print("Retrieved session data for profile %s/%s (path: %s)" % (
            self.__session.GetProperty('config_name'),
            self.__session.GetProperty('session_name'),
            self.__session.GetPath()))


    ##
    #  Starts a new VPN session.
    #
    #  Before this method is called, the RetrieveConfig() method must be called
    #
    def Start(self, log_level):
        if self.__session is not None:
            raise RuntimeError("Session already started")
        if self.__config is None:
            raise RuntimeError("No configuration profile has been retrieved")

        #  Create a new VPN Session
        self.__session = self.__smgr.NewTunnel(self.__config)
        time.sleep(2)
        self.__sesspath = self.__session.GetPath()

        # Set up signal callback handlers and the proper log level
        self.__session.LogCallback(self._LogHandler)
        self.__session.StatusChangeCallback(self._StatusHandler)

        if log_level > 0:
            print(">> Changing log level to %i" % log_level)
            self.__session.SetProperty('log_verbosity', dbus.UInt32(log_level))

        print("Session initiated: %s" % self.__session.GetPath())
        done = False
        while done is False:
            try:
                self.__session.Ready()
                print("Starting session connection")
                self.__session.Connect()
                print("Session started successfully")

                self.__mainloop.run()
                try:
                    print("Disconnecting")
                    self.sd_notify(1, "STATUS=Stopping\nSTOPPING=1")
                    try:
                        self.__session.LogCallback(None)
                    except dbus.exceptions.DBusException:
                        pass
                    self.__session.Disconnect()
                except dbus.exceptions.DBusException as e:
                    if str(e).split(':')[0] == 'org.freedesktop.DBus.Error.UnknownMethod':
                        pass
                    else:
                        raise e
                print("Disconnected")
            except dbus.exceptions.DBusException as excep:
                if str(excep).find('Backend VPN process is not ready') > 1:
                    time.sleep(0.5)
                elif str(excep).find(' Missing user credentials') > 0:
                    self.__request_credentials()
                elif str(excep).find('ERR_PROFILE_SERVER_LOCKED_UNSUPPORTED:') > 1:
                    self.__session.Disconnect()
                    raise Exception('Server locked profiles are not supported')
                elif str(excep).find(' Backend VPN process has died') > 1:
                    raise Exception('VPN backend process stopped unexpectedly. '
                                      + 'Session has closed. :: ' + str(excep))
            except KeyboardInterrupt:
                print("Disconnecting (sigint)")
                self.sd_notify(1, "STATUS=Stopping (SIGINT)\nSTOPPING=1")
                done = True
                try:
                    self.__session.LogCallback(None)
                except dbus.exceptions.DBusException:
                    pass
                try:
                    self.__session.Disconnect()
                except dbus.exceptions.DBusException as excep:
                    pass
                break
            except RuntimeError:
                done = True
                break


    ##
    #  Restart an on-going VPN session
    #
    #  The RetrieveSession() method must be called before calling
    #  this method
    def Restart(self):
        if self.__session is None:
            raise RuntimeError("Session not started")

        self.__session.Restart()


    ##
    #  Stop an on-going VPN session
    #
    #  The RetrieveSession() method must be called before calling
    #  this method
    def Stop(self):
        if self.__session is None:
            raise RuntimeError("Session not started")

        self.__session.Disconnect()


    def __request_credentials(self):
        args_base = [ '/usr/bin/systemd-ask-password','--icon','network-vpn']

        for input_slot in self.__session.FetchUserInputSlots():
            if input_slot.GetTypeGroup()[0] != ClientAttentionType.CREDENTIALS:
                # skip non-user credential requests
                continue

            args = [] + args_base
            if False == input_slot.GetInputMask():
                args += ['--echo']
            args += ['%s:' % (input_slot.GetLabel())]
            try:
                r = subprocess.check_output(args)
                input_slot.ProvideInput(r.decode('utf-8').strip())
            except subprocess.CalledProcessError as e:
                self.__session.Disconnect()
                raise RuntimeError("Could not retrieve user credentials: %s" % str(e))


if "__main__" == __name__:
    start_log_level = -1
    if 'OPENVPN3_DEBUG' in os.environ:
        start_log_level = 6
    argp = argparse.ArgumentParser('openvpn3-systemd',
                                   description='OpenVPN 3 systemd session management')
    argp.add_argument('config', metavar='CONFIG_PROFILE', type=str,
                      help='Configuration profile name of session which is to be managed')
    argp.add_argument('--start', action='store_const', const='start', dest='mode',
                      help='Starts a VPN session with this configuration profile')
    argp.add_argument('--restart', action='store_const', const='restart', dest='mode',
                      help='Restarts a running VPN session with this configuration profile')
    argp.add_argument('--stop', action='store_const', const='stop', dest='mode',
                      help='Stops a running VPN session with this configuration profile')
    argp.add_argument('--log-level', type=int, action='store', default=start_log_level,
                      help='Set the log verbosity level')
    args = argp.parse_args()

    if args.mode not in ['start','restart','stop']:
        print('%s: ** ERROR ** One of --start, --restart or --stop is required' % sys.argv[0])
        sys.exit(1)

    try:
        o3srv = OpenVPN3systemd()
        if 'start' == args.mode:
            o3srv.RetrieveConfig(args.config)
            o3srv.Start(args.log_level)
        elif 'restart' == args.mode:
            o3srv.RetrieveSession(args.config)
            print('Restarting VPN session')
            o3srv.Restart();
        elif 'stop' == args.mode:
            o3srv.RetrieveSession(args.config)
            print('Stopping VPN session')
            o3srv.Stop();

    except RuntimeError as err:
        print('%s: ** ERROR **  %s' % (sys.argv[0], str(err)))
    except Exception as err:
        print('%s: ** ERROR **  %s' % (sys.argv[0], str(err)))
