#!/usr/bin/python
# Generate SSHFP DNS records (RFC4255) from knownhosts files or ssh-keyscan
# Copyright 2.16-2010 by Xelerance http://www.xelerance.com/
# Copyright 2012 Paul Wouters <pwouters@redhat.com> 
# License: GNU GENERAL PUBLIC LICENSE Version 2 or later

VERSION = "2.5"

import os
import sys
import subprocess
import optparse
import base64
import time
# www.dnspython.org
try:
	import dns.resolver
	import dns.query
	import dns.zone
except ImportError:
	print >> sys.stderr, "sshfp requires the python-dns package from http://www.pythondns.org/"
	print >> sys.stderr, "Fedora: yum install python-dns"
	print >> sys.stderr, "Debian: apt-get install python-dnspython   (NOT python-dns!)"
	print >> sys.stderr, "openSUSE: zypper in python-dnspython"
	sys.exit(1)

try:
	import hashlib
	digest = hashlib.sha1
except ImportError:
	import sha
	digest = sha.new


global all_hosts
global khfile
global hostnames
global trailing
global quiet
global port
global timeout
global algo

DEFAULT_KNOWN_HOSTS_FILE = "~/.ssh/known_hosts"

def show_version():
	print >> sys.stderr, "sshfp version: " + VERSION

def create_sshfp(hostname, keytype, keyblob):
	"""Creates an SSH fingerprint"""

	if keytype == "ssh-rsa":
		keytype = "1"
	else:
		if keytype == "ssh-dss":
			keytype = "2"
		else:
			return ""
	try:
		rawkey = base64.b64decode(keyblob)
	except TypeError:
		print >> sys.stderr, "FAILED on hostname "+hostname+" with keyblob "+keyblob
		return "ERROR"
	fpsha1 = digest(rawkey).hexdigest().upper()
	# check for Reverse entries
	reverse = 1
	parts = hostname.split(".", 3)
	if parts[0] != hostname:
		for octet in parts:
			if not octet.isdigit():
				reverse = 0
		if reverse:
			hostname = parts[3] + "." + parts[2] + "." + parts[1] + "." + parts[0] + ".in-addr.arpa."
	# we don't know wether we need a trailing dot :(
	# eg if someone did "ssh ns.foo" we don't know if this really is "ns.foo." or "ns.foo" plus resolv.conf domainname
	if trailing and not reverse:
		if hostname[-1:] != ".":
			hostname = hostname + "."
	return hostname + " IN SSHFP " + keytype + " 1 " + fpsha1

def get_known_host_entry(known_hosts, host):
	"""Get a single entry out of a known_hosts file

	Uses the ssh-keygen utility."""
	cmd = ["ssh-keygen", "-f", known_hosts, "-F", host]
	process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
	(stdout, stderr) = process.communicate()
	if not quiet and stderr:
		print >> sys.stderr, stderr
	outputl = []
	for e in stdout.split("\n"):
		if not e.startswith("#"):
			outputl.append(e)
	return "\n".join(outputl)

def sshfp_from_file(khfile, wantedHosts):
	global all_hosts
	global algo

	# ok, let's do it
	known_hosts = os.path.expanduser(khfile)
	try:
		khfp = open(known_hosts)
	except IOError:
		print >> sys.stderr, "ERROR: failed to open file "+ known_hosts
		sys.exit(1)
	hashed_known_hosts = False
	if khfp.readline().startswith("|1|"):
		hashed_known_hosts = True
	khfp.close()

	fingerprints = []
	if hashed_known_hosts and all_hosts:
		print >> sys.stderr, "ERROR: %s is hashed, cannot use with -a" % known_hosts
		sys.exit(1)
	elif hashed_known_hosts: #only looking for some known hosts
		for host in wantedHosts:
			fingerprints.append(process_record(get_known_host_entry(known_hosts, host), host))
	else:
		try:
			khfp = open(known_hosts)
		except IOError:
			print >> sys.stderr, "ERROR: failed to open file "+ known_hosts
			sys.exit(1)
		data = khfp.read()
		khfp.close()
		fingerprints.append(process_records(data, wantedHosts))
	return "\n".join(fingerprints)

