#!/usr/bin/env python3

'''
Wrapper script for running the Rumur model checker.

This script is intended to be installed alongside the `rumur` binary from the
Rumur model checker. It can then be used to quickly generate and run a model, as
an alternative to having to run the model generation, compilation and execution
steps manually.
'''

import atexit
import os
import platform
import re
import shutil
import subprocess as sp
import sys
import tempfile
from typing import Optional

def which(cmd: str) -> Optional[str]:
  '''
  Equivalent of shell `which`
  '''
  try:
    return sp.check_output(['which', cmd], stderr=sp.DEVNULL,
      universal_newlines=True).strip()
  except sp.CalledProcessError:
    return None

# C compiler
CC = which(os.environ.get('CC', 'cc'))

def categorise(cc: str) -> str:
  '''
  Determine the vendor of a given C compiler
  '''

  # Create a temporary area to compile a test file
  tmp = tempfile.mkdtemp()

  # Setup the test file
  src = os.path.join(tmp, 'test.c')
  with open(src, 'wt') as f:
    f.write('#include <stdio.h>\n'
            '#include <stdlib.h>\n'
            'int main(void) {\n'
            '#ifdef __clang__\n'
            '  printf("clang\\n");\n'
            '#elif defined(__GNUC__)\n'
            '  printf("gcc\\n");\n'
            '#else\n'
            '  printf("unknown\\n");\n'
            '#endif\n'
            '  return EXIT_SUCCESS;\n'
            '}\n')

  categorisation = 'unknown'

  # Compile it
  aout = os.path.join(tmp, 'a.out')
  cc_proc = sp.run([cc, '-o', aout, src], universal_newlines=True,
    stdout=sp.DEVNULL, stderr=sp.DEVNULL)

  # Run it
  if cc_proc.returncode == 0:
    try:
      categorisation = sp.check_output([aout], universal_newlines=True).strip()
    except sp.CalledProcessError:
      pass

  # Clean up
  shutil.rmtree(tmp)

  return categorisation

def supports(flag: str) -> bool:
  '''check whether the compiler supports a given command line flag'''

  # a trivial program to ask it to compile
  program = 'int main(void) { return 0; }'

  # compile it
  p = sp.run([CC, '-o', os.devnull, '-x', 'c', '-', flag], stderr=sp.DEVNULL,
    input=program.encode('utf-8', 'replace'))

  # check whether compilation succeeded
  return p.returncode == 0

def needs_libatomic() -> bool:
  '''check whether the compiler needs -latomic for a double-word
  compare-and-swap'''

  # CAS program to ask it to compile
  program = '''
#include <stdbool.h>
#include <stdint.h>

// replicate what is in ../resources/header.c

#define THREADS 2

#if __SIZEOF_POINTER__ <= 4
  typedef uint64_t dword_t;
#elif __SIZEOF_POINTER__ <= 8
  typedef unsigned __int128 dword_t;
#else
  #error "unexpected pointer size; what scalar type to use for dword_t?"
#endif

static dword_t atomic_read(dword_t *p) {

  if (THREADS == 1) {
    return *p;
  }

#if defined(__x86_64__) || defined(__i386__)
  /* x86-64: MOV is not guaranteed to be atomic on 128-bit naturally aligned
   *   memory. The way to work around this is apparently the following
   *   degenerate CMPXCHG16B.
   * i386: __atomic_load_n emits code calling a libatomic function that takes a
   *   lock, making this no longer lock free. Force a CMPXCHG8B by using the
   *   __sync built-in instead.
   */
  return __sync_val_compare_and_swap(p, 0, 0);
#endif

  return __atomic_load_n(p, __ATOMIC_SEQ_CST);
}

static void atomic_write(dword_t *p, dword_t v) {

  if (THREADS == 1) {
    *p = v;
    return;
  }

#if defined(__x86_64__) || defined(__i386__)
  /* As explained above, we need some extra gymnastics to avoid a call to
   * libatomic on x86-64 and i386.
   */
  dword_t expected;
  dword_t old = 0;
  do {
    expected = old;
    old = __sync_val_compare_and_swap(p, expected, v);
  } while (expected != old);
  return;
#endif

  __atomic_store_n(p, v, __ATOMIC_SEQ_CST);
}

static bool atomic_cas(dword_t *p, dword_t expected, dword_t new) {

  if (THREADS == 1) {
    if (*p == expected) {
      *p = new;
      return true;
    }
    return false;
  }

#if defined(__x86_64__) || defined(__i386__)
  /* Make GCC >= 7.1 emit cmpxchg on x86-64 and i386. See
   * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80878.
   */
  return __sync_bool_compare_and_swap(p, expected, new);
#endif

  return __atomic_compare_exchange_n(p, &expected, new, false, __ATOMIC_SEQ_CST,
    __ATOMIC_SEQ_CST);
}

static dword_t atomic_cas_val(dword_t *p, dword_t expected, dword_t new) {

  if (THREADS == 1) {
    dword_t old = *p;
    if (old == expected) {
      *p = new;
    }
    return old;
  }

#if defined(__x86_64__) || defined(__i386__)
  /* Make GCC >= 7.1 emit cmpxchg on x86-64 and i386. See
   * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80878.
   */
  return __sync_val_compare_and_swap(p, expected, new);
#endif


  (void)__atomic_compare_exchange_n(p, &expected, new, false, __ATOMIC_SEQ_CST,
    __ATOMIC_SEQ_CST);
  return expected;
}

int main(void) {
  dword_t target = 0;

  target = atomic_read(&target);

  atomic_write(&target, 42);

  atomic_cas(&target, 42, 0);

  return (int)atomic_cas_val(&target, 0, 42);
}
'''

  # compile it
  args = [CC, '-x', 'c', '-std=c11', '-', '-o', os.devnull]
  if supports('-mcx16'):
    args.append('-mcx16')
  p = sp.run(args, stderr=sp.DEVNULL, input=program.encode('utf-8', 'replace'))

  # check whether compilation succeeded
  return p.returncode != 0

