"""
Evaluate model with non-polymer/small molecule ligands against reference.

Example: ost compare-ligand-structures \\
    -m model.pdb \\
    -ml ligand.sdf \\
    -r reference.cif \\
    --lddt-pli --rmsd

Structures of polymer entities (proteins and nucleotides) can be given in PDB
or mmCIF format. If the structure is given in mmCIF format, only the asymmetric
unit (AU) is used for scoring.

Ligands can be given as path to SDF files containing the ligand for both model
(--model-ligands/-ml) and reference (--reference-ligands/-rl). If omitted,
ligands will be detected in the model and reference structures. For structures
given in mmCIF format, this is based on the annotation as "non polymer entity"
(i.e. ligands in the _pdbx_entity_nonpoly mmCIF category) and works reliably.
For structures given in PDB format, this is based on the HET records and is
normally not what you want. You should always give ligands as SDF for
structures in PDB format.

Polymer/oligomeric ligands (saccharides, peptides, nucleotides) are not
supported.

Only minimal cleanup steps are performed (remove hydrogens, and for structures
of polymers only, remove unknown atoms and cleanup element column).

Ligands in mmCIF and PDB files must comply with the PDB component dictionary
definition, and have properly named residues and atoms, in order for
ligand connectivity to be loaded correctly. Ligands loaded from SDF files
are exempt from this restriction, meaning any arbitrary ligand can be assessed.

Output is written in JSON format (default: out.json). In case of no additional
options, this is a dictionary with three keys:

 * "model_ligands": A list of ligands in the model. If ligands were provided
   explicitly with --model-ligands, elements of the list will be the paths to
   the ligand SDF file(s). Otherwise, they will be the chain name, residue
   number and insertion code of the ligand, separated by a dot.
 * "reference_ligands": A list of ligands in the reference. If ligands were
   provided explicitly with --reference-ligands, elements of the list will be
   the paths to the ligand SDF file(s). Otherwise, they will be the chain name,
   residue number and insertion code of the ligand, separated by a dot.
 * "status": SUCCESS if everything ran through. In case of failure, the only
   content of the JSON output will be \"status\" set to FAILURE and an
   additional key: "traceback".

Each score is opt-in and, be enabled with optional arguments and is added
to the output. Keys correspond to the values in "model_ligands" above.
Only assigned (mapped) ligands are reported.
"""

import argparse
import json
import os
import traceback

import ost
from ost import io
from ost.mol.alg import ligand_scoring


