#!/usr/bin/env python3
# encoding: utf-8
"""
bgp

Created by Thomas Mangin
Copyright (c) 2013-2017 Exa Networks. All rights reserved.
License: 3-clause BSD. (See the COPYRIGHT file)
"""

import os
import pwd
import sys
import time
import errno
import socket
import threading
import signal
import asyncore
import subprocess
from struct import unpack

SIGNAL = dict([(name,getattr(signal,name)) for name in dir(signal) if name.startswith('SIG')])


def flushed (*output):
	print(' '.join(str(_) for _ in output))
	sys.stdout.flush()


def bytestream (value):
	return ''.join(['%02X' % _ for _ in value])


def dump (value):
	def spaced (value):
		even = None
		for v in value:
			if even is False:
				yield ' '
			yield '%02X' % v
			even = not even
	return ''.join(spaced(value))


def cdr_to_length (cidr):
	if cidr > 24:
		return 4
	if cidr > 16:
		return 3
	if cidr > 8:
		return 2
	if cidr > 0:
		return 1
	return 0


class BGPHandler(asyncore.dispatcher_with_send):
	counter = 0

	keepalive = bytearray([0xFF, ] * 16 + [0x0, 0x13, 0x4])

	_name = {
		b'\x01': 'OPEN',
		b'\x02': 'UPDATE',
		b'\x03': 'NOTIFICATION',
		b'\x04': 'KEEPALIVE',
	}

	def signal (self,myself,signal_name='SIGUSR1'):
		signal_number = SIGNAL.get(signal_name,'')
		if not signal_number:
			self.announce('invalid signal name in configuration : %s' % signal_name)
			self.announce('options are: %s' % ','.join(SIGNAL.keys()))
			sys.exit(1)

		conf_name = sys.argv[1].split('/')[-1].split('.')[0]

		processes = []

		for line in os.popen("/bin/ps x"):
			low = line.strip().lower()
			if not low:
				continue
			if 'python' not in low and 'pypy' not in low:
				continue

			cmdline = line.strip().split()[4:]
			pid = line.strip().split()[0]

			if len(cmdline) > 1 and not cmdline[1].endswith('/bgp.py'):
				continue

			if conf_name not in cmdline[-1]:
				continue

			if not cmdline[-1].endswith('.conf'):
				continue

			processes.append(pid)

		if len(processes) == 0:
			self.announce('no running process found, this should not happend, quitting')
			sys.exit(1)

		if len(processes) > 1:
			self.announce('more than one process running, this should not happend, quitting')
			sys.exit(1)

		try:
			self.announce('sending signal %s to ExaBGP (pid %s)\n' % (signal_name,processes[0]))
			os.kill(int(processes[0]),signal_number)
		except Exception as exc:
			self.announce('\n     failed: %s' % str(exc))

	def kind (self, header):
		return header[18]

	def isupdate (self, header):
		return header[18] == 2

	def isnotification (self, header):
		return header[18] == 4

	def name (self, header):
		return self._name.get(header[18],'SOME WEIRD RFC PACKET')

	def routes (self, header, body):
		len_w = unpack('!H',body[0:2])[0]
		withdrawn = bytearray([_ for _ in body[2:2+len_w]])
		len_a = unpack('!H',body[2+len_w:2+len_w+2])[0]
		announced = bytearray([_ for _ in body[2+len_w + 2+len_a:]])

		if not withdrawn and not announced:
			if len(body) == 4:
				yield 'eor:1:1'
			elif len(body) == 11:
				yield 'eor:%d:%d' % (body[-2],body[-1])
			else:  # undecoded MP route
				yield 'mp:'
			return

		while withdrawn:
			cdr,withdrawn = withdrawn[0],withdrawn[1:]
			size = cdr_to_length(cdr)
			r = [0,0,0,0]
			for index in range(size):
				r[index], withdrawn = withdrawn[0], withdrawn[1:]
			yield 'withdraw:%s' % '.'.join(str(_) for _ in r) + '/' + str(cdr)

		while announced:
			cdr, announced = announced[0], announced[1:]
			size = cdr_to_length(cdr)
			r = [0,0,0,0]
			for index in range(size):
				r[index], announced = announced[0], announced[1:]
			yield 'announce:%s' % '.'.join(str(_) for _ in r) + '/' + str(cdr)

	def notification (self, header, body):
		yield 'notification:%d,%d' % (body[0],body[1]), bytestream(body)

	def announce (self, *args):
		flushed('    ',self.ip,self.port,' '.join(str(_) for _ in args) if len(args) > 1 else args[0])

	def check_signal (self):
		if self.messages and self.messages[0].startswith('signal:'):
			name = self.messages.pop(0).split(':')[-1]
			self.signal(os.getppid(),name)

	def setup (self, ip, port, messages, options):
		self.ip = ip
		self.port = port
		self.options = options
		self.handle_read = self.handle_open
		self.sequence = {}
		self.raw = False
		for rule in messages:
			sequence,announcement = rule.split(':',1)
			if announcement.startswith('raw:'):
				self.raw = True
				announcement = ''.join(announcement[4:].replace(':',''))
			self.sequence.setdefault(sequence,[]).append(announcement)
		self.update_sequence()
		return self

	def update_sequence (self):
		if self.options['sink'] or self.options['echo']:
			self.messages = []
			return True
		keys = sorted(list(self.sequence))
		if keys:
			key = keys[0]
			self.messages = self.sequence[key]
			self.step = key
			del self.sequence[key]

			self.check_signal()
			# we had a list with only one signal
			if not self.messages:
				return self.update_sequence()
			return True
		return False

	def read_message (self):
		header = b''
		while len(header) != 19:
			try:
				left = 19-len(header)
				header += self.recv(left)
				if left == 19-len(header):  # ugly
					# the TCP session is gone.
					return None,None
			except socket.error as exc:
				if exc.args[0] in (errno.EWOULDBLOCK,errno.EAGAIN):
					continue
				raise exc

		length = unpack('!H',header[16:18])[0] - 19

		body = b''
		while len(body) != length:
			try:
				left = length-len(body)
				body += self.recv(left)
			except socket.error as exc:
				if exc.args[0] in (errno.EWOULDBLOCK,errno.EAGAIN):
					continue
				raise exc

		return bytearray(header),bytearray(body)

	def handle_open (self):
		# reply with a IBGP response with the same capability (just changing routerID)
		header,body = self.read_message()
		routerid = bytearray([body[8] + 1 & 0xFF])
		o = header+body[:8]+routerid+body[9:]

		if self.options['send-unknown-capability']:
			# hack capability 66 into the message

			content = b'loremipsum'
			cap66 = bytearray([66, len(content)]) + content
			param = bytearray([2, len(cap66)]) + cap66
			o = o[:17] + bytearray([o[17] + len(param)]) + o[18:28] + \
				bytearray([o[28] + len(param)]) + o[29:] + param

		self.send(o)
		self.send(self.keepalive)

		if self.options['send-default-route']:
			self.send(bytearray(
				[0xFF, ] * 16 +
				[0x00, 0x31] +
				[0x02,] +
				[0x00, 0x00] +
				[0x00, 0x15] +
				[] + [0x40, 0x01, 0x01, 0x00] +
				[] + [0x40, 0x02, 0x00] +
				[] + [0x40, 0x03, 0x04, 0x7F, 0x00, 0x00, 0x01] +
				[] + [0x40, 0x05, 0x04, 0x00, 0x00, 0x00, 0x64] +
				[0x20, 0x00, 0x00, 0x00, 0x00]
			))
			self.announce('sending default-route\n')

		self.handle_read = self.handle_keepalive

	def handle_keepalive (self):
		header,body = self.read_message()

		if header is None:
			self.announce('connection closed')
			self.close()
			if self.options['send-notification']:
				self.announce('successful')
				sys.exit(0)
			return

		if self.raw:
			def parser (self, header, body):
				if body:
					yield bytestream(header + body)
		else:
			parser = self._decoder.get(self.kind(header),None)

		if self.options['sink']:
			self.announce('received %d: %s' % (self.counter,'%s:%s:%s:%s' % (bytestream(header[:16]),bytestream(header[16:18]),bytestream(header[18:]),bytestream(body))))
			self.send(self.keepalive)
			return

		if self.options['echo']:
			self.announce('received %d: %s' % (self.counter,'%s:%s:%s:%s' % (bytestream(header[:16]),bytestream(header[16:18]),bytestream(header[18:]),bytestream(body))))
			self.send(header+body)
			self.announce('sent     %d: %s' % (self.counter,'%s:%s:%s:%s' % (bytestream(header[:16]),bytestream(header[16:18]),bytestream(header[18:]),bytestream(body))))
			return

		if parser:
			for announcement in parser(self,header,body):
				self.send(self.keepalive)
				if announcement.startswith('eor:'):  # skip EOR
					self.announce('skipping eor',announcement)
					continue

				if announcement.startswith('mp:'):  # skip unparsed MP
					self.announce('skipping multiprotocol :',dump(body))
					continue

				self.counter += 1

				if announcement in self.messages:
					self.messages.remove(announcement)
					if self.raw:
						self.announce('received %d (%1s%s):' % (self.counter,self.options['letter'],self.step),'%s:%s:%s:%s' % (announcement[:32],announcement[32:36],announcement[36:38],announcement[38:]))
					else:
						self.announce('received %d (%1s%s):' % (self.counter,self.options['letter'],self.step), announcement)
					self.check_signal()
				else:
					if self.raw:
						self.announce('received %d (%1s%s):' % (self.counter,self.options['letter'],self.step),'%s:%s:%s:%s' % (bytestream(header[:16]),bytestream(header[16:18]),bytestream(header[18:]),bytestream(body)))
					else:
						self.announce('received %d     :' % self.counter,announcement)

					if len(self.messages) > 1:
						self.announce('expected one of the following :')
						for message in self.messages:
							if message.startswith('F'*32):
								self.announce('                 %s:%s:%s:%s' % (message[:32],message[32:36],message[36:38],message[38:]))
							else:
								self.announce('                 %s' % message)
					elif self.messages:
						message = self.messages[0].upper()
						if message.startswith('F'*32):
							self.announce('expected       : %s:%s:%s' % (message[:32],message[32:38],message[38:]))
						else:
							self.announce('expected       : %s' % message)
					else:
						# can happen when the thread is still running
						self.announce('extra data')
						sys.exit(1)

					sys.exit(1)

				if not self.messages:
					if self.options['single-shot']:
						self.announce('successful (partial test)')
						sys.exit(0)

					if not self.update_sequence():
						if self.options['exit']:
							self.announce('successful')
							sys.exit(0)
		else:
			self.send(self.keepalive)

		if self.options['send-notification']:
			notification = b'closing session because we can'
			self.send(bytearray(
				[0xFF,]*16 +
				[0x00, 19+2+len(notification)] +
				[0x03] +
				[0x06] +
				[0x00]
			) + notification)

	_decoder = {
		2: routes,
		3: notification,
	}


