#!/usr/bin/env python3
import zipfile
import sys
import re
import xml.dom.minidom as xmd
import xml.dom as xdom
import argparse
import logging

log = logging.getLogger("wrep-importtable")


# This seems to change from time to time, you may be better off downloading zip
# files normally. Note: the version 17 zip file was published without XML files
# inside.
# The tables are currently (as of 2014-08-01) published here:
# http://www.wmo.int/pages/prog/www/WMOCodes/WMO306_vI2/LatestVERSION/LatestVERSION.html
# http://www.wmo.int/pages/prog/www/WMOCodes/WMO306_vI2/PrevVERSIONS/PreviousVERSIONS.html
DOWNLOAD_URL = "http://www.wmo.int/pages/prog/www/WMOCodes/WMO306_vI2/LatestVERSION/latestTables.zip"

BUFR_UNIT_MAP = {"CCITT IA5": "CCITTIA5"}
CREX_UNIT_MAP = {}


class Fail(BaseException):
    pass


def normalise_bufr_unit(unit):
    """
    Normalise BUFR unit
    """
    unit = unit.upper()
    return BUFR_UNIT_MAP.get(unit, unit)


def normalise_crex_unit(unit):
    """
    Normalise CREX unit
    """
    unit = unit.upper()
    return CREX_UNIT_MAP.get(unit, unit)


def read_text(el):
    """
    Return the concatenation of all text nodes that are direct children of el
    """
    res = []
    for t in el.childNodes:
        if t.nodeType != xdom.Node.TEXT_NODE:
            continue
        if not t.data:
            continue
        res.append(t.data)
    return "".join(res)


def read_mapping(node):
    """
    Convert a DOM node in the form <foo>bar</foo> into a (foo, bar) couple
    """
    return node.nodeName, read_text(node)


class Table(object):
    def __init__(self, v, sv, lv, text=None):
        # Version
        self.v = v
        # Subversion
        self.sv = sv
        # Local version
        self.lv = lv
        if text is not None:
            self.parse(text)


class TableB(Table):
    def parse_xml(self, text):
        """
        Read an XML B table from the given file descriptor and yield DICTs with all
        the <key>value</key> elements of B table records
        """
        dom = xmd.parseString(text)
        for n in dom.documentElement.childNodes:
            if n.nodeType != xdom.Node.ELEMENT_NODE:
                continue
            if "TableB" not in n.nodeName:
                continue
            data = dict()
            for x in n.childNodes:
                if x.nodeType != xdom.Node.ELEMENT_NODE:
                    continue
                k, v = read_mapping(x)
                data[k] = v
            yield data

    def parse(self, text):
        self.entries = sorted(self.parse_xml(text), key=lambda x: x["FXY"])

    def make_xml_b_lines(self):
        """
        Open info["fd_b"] as XML and build a BUFR table file with its contents,
        yielding the formatted table lines
        """
        for data in self.entries:
            el_name = data.get("ElementName_en", None)
            if el_name is None:
                el_name = data["ElementName_E"]
            if "CREX_Unit" not in data or "CREX_Scale" not in data or "CREX_DataWidth_Char" not in data:
                yield " %-6.6s %-64.64s %-24.24s %3d %12d %3d" % (
                    data["FXY"],
                    el_name,
                    normalise_bufr_unit(data["BUFR_Unit"]),
                    int(data["BUFR_Scale"]),
                    int(data["BUFR_ReferenceValue"]),
                    int(data["BUFR_DataWidth_Bits"]))
            else:
                yield " %-6.6s %-64.64s %-24.24s %3d %12d %3d %-24.24s %2d %9d" % (
                    data["FXY"],
                    el_name,
                    normalise_bufr_unit(data["BUFR_Unit"]),
                    int(data["BUFR_Scale"]),
                    int(data["BUFR_ReferenceValue"]),
                    int(data["BUFR_DataWidth_Bits"]),
                    normalise_crex_unit(data["CREX_Unit"]),
                    int(data["CREX_Scale"]),
                    int(data["CREX_DataWidth_Char"]))

    def write_files(self):
        out1 = open("B0000000000000{:03d}{:03d}.txt".format(self.v, self.sv), "w")
        log.info("Writing %s...", out1.name)
        out2 = open("B00{:02d}{:02d}.txt".format(self.v, self.sv), "w")
        log.info("Writing %s...", out2.name)
        try:
            for line in self.make_xml_b_lines():
                print(line, file=out1)
                print(line, file=out2)
        finally:
            out1.close()
            out2.close()


class TableD(Table):
    def parse(self, text):
        entries_dict = self.parse_xml(text)
        self.entries = sorted(x for x in entries_dict.items() if x[1])

    def parse_xml(self, text):
        """
        Read an XML D table from the given file descriptor and return a dict
        mapping D codes to lists of B and D codes
        """
        dom = xmd.parseString(text)
        res = dict()
        for n in dom.documentElement.childNodes:
            if n.nodeType != xdom.Node.ELEMENT_NODE:
                continue
            if "TableD" not in n.nodeName:
                continue
            # This is an odd way to encode a D table: each element maps a D code to
            # one element of its expansion, instead of mapping a D code to a list
            # of expanded codes
            fxy1 = None
            fxy2 = None
            for x in n.childNodes:
                if x.nodeType != xdom.Node.ELEMENT_NODE:
                    continue
                if x.nodeName == "FXY1":
                    fxy1 = read_text(x).strip()
                if x.nodeName == "FXY2":
                    fxy2 = read_text(x).strip()
            if fxy1 is not None and fxy2 is not None:
                res.setdefault(fxy1, []).append(fxy2)
        return res

    def output_lines(self):
        """
        Yield the formatted table lines for D table entries in dict d
        """
        for key, vals in self.entries:
            yield " %-6.6s %2d %-6.6s" % (key, len(vals), vals[0])
            for v in vals[1:]:
                yield "           %-6.6s" % v


