#!/usr/bin/env python3.5

import argparse
import ast
from collections import Iterable
from glob import iglob
from itertools import chain
from pathlib import Path
import re
import sys
import tokenize

# Get a list of all standard library modules.
sys.path.insert(0, Path(__file__).parent.absolute())
from python_standard_libs import python_standard_libs


def flatten(*things):
    """Recursively flatten iterable parts of `things`.

    For example::

      >>> sorted(flatten([1, 2, {3, 4, (5, 6)}]))
      [1, 2, 3, 4, 5, 6]

    :return: An iterator.
    """
    def _flatten(things):
        if isinstance(things, (bytes, str)):
            # String- and byte-like objects are treated as leaves; iterating
            # through either yields more of the same, each of which is also
            # iterable, and so on, until the heat-death of the universe.
            return iter((things,))
        elif isinstance(things, Iterable):
            # Recurse and merge in order to flatten nested structures.
            return chain.from_iterable(map(_flatten, things))
        else:
            # This is a leaf; return an single-item iterator so that it can be
            # chained with any others.
            return iter((things,))

    return _flatten(things)


class Pattern:

    def __init__(self, *patterns):
        super(Pattern, self).__init__()
        self.patterns = tuple(flatten(patterns))

    def _compile_pattern(self, pattern):
        for part in pattern.split("|"):
            expr = []
            for component in re.findall('([*]+|[^*]+)', part):
                if component == "*":
                    expr.append("[^.]+")
                elif component == "**":
                    expr.append(".+")
                elif component.count("*") >= 3:
                    raise ValueError(component)
                else:
                    expr.append(re.escape(component))
            yield "".join(expr)

    def compile(self):
        self._matcher = re.compile(
            r"(%s)\Z(?ms)" % "|".join(chain.from_iterable(
                map(self._compile_pattern, self.patterns))))

    def match(self, name):
        return self._matcher.match(name) is not None

    def __iter__(self):
        yield from self.patterns


ALLOWED, DENIED, INDIFFERENT = True, False, None


class Action:

    def __init__(self, *patterns):
        super(Action, self).__init__()
        self.pattern = Pattern(patterns)

    def compile(self):
        self.pattern.compile()

    def __iter__(self):
        yield from self.pattern.patterns


class Allow(Action):

    def check(self, name):
        if self.pattern.match(name):
            return ALLOWED
        else:
            return INDIFFERENT


class Deny(Action):

    def check(self, name):
        if self.pattern.match(name):
            return DENIED
        else:
            return INDIFFERENT


class Rule:

    def __init__(self, *actions):
        super(Rule, self).__init__()
        allows = {act for act in actions if isinstance(act, Allow)}
        denies = {act for act in actions if isinstance(act, Deny)}
        leftover = set(actions) - allows - denies
        if len(leftover) != 0:
            raise ValueError(
                "Expected Allow or Deny instance, got: %s"
                % ", ".join(map(repr, leftover)))
        self.allow = Allow(allows)
        self.deny = Deny(denies)

    def compile(self):
        self.allow.compile()
        self.deny.compile()

    def check(self, name):
        if self.deny.check(name) is DENIED:
            return DENIED
        elif self.allow.check(name) is ALLOWED:
            return ALLOWED
        else:
            return DENIED

    def __or__(self, other):
        if isinstance(other, self.__class__):
            return self.__class__(
                self.allow, self.deny, other.allow, other.deny)
        else:
            return self.__class__(
                self.allow, self.deny, other)


# Common module patterns.
StandardLibraries = Pattern(
    map("{0}|{0}.**".format, python_standard_libs))
TestingLibraries = Pattern(
    "django_nose|django_nose.**",
    "dns|dns.**",
    "fixtures|fixtures.**",
    "hypothesis|hypothesis.**",
    "maastesting|maastesting.**",
    "mock|mock.**",  # Deprecated: use unittest.mock.
    "nose|nose.**",
    "postgresfixture|postgresfixture.**",
    "selenium|selenium.**",
    "subunit|subunit.**",
    "testresources|testresources.**",
    "testscenarios|testscenarios.**",
    "testtools|testtools.**",
)


def files(*patterns):
    return frozenset(chain.from_iterable(
        iglob(pattern, recursive=True)
        for pattern in flatten(patterns)))


APIClient = files("src/apiclient/**/*.py")


RackControllerConfig = files(
    "src/provisioningserver/config.py")
RackControllerUtilsInit = files(
    "src/provisioningserver/utils/__init__.py")
RackControllerUtilsFS = files(
    "src/provisioningserver/utils/fs.py")
RackController = (
    files("src/provisioningserver/**/*.py") -
    files("src/provisioningserver/twisted/**/*.py") -
    RackControllerConfig - RackControllerUtilsInit -
    RackControllerUtilsFS)
