#!/usr/bin/python3
import os
import signal
import socket
import subprocess
import sys
import time
import types


# Re-implement sd_notify to avoid the need to tweak apparmor permissions to run systemd-notify(1)
def sd_notify(unset_environment: bool, state: str) -> bool:
    addr = os.getenv("NOTIFY_SOCKET")
    if addr is None:
        return False
    with socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) as sock:
        sock.connect(addr)
        sock.sendmsg([state.encode("UTF-8")])
    if unset_environment:
        os.unsetenv("NOTIFY_SOCKET")
    return True

def write_file(name: str) -> bool:
    SNAP_COMMON = os.getenv("SNAP_COMMON")
    if SNAP_COMMON is None:
        return False
    with open(os.path.join(SNAP_COMMON, name), "w") as stream:
        pass
    return True

def main() -> None:
    # Keep track of the need to keep waiting for signals
    loop = True

    # Keep track of the signal that killed us so that we don't attempt IO from
    # the signal handler.
    incoming_signum = None

    # When one of the signals given below arrives, record this fact and stop the loop.
    def on_signal(signum: int, tb: types.FrameType) -> None:
        nonlocal incoming_signum
        incoming_signum = signum
        nonlocal loop
        loop = False

    for signum in signal.SIGUSR1, signal.SIGUSR2, signal.SIGHUP:
        signal.signal(signum, on_signal)

    # Tell systemd that we are considered ready now or touch a file in
    # $SNAP_COMMON/ready that conveys the same message. The tests can use that
    # to synchronize.
    try:
        sd_notify(False, "READY=1\n") or write_file("ready")
    except PermissionError:
        raise SystemExit("cannot send systemd notification message")

    # Keep looping until a signal arrives, logging message to show we are alive.
    while loop:
        print("running {} process".format(sys.argv[1] if len(sys.argv) > 1 else "???"))
        time.sleep(1)
        if incoming_signum is not None:
            signame = signal.Signals(incoming_signum).name.lower()
            print("got {}".format(signame))
            write_file("{}-{}".format(sys.argv[1],signame))


if __name__ == "__main__":
    main()