class TableDBUFR(TableD):
    def write_files(self):
        out = open("D0000000000000{:03d}{:03d}.txt".format(self.v, self.sv), "w")
        log.info("Writing %s...", out.name)
        try:
            for line in self.output_lines():
                print(line, file=out)
        finally:
            out.close()


class TableDCREX(TableD):
    def write_files(self):
        out = open("D00{:02d}{:02d}.txt".format(self.v, self.sv), "w")
        log.info("Writing %s...", out.name)
        try:
            for line in self.output_lines():
                print(line, file=out)
        finally:
            out.close()


class Tables(object):
    re_fname_b = re.compile(r"^BUFRCREX[^/]+/BUFRCREX_(?P<v>\d+)_(?P<sv>\d+)_(?P<lv>\d+)_TableB_(?:E|en).xml$")
    re_fname_bufr_d = re.compile(r"BUFRCREX[^/]+/BUFR_(?P<v>\d+)_(?P<sv>\d+)_(?P<lv>\d+)_TableD_(?:E|en).xml$")
    re_fname_crex_d = re.compile(r"BUFRCREX[^/]+/CREX_(?P<v>\d+)_(?P<sv>\d+)_(?P<lv>\d+)_TableD_(?:E|en).xml$")

    def __init__(self):
        self.table_b = None
        self.table_d_bufr = None
        self.table_d_crex = None

    def read_zip(self, infd):
        """
        Read the XML table files from inside a zip file
        """
        zf = zipfile.ZipFile(infd, "r")

        def read_version(mo):
            return int(mo.group("v")), int(mo.group("sv")), int(mo.group("lv"))
        try:
            zf.testzip()
            for i in zf.infolist():
                mo = self.re_fname_b.match(i.filename)
                if mo is not None:
                    log.info("Processing BUFR/CREX B table %s", i.filename)
                    fd = zf.open(i, "r")
                    self.table_b = TableB(*read_version(mo), text=fd.read())
                    fd.close()
                    continue
                mo = self.re_fname_bufr_d.match(i.filename)
                if mo is not None:
                    log.info("Processing BUFR D table %s", i.filename)
                    fd = zf.open(i, "r")
                    self.table_d_bufr = TableDBUFR(*read_version(mo), text=fd.read())
                    fd.close()
                    continue
                mo = self.re_fname_crex_d.match(i.filename)
                if mo is not None:
                    log.info("Processing CREX D table %s", i.filename)
                    fd = zf.open(i, "r")
                    self.table_d_crex = TableDCREX(*read_version(mo), text=fd.read())
                    fd.close()
                    continue
                log.info("Skipping %s", i.filename)
        finally:
            zf.close()

    def write_files(self):
        self.table_b.write_files()
        self.table_d_bufr.write_files()
        self.table_d_crex.write_files()


def main():
    parser = argparse.ArgumentParser(
            description="builds ECMWF-style BUFR/CREX table files from zipped WMO XML table information."
                        " Files are written in the current directory."
                        " New files can be found at "
                        " https://community.wmo.int/activity-areas/wis/latest-version"
                        # " http://www.wmo.int/pages/prog/www/WMOCodes/WMO306_vI2/LatestVERSION/LatestVERSION.html"
            )
    parser.add_argument("--verbose", "-v", action="store_true", help="verbose output")
    parser.add_argument("--debug", action="store_true", help="debug output")
    parser.add_argument("--zipfile", action="store", help="zip file to open")
    parser.add_argument("--btable", action="store", help="xml file to open for B table")
    parser.add_argument("--dtable", action="store", help="xml file to open for D table")
    args = parser.parse_args()

    log_format = "%(asctime)-15s %(levelname)s %(message)s"
    level = logging.WARN
    if args.debug:
        level = logging.DEBUG
    elif args.verbose:
        level = logging.INFO
    logging.basicConfig(level=level, stream=sys.stderr, format=log_format)

    tables = Tables()
    if args.zipfile:
        infd = open(args.zipfile, "rb")
        try:
            tables.read_zip(infd)
        finally:
            infd.close()
        tables.write_files()
    elif args.btable:
        with open(args.btable, "rt") as fd:
            table = TableB(255, 255, 255, text=fd.read())
            table.write_files()
    elif args.dtable:
        with open(args.dtable, "rt") as fd:
            table = TableDBUFR(255, 255, 255, text=fd.read())
            table.write_files()

    log.info("Done. You can copy the table files to the table directory (see wrep --info for its location).")


if __name__ == "__main__":
    try:
        main()
    except Fail as e:
        print(e, file=sys.stderr)
        sys.exit(1)
    except Exception:
        log.exception("uncaught exception")