RackControllerRule = Rule(
    Allow("apiclient.creds.*"),
    Allow("apiclient.maas_client.*"),
    Allow("apiclient.utils.*"),
    Allow("bson"),
    Allow("crochet|crochet.**"),
    Allow("curtin|curtin.**"),
    Allow("distro_info|distro_info.*"),
    Allow("formencode|formencode.**"),
    Allow("jsonschema|jsonschema.**"),
    Allow("lxml|lxml.**"),
    Allow("netaddr|netaddr.**"),
    Allow("netifaces|netifaces.*"),
    Allow("oauthlib.oauth1"),
    Allow("paramiko.**"),
    Allow("pexpect"),
    Allow("provisioningserver|provisioningserver.**"),
    Allow("seamicroclient|seamicroclient.**"),
    Allow("simplestreams.**"),
    Allow("tempita"),
    Allow("tftp|tftp.**"),
    Allow("twisted.**"),
    Allow("yaml"),
    Allow("zope.interface|zope.interface.**"),
    Allow(StandardLibraries),
)


RegionMigrationsSouth = files(
    "src/maasserver/migrations/south/**/*.py",
    "src/metadataserver/migrations/south/**/*.py")
RegionMigrationsBuiltin = files(
    "src/maasserver/migrations/builtin/**/*.py",
    "src/metadataserver/migrations/builtin/**/*.py")
RegionController = (
    files("src/maasserver/**/*.py", "src/metadataserver/**/*.py") -
    RegionMigrationsSouth - RegionMigrationsBuiltin)
RegionControllerRule = Rule(
    Allow("apiclient.creds.*"),
    Allow("apiclient.multipart.*"),
    Allow("apiclient.utils.*"),
    Allow("apt_pkg"),
    Allow("bson"),
    Allow("convoy|convoy.**"),
    Allow("crochet|crochet.**"),
    Allow("curtin|curtin.**"),
    Allow("distro_info|distro_info.*"),
    Allow("django|django.**"),
    Allow("docutils.core"),
    Allow("formencode|formencode.**"),
    Allow("jsonschema|jsonschema.**"),
    Allow("lxml|lxml.**"),
    Allow("maascli.utils.parse_docstring"),
    Allow("maasserver|maasserver.**"),
    Allow("metadataserver|metadataserver.**"),
    Allow("netaddr|netaddr.**"),
    Allow("oauth|oauth.**"),
    Allow("OpenSSL|OpenSSL.**"),
    Allow("petname"),
    Allow("piston3|piston3.**"),
    Allow("provisioningserver|provisioningserver.**"),
    Allow("psycopg2|psycopg2.**"),
    Allow("pytz.UTC"),
    Allow("requests|requests.**"),
    Allow("simplestreams.**"),
    Allow("tempita"),
    Allow("twisted.**"),
    Allow("yaml"),
    Allow("zope.interface|zope.interface.**"),
    Allow(StandardLibraries),
)


RegionControllerConfig = files("src/maas/**/*.py")
RegionControllerConfigRule = Rule(
    Allow("django|django.**"),
    Allow("formencode.validators.*"),
    Allow("maasserver.config.*"),
    Allow("maas|maas.**"),
    Allow("provisioningserver|provisioningserver.**"),
    Allow("psycopg2|psycopg2.**"),
    Allow(StandardLibraries),
)


Tests = files(
    "src/**/test_*.py",
    "src/**/testing/**/*.py",
    "src/**/testing.py")


TestProtractor = files(
    "src/maastesting/protractor/**/*.py")
TestProtractorRule = Rule(
    Allow("django|django.**"),
    Allow("maasserver.eventloop"),  # XXX: Dodgy!
    Allow("provisioningserver.testing.config.*"),
    Allow("twisted.scripts.twistd"),
    Allow(StandardLibraries),
    Allow(TestingLibraries),
)


TestHelpers = (
    files("src/maastesting/**/*.py") -
    TestProtractor)
TestHelpersRule = Rule(
    Allow("crochet|crochet.**"),
    Allow("django|django.**"),
    Allow("netaddr|netaddr.**"),
    Allow("twisted.**"),
    Allow(StandardLibraries),
    Allow(TestingLibraries),
)


