#!/usr/bin/env python3

prefix = 'https://raw.githubusercontent.com/awslabs/s2n-bignum/acbb18e6343f12a7944de72c1ec0991739600f8c'

todo = (
  ('3c3f8f9496266306172251ed3b64c989b10b8dfbe1ee8353a19b932de22cf0d0','x86/curve25519','curve25519_x25519base'         ,'crypto_nG/montgomery25519/amd64-s2nbignum-adx','amd64 bmi2 adx'),
  ('4b6eef9217cc3822667734eedd23a6a0a7068d6f9b2f401d647f5fde3629c2db','x86/curve25519','curve25519_x25519base_alt'     ,'crypto_nG/montgomery25519/amd64-s2nbignum'    ,'amd64'),
  ('64b08e280715ce2e8b6e48156bcb6bc9cd7ac117e52415e2e5d4b11fe57fa715','x86/curve25519','curve25519_x25519'             ,'crypto_nP/montgomery25519/amd64-s2nbignum-adx','amd64 bmi2 adx'),
  ('dff3fafb80cf180012d71e41e2c508622dca256f03ae6010943528cbfcbc37d3','x86/curve25519','curve25519_x25519_alt'         ,'crypto_nP/montgomery25519/amd64-s2nbignum'    ,'amd64'),
  ('ee7bc766256150ab216f828d11f690485487e1cb86bc288a24fa5953b011050b','arm/curve25519','curve25519_x25519base_byte'    ,'crypto_nG/montgomery25519/arm64-s2nbignum'    ,'arm64'),
  ('afdd9d920cf40fbfb0032e94efe98f7b415ab977c670ec56cb21c32fff9e5b96','arm/curve25519','curve25519_x25519base_byte_alt','crypto_nG/montgomery25519/arm64-s2nbignum-alt','arm64'),
  ('29ad8c639e1d8edac6712900192d9ac1ee5fc4097e1ff09d0b754193ae1acaf6','arm/curve25519','curve25519_x25519_byte'        ,'crypto_nP/montgomery25519/arm64-s2nbignum'    ,'arm64'),
  ('1ea1ef06b85faa46ac7c01578ceaee69142ba36d5a1e2abd035afc2b07c5dc1b','arm/curve25519','curve25519_x25519_byte_alt'    ,'crypto_nP/montgomery25519/arm64-s2nbignum-alt','arm64'),
)

import os
import re
import urllib.request
import hashlib

def preprocess(asm): # expand #define x(y) and remove semicolons
  macros = {}
  processed = []

  def expand(processed,line):
    m = re.match(rb'\s*([a-zA-Z0-9_]*)\((.*)\)$',line)
    if m is None:
      processed += [line]
      return
    fun = m.group(1)
    args = m.group(2)
    processed += [b'// expanding %b(%b):' % (fun,args)]
    assert fun in macros
    assert macros[fun][0].count(b',') == args.count(b',')
    tr = {param:arg for param,arg in zip(macros[fun][0].split(b','),args.split(b','))}
    for result in macros[fun][1]:
      for param in tr:
        result = re.sub(param,tr[param],result)
      expand(processed,result)

  inmacro = False
  for line in asm:
    if inmacro:
      insn = line
      while insn.endswith(b'\\') or insn.endswith(b' ') or insn.endswith(b';'): insn = insn[:-1]
      macro_out += [insn]
      if not line.endswith(b'\\'):
        macros[macro] = macro_args,macro_out
        inmacro = False
    else:
      m = re.match(rb'#define (.*)\((.*)\) *\\$',line)
      if m is not None:
        inmacro = True
        macro = m.group(1)
        macro_args = m.group(2)
        macro_out = []
      else:
        expand(processed,line)

  return processed

for sha256,dir,source,target,architectures in todo:
  print(f'{prefix}/{dir}/{source}.S')
  with urllib.request.urlopen(f'{prefix}/{dir}/{source}.S') as http:
    asm = http.read()

  if hashlib.sha256(asm).hexdigest() != sha256:
    raise Exception(f'hash mismatch: expected {sha256}, received {hashlib.sha256(asm).hexdigest()}')

  asm = asm.splitlines()
  asm2 = []
  for line in asm:
    if line == b'#include "_internal_s2n_bignum.h"': continue
    if line.startswith(b'S2N_BN_SYMBOL') and line.endswith(b':'):
      parenfun = line[len(b'S2N_BN_SYMBOL'):-1]
      asm2 += [b'ASM_HIDDEN CRYPTO_SHARED_NAMESPACE'+parenfun]
      asm2 += [b'.globl CRYPTO_SHARED_NAMESPACE'+parenfun]
      asm2 += [b'ASM_HIDDEN _CRYPTO_SHARED_NAMESPACE'+parenfun]
      asm2 += [b'.globl _CRYPTO_SHARED_NAMESPACE'+parenfun]
      asm2 += [b'CRYPTO_SHARED_NAMESPACE'+parenfun+b':']
      asm2 += [b'_CRYPTO_SHARED_NAMESPACE'+parenfun+b':']
      continue
    if line.startswith(b'        S2N'): continue
    if line.startswith(b'// .section .rodata'):
      asm2 = [b'#include "crypto_asm_rodata.h"']+asm2
      line = b'ASM_RODATA'
    asm2 += [line]
  asm = b'\n'.join(preprocess(asm2))+b'\n'

  if '_byte' in source:
    fun = source
  else:
    fun = source.replace('x25519','x25519_byte')
    fun = fun.replace('x25519_bytebase','x25519base_byte')

  os.makedirs(target,exist_ok=True)

  with open(f'{target}/architectures','w') as f:
    f.write(architectures+'\n')
  with open(f'{target}/{source}.S','wb') as f:
    f.write(b'// generated by use-s2n-bignum script\n')
    f.write(b'// starting from s2n-bignum\n')
    f.write(b'\n')
    f.write(b'#include "crypto_asm_hidden.h"\n')
    f.write(b'\n')
    f.write(asm)

  nG = target.startswith('crypto_nG')
  with open(f'{target}/wrapper.c','w') as f:
    f.write('#include <stdint.h>\n')
    if nG:
      f.write('#include "crypto_nG.h"\n\n')
    else:
      f.write('#include "crypto_nP.h"\n\n')
    f.write(f'#define wrapped CRYPTO_SHARED_NAMESPACE({fun})\n\n')
    if nG:
      f.write('extern void wrapped(uint8_t *,const uint8_t *);\n\n')
    else:
      f.write('extern void wrapped(uint8_t *,const uint8_t *,const uint8_t *);\n\n')
    if nG:
      f.write('void crypto_nG(unsigned char *nG,const unsigned char *n)\n')
    else:
      f.write('void crypto_nP(unsigned char *q,const unsigned char *n,const unsigned char *p)\n')
    f.write('{\n')
    if nG:
      f.write('  wrapped(nG,n);\n')
    else:
      f.write('  wrapped(q,n,p);\n')
    f.write('}\n')