class BGPServer (asyncore.dispatcher):
	def announce (self, *args):
		flushed('    ' + ' '.join(str(_) for _ in args) if len(args) > 1 else args[0])

	def __init__ (self, host, options):
		asyncore.dispatcher.__init__(self)

		if ':' in host:
			self.create_socket(socket.AF_INET6, socket.SOCK_STREAM)
		else:
			self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
		self.set_reuse_addr()
		self.bind((host,options['port']))
		self.listen(5)

		self.messages = {}

		self.options = {
			'send-unknown-capability': False,  # add an unknown capability to the open message
			'send-default-route': False,       # send a default route to the peer
			'send-notification': False,        # send notification messages to the backend
			'signal-SIGUSR1': 0,               # send SIGUSR1 after X seconds
			'single-shot': False,              # we can not test signal on python 2.6
			'sink': False,                     # just accept whatever is sent
			'echo': False,                     # just accept whatever is sent
		}
		self.options.update(options)

		for message in options['messages']:
			if message.strip() == 'option:open:send-unknown-capability':
				self.options['send-unknown-capability'] = True
				continue
			if message.strip() == 'option:update:send-default-route':
				self.options['send-default-route'] = True
				continue
			if message.strip() == 'option:notification:send-notification':
				self.options['send-notification'] = True
				continue
			if message.strip().startswith('option:SIGUSR1:'):
				def notify (delay,myself):
					time.sleep(delay)
					self.signal(myself)
					time.sleep(10)

				# Python 2.6 can not perform this test as it misses the function
				if 'check_output' in dir(subprocess):
					# thread.start_new_thread(notify,(int(message.split(':')[-1]),os.getpid()))
					threading.Thread(target=notify, args=(int(message.split(':')[-1]), os.getpid()))
				else:
					self.options['single-shot'] = True
				continue

			if message[0].isalpha():
				index,content = message[:1].upper(), message[1:]
			else:
				index,content = 'A',message
			self.messages.setdefault(index,[]).append(content)

	def handle_accept (self):
		messages = None
		for number in range(ord('A'),ord('Z')+1):
			letter = chr(number)
			if letter in self.messages:
				messages = self.messages[letter]
				del self.messages[letter]
				break

		if self.options['sink']:
			flushed('\nsink mode - send us whatever, we can take it ! :p\n')
			messages = []
		elif self.options['echo']:
			flushed('\necho mode - send us whatever, we can parrot it ! :p\n')
			messages = []
		elif not messages:
			self.announce('we used all the test data available, can not handle this new connection')
			sys.exit(1)
		else:
			flushed('using :\n   ', '\n    '.join(messages),'\n\nconversation:\n')

		self.options['exit'] = not len(self.messages.keys())
		self.options['letter'] = letter

		pair = self.accept()
		if pair is not None:
			sock,addr = pair
			handler = BGPHandler(sock).setup(
				*addr[:2],
				messages=messages,
				options=self.options
			)