checks = [
    #
    # API CLIENT
    #
    (
        APIClient - Tests,
        Rule(
            Allow("apiclient|apiclient.**"),
            Allow("django.utils.**"),
            Allow("oauth.oauth"),
            Allow(StandardLibraries),
        ),
    ),
    (
        APIClient & Tests,
        Rule(
            Allow("apiclient|apiclient.**"),
            Allow("django.**"),
            Allow("oauth.oauth"),
            Allow("piston3|piston3.**"),
            Allow(StandardLibraries),
            Allow(TestingLibraries),
        ),
    ),
    #
    # RACK CONTROLLER
    #
    (
        RackController - Tests,
        RackControllerRule,
    ),
    (
        RackControllerConfig | RackControllerUtilsFS,
        RackControllerRule | Allow("maastesting.root"),
    ),
    (
        RackControllerUtilsInit,
        RackControllerRule | Rule(
            Allow("maastesting.typecheck.typed"),
            Allow("provisioningserver.path.get_tentative_path"),
            Allow("tempita"),
            Allow(StandardLibraries),
        ),
    ),
    (
        RackController & Tests,
        RackControllerRule | Rule(
            Allow("apiclient.testing.credentials.make_api_credentials"),
            Allow(StandardLibraries),
            Allow(TestingLibraries),
        ),
    ),
    #
    # REGION CONTROLLER
    #
    (
        RegionController - Tests,
        RegionControllerRule,
    ),
    (
        RegionMigrationsSouth,
        RegionControllerRule | Allow("south|south.**"),
    ),
    (
        RegionMigrationsBuiltin,
        Rule(
            Allow("django|django.**"),
            Allow("netaddr|netaddr.**"),
            Allow("piston3.models"),
            Allow(StandardLibraries),

            # XXX: The following three permissions are temporary;
            # they MUST be eliminated before MAAS 2.0 release.
            Allow("maasserver|maasserver.**"),
            Allow("metadataserver|metadataserver.**"),
            Allow("provisioningserver|provisioningserver.**"),
        )
    ),
    (
        RegionController & Tests,
        RegionControllerRule | Rule(
            Allow(TestingLibraries),
        ),
    ),

    #
    # REGION CONTROLLER CONFIGURATION
    #
    (
        RegionControllerConfig - Tests,
        RegionControllerConfigRule,
    ),
    (
        RegionControllerConfig & Tests,
        RegionControllerConfigRule | Rule(
            Allow(TestingLibraries),
        ),
    ),
    #
    # TESTING HELPERS
    #
    (
        TestHelpers,
        TestHelpersRule,
    ),
    #
    # TESTING HELPERS: PROTRACTOR
    #
    (
        TestProtractor,
        TestProtractorRule,
    ),
]


def extract(node):
    for node in ast.walk(node):
        if isinstance(node, ast.Import):
            for alias in node.names:
                yield alias.name
        elif isinstance(node, ast.ImportFrom):
            for alias in node.names:
                yield "%s.%s" % (node.module, alias.name)
        else:
            pass  # Not an import.


def scan(checks):
    for files, rule in checks:
        rule.compile()
        for filename in sorted(files):
            with tokenize.open(filename) as fd:
                module = ast.parse(fd.read())
            imports = set(extract(module))
            allowed = set(filter(rule.check, imports))
            denied = imports.difference(allowed)
            yield filename, allowed, denied


if sys.stdout.isatty():

    def print_filename(filename):
        print('\x1b[36m' + filename + '\x1b[39m')

    def print_allowed(name):
        print(" \x1b[32mallowed:\x1b[39m", name)

    def print_denied(name):
        print("  \x1b[31mdenied:\x1b[39m", name)

else:

    def print_filename(filename):
        print(filename)

    def print_allowed(name):
        print(" allowed:", name)

    def print_denied(name):
        print("  denied:", name)


def main(args):
    parser = argparse.ArgumentParser(
        description="Statically check imports against policy.",
        add_help=False)
    parser.add_argument(
        "--show-allowed", action="store_true", dest="show_allowed",
        help="Log allowed imports.")
    parser.add_argument(
        "--hide-denied", action="store_false", dest="show_denied",
        help="Log denied imports.")
    parser.add_argument(
        "--hide-summary", action="store_false", dest="show_summary",
        help="Show summary of allowed and denied imports.")
    parser.add_argument(
        "-h", "--help", action="help", help=argparse.SUPPRESS)
    options = parser.parse_args(args)

    allowedcount, deniedcount = 0, 0
    for filename, allowed, denied in scan(checks):
        allowedcount += len(allowed)
        deniedcount += len(denied)
        show_filename = (
            (options.show_allowed and len(allowed) != 0) or
            (options.show_denied and len(denied) != 0))
        if show_filename:
            print_filename(filename)
        if options.show_allowed and len(allowed) != 0:
            for imported_name in sorted(allowed):
                print_allowed(imported_name)
        if options.show_denied and len(denied) != 0:
            for imported_name in sorted(denied):
                print_denied(imported_name)

    if options.show_summary:
        if (options.show_allowed and allowedcount != 0):
            print()
        elif (options.show_denied and deniedcount != 0):
            print()

        print(allowedcount, "imported names were ALLOWED.")
        print(deniedcount, "imported names were DENIED.")

    return deniedcount == 0


if __name__ == '__main__':
    raise SystemExit(0 if main(sys.argv[1:]) else 1)