def check_keytype(keytype):
	global algos
	for algo in algos:
		if "ssh-%s" % algo[:-1] == keytype[:-1]:
			return True
	if not quiet:
		print >> sys.stderr, "Could only find key type %s for %s" % (keytype, hostname)
	return False

def process_record(record, hostname):
	(host, keytype, key) = record.split(" ")
	key = key.rstrip()
	if check_keytype(keytype):
		record = create_sshfp(hostname, keytype, key)
		return record
	return ""

def process_records(data, hostnames):
	"""Process all records in a string.

	If the global "all_hosts" is True, then return SSHFP entries
	for all records with the allowed key types.

	If "all_hosts is False and hostnames is non-empty, return
	only the items in hostnames
	"""
	all_records = []
	for record in data.split("\n"):
		if record.startswith("#") or not record:
			continue
		try:
			(host, keytype, key) = record.split(" ")
		except ValueError:
			if not quiet:
				print >> sys.stdout, "Print unable to read record '%s'" % record
			continue
		if "," in host:
			host = host.split(",")[0]
		if all_hosts or host in hostnames or host == hostnames:
			if not check_keytype(keytype):
				continue
			all_records.append(create_sshfp(host, keytype, key))
	if all_records:
		all_records.sort()
		return "\n".join(all_records)
	else:
		return ""

def get_record(domain, qtype):
	try:
		answers = dns.resolver.query(domain, qtype)
	except dns.resolver.NXDOMAIN:
		#print "NXdomain: "+domain
		return 0
	except dns.resolver.NoAnswer:
		#print "NoAnswer: "+domain
		return 0
	for rdata in answers:
		# just return first entry we got, answers[0].target does not work
		if qtype == "A":
			return rdata
		if qtype == "NS":
			return str(rdata.target)
		else:
			print >> sys.stderr, "ERROR: error in get_record, unknown type " + qtype
			sys.exit(1)

def get_axfr_record(domain, nameserver):
	try:
		zone = dns.zone.from_xfr(dns.query.xfr(nameserver, domain, rdtype=dns.rdatatype.AXFR))
	except dns.exception.FormError:
		raise dns.exception.FormError, domain
	else:
		return  zone

def sshfp_from_axfr(domain, nameserver):
	if " " in domain:
		print >> sys.stderr, "ERROR: space in domain '"+domain+"' can't be right, aborted"
		sys.exit(1)
	if not nameserver:
		nameserver = get_record(domain, "NS")
		if not nameserver:
			print >> sys.stderr, "WARNING: no NS record found for domain "+domain+". trying as host record instead"
			# better then nothing
			return sshfp_from_dns([domain])
	hosts = []
	#print "nameserver:" + str(ns)
	try:
		# print "trying axfr for "+domain+"@"+nameserver
		axfr = get_axfr_record(domain, nameserver)
	except dns.exception.FormError, badDomain:
		print >> sys.stderr, "AXFR error: %s - No permission or not authorative for %s; aborting" % (nameserver, badDomain)
		sys.exit(1)

	for (name, ttl, rdata) in axfr.iterate_rdatas('A'):
		#print "name:" +str(name) +", ttl:"+ str(ttl)+ ", rdata:"+str(rdata)
		if "@" in str(name): 
			hosts.append(domain + ".")
		else:
			if not str(name) == "localhost":
				hosts.append("%s.%s." % (name, domain))
	return sshfp_from_dns(hosts)

def sshfp_from_dns(hosts):
	global quiet
	global port
	global timeout
	global algos
	cmd = ["ssh-keyscan", "-p", str(port), "-T", str(timeout), "-t", ",".join(algos)] + hosts
	process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
	(stdout, stderr) = process.communicate()
	if not quiet:
		print >> sys.stderr, stderr
	khdns = stdout
	return process_records(khdns, hosts)
	
