#!/usr/bin/env ruby
#
# kumofs
#
# Copyright (C) 2009 FURUHASHI Sadayuki
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
#

if ARGV.length != 1
	puts "usage: #{$0} <proto file>.proto.h"
	exit 1
end

srcpath = ARGV[0]
$dirname = File.dirname(srcpath)
$genname = File.basename(srcpath, ".proto.h")

src = File.read(srcpath)



AUTOGEN_WARN = <<EOW
/* This is generated file. */
EOW

require 'strscan'


def include_guard(f, fname, &block)
	name = "PROTOGEN_#{$genname.upcase}_#{File.basename(fname).upcase.gsub(".","_")}__"
	f.puts %[#ifndef #{name}]
	f.puts %[#define #{name}]
	yield
	f.puts %[#endif]
end


class RPCMessage
	def initialize(name, id = nil)
		@name = name
		@id = id
		@versions = {}
		@code = ""
		@mode_cluster = false
	end
	attr_reader :name
	attr_reader :versions
	attr_accessor :id
	attr_accessor :mode_cluster

	def versions
		@versions.keys
	end

	def add_version(version, members)
		version = version.to_i if version
		if @versions.include?(version)
			raise "duplicated declaration of #{@name}.#{version}"
		end
		@versions[version] = members
	end

	def append_code(code)
		@code << code
	end

	def check
		if @id == nil && !@versions.empty?
			raise "message id for #{@name} is not declared"
		end
	end

	def generate(f)
		@versions.to_a.sort_by do |ver, members|
			ver ? ver : 0
		end.each do |ver, members|
			defaults = members.select {|type, var, default| default }
			def members.xjoin(between, after, join)
				map {|a, b| "#{a}#{between}#{b}#{after}" }.join(join)
			end
			if ver
				struct = "#{@name}_#{ver}"
			else
				struct = @name
			end
			f.puts %[	struct #{struct} : rpc::message<#{@id}, #{ver ? ver : 0}, rpc::#{@mode_cluster ? "node" : "basic_session"}> {]
			f.puts %[		#{members.xjoin(" ",";","\n\t\t")}]
			if defaults.empty?
				f.puts %[		void msgpack_unpack(const msgpack::type::tuple<
				#{members.map {|type, var|type}.join(", ")} >& args)
		{
			#{i=-1; members.map {|type, var| "this->#{var} = args.get<#{i+=1}>();" }.join("\n\t\t\t")}
		}]
			else
				f.puts %[		void msgpack_unpack(msgpack::object o)
		{]
				if members.length > defaults.length
					f.puts %[			const msgpack::type::tuple<#{members.map {|type, var|type}.join(", ")} >& args(o.convert());]
				else
					f.puts %[			if(o.type != msgpack::type::ARRAY) { throw msgpack::type_error(); }]
				end
				f.puts %[			#{i=-1; members.map {|type, var, default|
				unless default
					"this->#{var} = args.get<#{i+=1}>();"
				else
					"if(o.via.array.size > #{i+=1}) { o.via.array.ptr[#{i}].convert(&this->#{var}); }"
				end
			}.join("\n\t\t\t")}
		}]
			end

			f.puts %[
		template <typename Packer>
		void msgpack_pack(Packer& pk) const
		{
			pk.pack_array(#{members.size});
			#{members.map {|type, var| "pk.pack(#{var});" }.join("\n\t\t\t")}
		}]

		f.puts %[		typedef rpc::retry<#{struct}> retry;]

		unless members.empty?
			f.puts %[		#{struct}(#{members.map {|type, var, default|
				unless default
					"\n\t\t\t\tconst #{type}& #{var}_"
				else
					"\n\t\t\t\t#{type} #{var}_ = #{default}"
				end
			}.join(",")}) :]
			f.puts %[				#{members.map {|type, var, default| "#{var}(#{var}_)" }.join(", ")} { }]
		end

		if defaults.empty?
			f.puts %[		#{struct}() { }]
		elsif members.length > defaults.length
			f.puts %[		#{struct}() :]
			f.puts %[				#{defaults.map {|type, var, default| "#{var}(#{default})"}.join(", ")} { }]
		end

		f.puts %[
		#{@code}
	};]
			f.puts %[	void rpc_#{struct}(rpc::request<#{struct}>&, rpc::auto_zone z,
			rpc::weak_responder);]
			f.puts %[]
		end
	end
end


class RPCHeader
	def initialize(name, pre_body)

		@name = name
		@trails = []

		@messages = {}

		@pre_body = pre_body
		@post_body = ""

		@code = ""
	end

	def add_trails(trails)
		@trails.concat(trails)
	end

	def message(name, id = nil)
		msg = @messages[name]
		unless msg
			return @messages[name] = RPCMessage.new(name, id ? id.to_i : nil)
		end
		if id
			if msg.id
				raise "duplicated declaration of message '#{@name}::#{name}'"
			end
			msg.id = id.to_i
		end
		msg
	end

	def <<(body)
		if @messages.find {|name, msg| !msg.versions.empty? }
			@post_body << body
		else
			@pre_body << body
		end
	end

	def append_code(code)
		@code << code
	end

	def check
		@messages.each_pair {|name, msg|
			msg.check
		}
	end

	def generate(file)
		File.open(file, "w") {|f|
			include_guard(f, file) {
				f.puts AUTOGEN_WARN
				f.puts @pre_body
				f.puts %[class #{@name} #{@trails.join(' ').strip} {]
				f.puts %[public:]
				@messages.each_pair {|name, msg|
					msg.generate(f)
				}
				f.puts @code.strip
				f.puts %[};]
				f.puts @post_body
			}
		}
	end

	def each_message(&block)
		@messages.each_pair {|name, message|
			block.call(message)
		}
	end
end


class ProtocolNumber
	def initialize(pre_body)
		@pre_body = pre_body
		@post_body = ""
	end

	def <<(body)
		@post_body << body
	end

	def check
	end

	def generate(file, headers)
		File.open(file, "w") {|f|
			include_guard(f, file) {
				f.puts @pre_body
				f.puts %[namespace proto {]
				headers.each_pair {|name, header|
					f.puts %[	enum #{name} {]
					header.each_message {|msg|
						f.puts %[		#{msg.name} = #{msg.id},]
					}
					f.puts %[	};]
				}
				f.puts %[}]
				f.puts @post_body
			}
		}
	end
end


# FIXME
class Dispatch
	def initialize(name, scopes)
		@name = name
		@scopes = scopes
	end

	def check(headers)
		@headers = @scopes.map {|scope|
			unless header = headers[scope]
				raise "dispatch scope '#{scope}' is not declared"
			end
			header
		}
	end

	def generate(file)
		File.open(file, "w") {|f|
			f.puts(AUTOGEN_WARN)
			@headers.each {|header|
				header.each_message {|msg|
					msg.versions.each {|ver|
						f.puts %[case #{msg.id << 16 | (ver ? ver : 0)}:]
						f.puts %[	break;]
					}
				}
			}
		}
	end
end

class OutputFiles
	def initialize
		@pre_body = ""
		@headers = {}
		@dispatches = {}
		@protonum = nil
	end

	def <<(body)
		@pre_body << body
		@headers.each_pair {|name, header|
			header << body
		}
		if @protonum
			@protonum << body
		end
	end

	def rpc_header(name, trails = [])
		@protonum ||= ProtocolNumber.new(@pre_body.dup)
		header = @headers[name] ||= RPCHeader.new(name, @pre_body.dup)
		header.add_trails(trails)
		header
	end

	def rpc_dispatch(name, headers)
		if @dispatches.include?(name)
			raise "duplicated declaration of #{name} dispatch"
		end
		@dispatches[name] = Dispatch.new(name, headers)
	end

	def check
		@headers.each_pair {|name, header|
			header.check
		}
		@dispatches.each_pair {|name, dispatch|
			dispatch.check(@headers)
		}
	end

	def generate(dir)
		@headers.each_pair {|name, header|
			fname = name.sub(/_t$/, "")
			header.generate(File.join(dir, "#{fname}.h"))
		}
		@dispatches.each_pair {|name, dispatch|
			fname = name.sub(/_t$/, "")
			dispatch.generate(File.join(dir, "#{fname}_dispatch.h"))
		}
		if @protonum
			@protonum.generate(File.join(dir, "proto.h"), @headers)
		end
	end
end


$body = OutputFiles.new


s = StringScanner.new(src)

while body_len = s.exist?(/@(rpc|c?message|dispatch|code)\b/)
	$body << s.peek(body_len - s.matched_size)
	s.pos += body_len

	case s[1]
	when "rpc"
		line = s.scan_until(/\n/)
		scope, *trails = line.strip.split(/\s/)

		header = $body.rpc_header(scope, trails)

		code = s.scan_until(/@end\s\n?/)
		raise "'@end' mismatch" unless code
		code.slice!(-s.matched_size, s.matched_size)

		code.gsub!(/\bmessage\b\s+(\w+)(?:[\._](\d+))?(\s*\+cluster\b)?(?:\s*\=\s*(\d+)\s*)?\s*\{(.*?)\}\;/m) do |match|
			msgname = $~[1]
			version = $~[2]

			mode_cluster = $~[3]

			msgid = $~[4]
			message = header.message(msgname, msgid)

			message.mode_cluster = true if mode_cluster
	
			decl = $~[5]
			members = []
			decl.gsub!(/\b(\S+)\s+(\w+)(?:\s+\=\s+(\w+))?\s*\;/) {|match|
				type    = $~[1]
				var     = $~[2]
				default = $~[3]
				unless default
					members.each {|t, v, d|
						if d
							raise "unexpected non-default parameter"
						end
					}
				end
				members.push [type, var, default]
				""
			}
			decl.strip!

			message.add_version(version, members)
			message.append_code(decl)

			""
		end

		code.gsub!(/\bmessage\b\s+(\w+)\s*\=\s*(\d+)\s*\;\s*/m) do |match|
			msgname = $~[1]
			msgid  = $~[2]

			message = header.message(msgname, msgid)

			""
		end

		header.append_code(code)


	when "message", "cmessage"
		line = s.scan_until(/(:?\;|\n)\s*/)

		unless match = /\A\s*(\w+)\:\:(\w+)\s*\=\s*(\d+)\s*?(:?\;|\n)\s*\z/.match(line)
			raise "invalid @message line '#{line}'"
		end

		scope = match[1]
		msgname = match[2]
		msgid = match[3]

		header = $body.rpc_header(scope)
		header.message(msgname, msgid)


	when "dispatch"
		line = s.scan_until(/;\s*/)

		unless match = /\A\s*\(\s*(\w+)\s*\)\s+(\w+)\s*(?:,\s*(\w+))*\s*\;\s*\z/.match(line)
			raise "invalid @dispatch line '#{line}'"
		end
		captures = match.captures.compact

		name = captures.shift
		scopes = captures

		$body.rpc_dispatch(name, scopes)


	when "code"
		line = s.scan_until(/\n/)
		scopes = line.strip.split(/\s/)

		code = s.scan_until(/@end\s\n?/)
		raise "'@end' mismatch" unless code
		code.slice!(-s.matched_size, s.matched_size)

		scopes.each {|scope|
			header = $body.rpc_header(scope)
			header << code
		}

	end
end

$body << s.rest

$body.check
$body.generate(File.join($dirname,$genname))