def _ParseArgs():
    parser = argparse.ArgumentParser(description = __doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter,
                                     prog="ost compare-ligand-structures")

    parser.add_argument(
        "-m",
        "--mdl",
        "--model",
        dest="model",
        required=True,
        help=("Path to model file."))

    parser.add_argument(
        "-ml",
        "--mdl-ligands",
        "--model-ligands",
        dest="model_ligands",
        nargs="*",
        default=None,
        help=("Path to model ligand files."))

    parser.add_argument(
        "-r",
        "--ref",
        "--reference",
        dest="reference",
        required=True,
        help=("Path to reference file."))

    parser.add_argument(
        "-rl",
        "--ref-ligands",
        "--reference-ligands",
        dest="reference_ligands",
        nargs="*",
        default=None,
        help=("Path to reference ligand files."))

    parser.add_argument(
        "-o",
        "--out",
        "--output",
        dest="output",
        default="out.json",
        help=("Output file name. The output will be saved as a JSON file. "
              "default: out.json"))

    parser.add_argument(
        "-mf",
        "--mdl-format",
        "--model-format",
        dest="model_format",
        default="auto",
        choices=["pdb", "mmcif", "cif"],
        help=("Format of model file. Inferred from path if not given."))

    parser.add_argument(
        "-rf",
        "--reference-format",
        "--ref-format",
        dest="reference_format",
        default="auto",
        choices=["pdb", "mmcif", "cif"],
        help=("Format of reference file. Inferred from path if not given."))

    parser.add_argument(
        "-ft",
        "--fault-tolerant",
        dest="fault_tolerant",
        default=False,
        action="store_true",
        help=("Fault tolerant parsing."))

    parser.add_argument(
        "-rna",
        "--residue-number-alignment",
        dest="residue_number_alignment",
        default=False,
        action="store_true",
        help=("Make alignment based on residue number instead of using "
              "a global BLOSUM62-based alignment (NUC44 for nucleotides)."))

    parser.add_argument(
        "-ec",
        "--enforce-consistency",
        dest="enforce_consistency",
        default=False,
        action="store_true",
        help=("Enforce consistency of residue names between the reference "
              "binding site and the model. By default residue name "
              "discrepancies are reported but the program proceeds. "
              "If this is set to True, the program will fail with an error "
              "message if the residues names differ. "
              "Note: more binding site mappings may be explored during "
              "scoring, but only inconsistencies in the selected mapping are "
              "reported."))

    parser.add_argument(
        "-sm",
        "--substructure-match",
        dest="substructure_match",
        default=False,
        action="store_true",
        help=("Allow incomplete target ligands."))

    parser.add_argument(
        "-gcm",
        "--global-chain-mapping",
        dest="global_chain_mapping",
        default=False,
        action="store_true",
        help=("Use a global chain mapping."))

    parser.add_argument(
        "-c",
        "--chain-mapping",
        nargs="+",
        dest="chain_mapping",
        help=("Custom mapping of chains between the reference and the model. "
              "Each separate mapping consist of key:value pairs where key "
              "is the chain name in reference and value is the chain name in "
              "model. Only has an effect if global-chain-mapping flag is set."))

    parser.add_argument(
        "-ra",
        "--rmsd-assignment",
        dest="rmsd_assignment",
        default=False,
        action="store_true",
        help=("Use RMSD for ligand assignment."))

    parser.add_argument(
        "--lddt-pli",
        dest="lddt_pli",
        default=False,
        action="store_true",
        help=("Compute lDDT-PLI score and store as key \"lddt-pli\"."))

    parser.add_argument(
        "--rmsd",
        dest="rmsd",
        default=False,
        action="store_true",
        help=("Compute RMSD score and store as key \"rmsd\"."))

    parser.add_argument(
        "--radius",
        dest="radius",
        default=4.0,
        help=("Inclusion radius for the binding site. Any residue with atoms "
              "within this distance of the ligand will be included in the "
              "binding site."))

    parser.add_argument(
        "--lddt-pli-radius",
        dest="lddt_pli_radius",
        default=6.0,
        help=("lDDT inclusion radius for lDDT-PLI."))

    parser.add_argument(
        "--lddt-lp-radius",
        dest="lddt_lp_radius",
        default=4.0,
        help=("lDDT inclusion radius for lDDT-LP."))

    parser.add_argument(
        '-v',
        '--verbosity',
        dest="verbosity",
        type=int,
        default=3,
        help="Set verbosity level. Defaults to 3 (INFO).")

    parser.add_argument(
        "--n-max-naive",
        dest="n_max_naive",
        required=False,
        default=12,
        type=int,
        help=("If number of chains in model and reference are below or equal "
              "that number, the global chain mapping will naively enumerate "
              "all possible mappings. A heuristic is used otherwise."))

    return parser.parse_args()


def _LoadStructure(structure_path, format="auto", fault_tolerant=False):
    """Read OST entity either from mmCIF or PDB.

    The returned structure has structure_path attached as structure name
    """

    if not os.path.exists(structure_path):
        raise Exception(f"file not found: {structure_path}")

    if format == "auto":
        # Determine file format from suffix.
        ext = structure_path.split(".")
        if ext[-1] == "gz":
            ext = ext[:-1]
        if len(ext) <= 1:
            raise Exception(f"Could not determine format of file "
                            f"{structure_path}.")
        format = ext[-1].lower()

    # Load the structure
    if format in ["mmcif", "cif"]:
        entity = io.LoadMMCIF(structure_path, fault_tolerant=fault_tolerant)
        if len(entity.residues) == 0:
            raise Exception(f"No residues found in file: {structure_path}")
    elif format == "pdb":
        entity = io.LoadPDB(structure_path, fault_tolerant=fault_tolerant)
        if len(entity.residues) == 0:
            raise Exception(f"No residues found in file: {structure_path}")
    else:
        raise Exception(f"Unknown/ unsupported file extension found for "
                        f"file {structure_path}.")

    # Remove hydrogens
    cleaned_entity = ost.mol.CreateEntityFromView(entity.Select("ele != H"),
                                                  include_exlusive_atoms=False)
    cleaned_entity.SetName(structure_path)
    return cleaned_entity