def main():
	global all_hosts
	global trailing
	global quiet
	global port
	global timeout
	global algos

	parser = optparse.OptionParser()
	parser.add_option("-k", "--knownhosts", "--known-hosts", 
			action="store",
			dest="known_hosts",
			metavar="KNOWN_HOSTS_FILE",
			default=DEFAULT_KNOWN_HOSTS_FILE,
			help="obtain public ssh keys from the known_hosts file KNOWN_HOSTS_FILE")
	parser.add_option("-s", "--scan", 
			action="store_true",
			default=False,
			dest="scan",
			help="scan the listed hosts for public keys using ssh-keyscan")
	parser.add_option("-a", "--all",
			action="store_true",
			dest="all_hosts",
			help="scan all hosts in the known_hosts file when used with -k. When used with -s, attempt a zone transfer to obtain all A records in the domain provided")
	parser.add_option("-d", "--trailing-dot",
			action="store_true",
			dest="trailing_dot",
			help="add a trailing dot to the hostname in the SSHFP records")
	parser.add_option("-o", "--output",
			action="store",
			metavar="FILENAME",
			default=None,
			help="write to FILENAME instead of stdout")
	parser.add_option("-p", "--port",
			action="store",
			metavar="PORT",
			type="int",
			default=22,
			help="use PORT for scanning")
	parser.add_option("-v", "--version",
			action="store_true",
			dest="version",
			help="print version information and exit")
	parser.add_option("-q", "--quiet",
			action="store_true",
			dest="quiet")
	parser.add_option("-T", "--timeout",
			action="store",
			type="int",
			dest="timeout",
			default=5,
			help="scanning timeout (default %default)")
	parser.add_option("-t", "--type",
			action="append",
			type="choice",
			dest="algo",
			choices=["rsa", "dsa"],
			default=[],
			help="key type to fetch (may be specified more than once, default dsa,rsa)")
	parser.add_option("-n", "--nameserver",
			action="store",
			type="string",
			dest="nameserver",
			default="",
			help="nameserver to use for AXFR (only valid with -s -a)")
	(options, args) = parser.parse_args()

	# parse options
	khfile = options.known_hosts
	dodns = options.scan
	nameserver = ""
	domain = ""
	output = options.output
	quiet = options.quiet
	data = ""
	trailing = options.trailing_dot
	timeout = options.timeout
	algos = options.algo or ["dsa", "rsa"]
	all_hosts = options.all_hosts
	port = options.port
	hostnames = ()

	if options.version:
		show_version()
		sys.exit(0)
	if not quiet and port != 22:
		print >> sys.stderr, "WARNING: non-standard port numbers are not designated in SSHFP records"
	if not quiet and options.known_hosts and options.scan:
		print >> sys.stderr, "WARNING: Ignoring -k option, -s was passwd"
	if options.nameserver and not options.scan and not options.all_hosts:
		print >> sys.stderr, "ERROR: Cannot specify -n without -s and -a"
		sys.exit(1)
	if not options.scan and options.all_hosts and args:
		print >> sys.stderr, "WARNING: -a and hosts both passed, ignoring manual host list"
	if not args and (not all_hosts):
		print >> sys.stderr, "WARNING: Assuming -a"
		all_hosts = True

	if options.scan and options.all_hosts:
		datal = []
		for arg in args:
			datal.append(sshfp_from_axfr(arg, options.nameserver))
		if not quiet:
			datal.insert(0, ";")
			datal.insert(0, "; Generated by sshfp %s from %s at %s" % (VERSION, nameserver, time.ctime()))
			datal.insert(0, ";")
		data = "\n".join(datal)
	elif options.scan:	# Scan specified hosts
		if not args:
			print >> sys.stderr, "ERROR: You asked me to scan, but didn't give any hosts to scan"
			sys.exit(1)
		data = sshfp_from_dns(args)
	else: # known_hosts
		data = sshfp_from_file(khfile, args)

	if output:
		try:
			fp = open(output, "w")
		except IOError:
			print >> sys.stderr, "ERROR: can't open '%s'' for writing" % output
			sys.exit(1)
		else:
			fp.write(data)
			fp.close()
	else:
		print data

if __name__ == "__main__":
	main()