def optimisation_flags() -> [str]:
  '''C compiler optimisation command line options for this platform'''

  flags = ['-O3']

  # optimise code for the current host architecture
  if supports('-march=native'): flags.append('-march=native')

  # optimise code for the current host CPU
  if supports('-mtune=native'): flags.append('-mtune=native')

  # enable link-time optimisation
  if supports('-flto'): flags.append('-flto')

  cc_vendor = categorise(CC)

  # allow GCC to perform more advanced interprocedural optimisations
  if cc_vendor == 'gcc': flags.append('-fwhole-program')

  # on platforms that need it made explicit, enable CMPXCHG16B
  if supports('-mcx16'): flags.append('-mcx16')

  return flags

def has_no_la57() -> bool:
  '''
  does our hardware lack support for Intel 5-level paging?
  '''

  # if we are not on an x86-64 platform, this is irrelevant
  if platform.machine() not in ('amd64', 'x86_64'):
    return False

  # read cpuinfo, looking for la57 as a supported flag by the CPU
  try:
    with open('/proc/cpuinfo', 'rt') as f:
      for line in f:
        flags = re.match(r'flags\s*:\s*(?P<flags>.*)$', line)
        if flags is not None:
          # found the flags field; now is it missing LA57?
          return re.search(r'\bla57\b', flags.group('flags')) is None

  except (FileNotFoundError, PermissionError):
    # procfs is unavailable
    return False

  # if we read the entire cpuinfo and did not find the flags field,
  # conservatively assume we may support LA57
  return False

def main(args: [str]) -> int:

  # Find the Rumur binary
  rumur_bin = which(os.path.join(
    os.path.abspath(os.path.dirname(__file__)), 'rumur'))
  if rumur_bin is None:
    rumur_bin = which('rumur')
  if rumur_bin is None:
    sys.stderr.write('rumur binary not found\n')
    return -1

  # if the user asked for help or version information, run Rumur directly
  for arg in args[1:]:
    if arg.startswith('-h') or arg.startswith('--h') or arg.startswith('--vers'):
      os.execv(rumur_bin, [rumur_bin] + args[1:])

  if CC is None:
    sys.stderr.write('no C compiler found\n')
    return -1

  argv = [rumur_bin]
  # if this hardware does not support 5-level paging, we can more aggressively
  # compress pointers
  if has_no_la57():
    argv += ['--pointer-bits', '48']
  argv += args[1:] + ['--output', '/dev/stdout']

  # Generate the checker
  print('Generating the checker...')
  rumur_proc = sp.run(argv, stdin=sp.PIPE, stdout=sp.PIPE)
  if rumur_proc.returncode != 0:
    return rumur_proc.returncode
  checker_c = rumur_proc.stdout

  ok = True

  # Setup a temporary directory in which to generate the checker
  tmp = tempfile.mkdtemp()
  atexit.register(shutil.rmtree, tmp)

  # Compile the checker
  if ok:
    print('Compiling the checker...')
    aout = os.path.join(tmp, 'a.out')
    argv = [CC, '-std=c11'] + optimisation_flags() + ['-o', aout, '-x', 'c',
      '-', '-lpthread']
    if needs_libatomic():
      argv.append('-latomic')
    cc_proc = sp.run(argv, input=checker_c)
    ok &= cc_proc.returncode == 0

  # Run the checker
  if ok:
    print('Running the checker...')
    checker_proc = sp.run([aout])
    ok &= checker_proc.returncode == 0

  return 0 if ok else -1

if __name__ == '__main__':
  try:
    sys.exit(main(sys.argv))
  except KeyboardInterrupt:
    sys.exit(130)