def _LoadLigands(ligands):
    """
    Load a list of ligands from file names. Return a list of entities oif the
    same size.
    """
    if ligands is None:
        return None
    else:
        return [_LoadLigand(lig) for lig in ligands]


def _LoadLigand(file):
    """
    Load a single ligand from file names. Return an entity.
    """
    ligand_ent = ost.io.LoadEntity(file, format="sdf")
    ligand_view = ligand_ent.Select("ele != H")
    return ost.mol.CreateEntityFromView(ligand_view, False)


def _Validate(structure, ligands, legend, fault_tolerant=False):
    """Validate the structure.

    If fault_tolerant is True, only warns in case of problems. If False,
    raise them as ValueErrors.

    At the moment this chiefly checks for ligands in polymers
    """
    if ligands is not None:
        for residue in structure.residues:
            if residue.is_ligand:
                msg = "Ligand residue %s found in %s polymer structure" %(
                    residue.qualified_name, legend)
                if fault_tolerant:
                    ost.LogWarning(msg)
                else:
                    raise ValueError(msg)


def _QualifiedResidueNotation(r):
    """Return a parsable string of the residue in the format:
    ChainName.ResidueNumber.InsertionCode."""
    resnum = r.number
    return "{cname}.{rnum}.{ins_code}".format(
        cname=r.chain.name,
        rnum=resnum.num,
        ins_code=resnum.ins_code.strip("\u0000"),
    )