def drop ():
	uid = os.getuid()
	gid = os.getgid()

	if uid and gid:
		return

	for name in ['nobody',]:
		try:
			user = pwd.getpwnam(name)
			nuid = int(user.pw_uid)
			ngid = int(user.pw_uid)
		except KeyError:
			pass

	if not gid:
		os.setgid(ngid)
	if not uid:
		os.setuid(nuid)


def main ():
	port = os.environ.get('exabgp.tcp.port', os.environ.get('exabgp_tcp_port', '179'))

	if not port.isdigit() and port >0 and port <= 65535 or len(sys.argv) <= 1:
		flushed('--sink   accept any BGP messages and reply with a keepalive')
		flushed('--echo   accept any BGP messages send it back to the emiter')
		flushed('--port <port>   port to bind to')
		flushed('a list of expected route announcement/withdrawl in the format <number>:announce:<ipv4-route> <number>:withdraw:<ipv4-route> <number>:raw:<exabgp hex dump : separated>')
		flushed('for example:',sys.argv[0],'1:announce:10.0.0.0/8 1:announce:192.0.2.0/24 2:withdraw:10.0.0.0/8 ')
		flushed('routes with the same <number> can arrive in any order')
		sys.exit(1)

	options = {
		'sink': False,
		'echo': False,
		'port': int(port),
		'messages': []
	}

	for arg in sys.argv[1:]:
		if arg == '--sink':
			messages = []
			options['sink'] = True
			continue

		if arg == '--echo':
			messages = []
			options['echo'] = True
			continue

		if arg == '--port':
			args = sys.argv[1:] + ['',]
			port = args[args.index('--port') + 1]
			if port.isdigit() and int(port)>0:
				options['port'] = int(port)
				continue
			print('invalid port %s' % port)
			sys.exit(1)

		if arg == str(options['port']):
			continue

		try:
			with open(sys.argv[1]) as content:
				options['messages'] = [_.strip() for _ in content.readlines() if _.strip() and '#' not in _]
		except IOError:
			flushed('could not open file', sys.argv[1])
			sys.exit(1)

	try:
		BGPServer('127.0.0.1',options)
		try:
			BGPServer('::1',options)
		except:
			# does not work on travis-ci
			pass
		drop()
		asyncore.loop()
	except socket.error as exc:
		if exc.errno == errno.EACCES:
			flushed('failure: could not bind to port %s - most likely not run as root' % port)
		elif exc.errno == errno.EADDRINUSE:
			flushed('failure: could not bind to port %s - port already in use' % port)
		else:
			flushed('failure', str(exc))


if __name__ == '__main__':
	main()