def _Process(model, model_ligands, reference, reference_ligands, args):

    mapping = None
    if args.chain_mapping is not None:
        mapping = {x.split(':')[0]: x.split(':')[1] for x in args.chain_mapping}

    scorer = ligand_scoring.LigandScorer(
        model=model,
        target=reference,
        model_ligands=model_ligands,
        target_ligands=reference_ligands,
        resnum_alignments=args.residue_number_alignment,
        check_resnames=args.enforce_consistency,
        rename_ligand_chain=True,
        substructure_match=args.substructure_match,
        global_chain_mapping=args.global_chain_mapping,
        rmsd_assignment=args.rmsd_assignment,
        radius=args.radius,
        lddt_pli_radius=args.lddt_pli_radius,
        lddt_lp_radius=args.lddt_lp_radius,
        n_max_naive=args.n_max_naive,
        custom_mapping=mapping
    )

    out = dict()

    if model_ligands is not None:
        # Replace model ligand by path
        if len(model_ligands) == len(scorer.model_ligands):
            # Map ligand => path
            model_ligands_map = {k: v for k, v in zip(scorer.model_ligands,
                                                      args.model_ligands)}
            out["model_ligands"] = args.model_ligands
        elif len(model_ligands) < len(scorer.model_ligands):
            # Multi-ligand SDF files were given
            # Map ligand => path:idx
            out["model_ligands"] = []
            for ligand, filename in zip(model_ligands, args.model_ligands):
                assert isinstance(ligand, ost.mol.EntityHandle)
                for i, residue in enumerate(ligand.residues):
                    out["model_ligands"].append(f"{filename}:{i}")
        else:
            # This should never happen and would be a bug
            raise RuntimeError("Fewer ligands in the model scorer "
                               "(%d) than given (%d)" % (
                len(scorer.model_ligands), len(model_ligands)))
    else:
        model_ligands_map = {l: _QualifiedResidueNotation(l)
                             for l in scorer.model_ligands}
        out["model_ligands"] = list(model_ligands_map.values())

    if reference_ligands is not None:
        # Replace reference ligand by path
        if len(reference_ligands) == len(scorer.target_ligands):
            # Map ligand => path
            reference_ligands_map = {k: v for k, v in zip(scorer.target_ligands,
                                                      args.reference_ligands)}
            out["reference_ligands"] = args.reference_ligands
        elif len(reference_ligands) < len(scorer.target_ligands):
            # Multi-ligand SDF files were given
            # Map ligand => path:idx
            out["reference_ligands"] = []
            for ligand, filename in zip(reference_ligands, args.reference_ligands):
                assert isinstance(ligand, ost.mol.EntityHandle)
                for i, residue in enumerate(ligand.residues):
                    out["reference_ligands"].append(f"{filename}:{i}")
        else:
            # This should never happen and would be a bug
            raise RuntimeError("Fewer ligands in the reference scorer "
                               "(%d) than given (%d)" % (
                len(scorer.target_ligands), len(reference_ligands)))

    else:
        reference_ligands_map = {l: _QualifiedResidueNotation(l)
                             for l in scorer.target_ligands}
        out["reference_ligands"] = list(reference_ligands_map.values())

    if args.lddt_pli:
        out["lddt_pli"] = {}
        for chain, lddt_pli_results in scorer.lddt_pli_details.items():
            for _, lddt_pli in lddt_pli_results.items():
                model_key = model_ligands_map[lddt_pli["model_ligand"]]
                lddt_pli["reference_ligand"] = reference_ligands_map[
                    lddt_pli.pop("target_ligand")]
                lddt_pli["model_ligand"] = model_key
                transform_data = lddt_pli["transform"].data
                lddt_pli["transform"] = [transform_data[i:i + 4]
                                         for i in range(0, len(transform_data),
                                                        4)]
                lddt_pli["bs_ref_res"] = [_QualifiedResidueNotation(r) for r in
                                          lddt_pli["bs_ref_res"]]
                lddt_pli["bs_ref_res_mapped"] = [_QualifiedResidueNotation(r) for r in
                                                 lddt_pli["bs_ref_res_mapped"]]
                lddt_pli["bs_mdl_res_mapped"] = [_QualifiedResidueNotation(r) for r in
                                                 lddt_pli["bs_mdl_res_mapped"]]
                lddt_pli["inconsistent_residues"] = ["%s-%s" %(
                    _QualifiedResidueNotation(x), _QualifiedResidueNotation(y)) for x,y in lddt_pli[
                    "inconsistent_residues"]]
                out["lddt_pli"][model_key] = lddt_pli

    if args.rmsd:
        out["rmsd"] = {}
        for chain, rmsd_results in scorer.rmsd_details.items():
            for _, rmsd in rmsd_results.items():
                model_key = model_ligands_map[rmsd["model_ligand"]]
                rmsd["reference_ligand"] = reference_ligands_map[
                    rmsd.pop("target_ligand")]
                rmsd["model_ligand"] = model_key
                transform_data = rmsd["transform"].data
                rmsd["transform"] = [transform_data[i:i + 4]
                                     for i in range(0, len(transform_data), 4)]
                rmsd["bs_ref_res"] = [_QualifiedResidueNotation(r) for r in
                                      rmsd["bs_ref_res"]]
                rmsd["bs_ref_res_mapped"] = [_QualifiedResidueNotation(r) for r in
                                             rmsd["bs_ref_res_mapped"]]
                rmsd["bs_mdl_res_mapped"] = [_QualifiedResidueNotation(r) for r in
                                             rmsd["bs_mdl_res_mapped"]]
                rmsd["inconsistent_residues"] = ["%s-%s" %(
                    _QualifiedResidueNotation(x), _QualifiedResidueNotation(y)) for x,y in rmsd[
                    "inconsistent_residues"]]
                out["rmsd"][model_key] = rmsd

    return out


def _Main():

    args = _ParseArgs()
    ost.PushVerbosityLevel(args.verbosity)
    try:
        # Load structures
        reference = _LoadStructure(args.reference,
                                   format=args.reference_format,
                                   fault_tolerant = args.fault_tolerant)
        model = _LoadStructure(args.model, format=args.model_format,
                               fault_tolerant = args.fault_tolerant)

        # Load ligands
        model_ligands = _LoadLigands(args.model_ligands)
        reference_ligands = _LoadLigands(args.reference_ligands)

        # Validate
        _Validate(model, model_ligands, "model",
                  fault_tolerant = args.fault_tolerant)
        _Validate(reference, reference_ligands, "reference",
                  fault_tolerant = args.fault_tolerant)

        out = _Process(model, model_ligands, reference, reference_ligands, args)

        out["status"] = "SUCCESS"
        with open(args.output, 'w') as fh:
            json.dump(out, fh, indent=4, sort_keys=False)

    except Exception:
        out = dict()
        out["status"] = "FAILURE"
        out["traceback"] = traceback.format_exc()
        with open(args.output, 'w') as fh:
            json.dump(out, fh, indent=4, sort_keys=False)
        raise


if __name__ == '__main__':
    _Main()
