#!/usr/bin/env python3

r'''Reproject calibration images to show just the chessboard in its nominal
coordinates

SYNOPSIS

  D=../mrcal-doc-external/2022-11-05--dtla-overpass--samyang--alpha7/2-f22-infinity;

  analyses/mrcal-reproject-to-chessboard \
    --image-directory $D/images          \
    --outdir /tmp/frames-splined         \
    $D/splined.cameramodel

Outliers are highlighted. This is useful to visualize calibration errors. A
perfect solve would display exactly the same calibration grid with every frame.
In reality, we see the small errors in the calibration, and how they affect the
individual chessboard corner observations. The intent of this tool is to be able
to see any unmodeled chessboard deformations. It's not yet clear if this tool
can do that effectively

'''


import sys
import argparse
import re
import os

def parse_args():

    parser = \
        argparse.ArgumentParser(description = __doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument('--image-path-prefix',
                        help='''If given, we prepend the given prefix to the image paths. Exclusive with
                        --image-directory''')

    parser.add_argument('--image-directory',
                        help='''If given, we extract the filenames from the image paths in the solve, and use
                        the given directory to find those filenames. Exclusive
                        with --image-path-prefix''')
    parser.add_argument('--outdir',
                        type=lambda d: d if os.path.isdir(d) else \
                        parser.error(f"--outdir requires an existing directory as the arg, but got '{d}'"),
                        default='.',
                        help='''Directory to write the output into. If omitted,
                        we use the current directory''')

    parser.add_argument('--resolution-px-m',
                        default = 1000,
                        type    = float,
                        help='''The resolution of the output image in pixels/m.
                        Defaults to 1000''')

    parser.add_argument('--force', '-f',
                        action='store_true',
                        default=False,
                        help='''By default existing files are not overwritten. Pass --force to overwrite them
                        without complaint''')

    parser.add_argument('model',
                        type=str,
                        help='''The input camera model. We use its
                        optimization_inputs''')

    args = parser.parse_args()

    if args.image_path_prefix is not None and \
       args.image_directory is not None:
        print("--image-path-prefix and --image-directory are mutually exclusive",
              file=sys.stderr)
        sys.exit(1)

    return args

args = parse_args()

# arg-parsing is done before the imports so that --help works without building
# stuff, so that I can generate the manpages and README


import numpy as np
import numpysane as nps
import cv2
import mrcal

try:
    model = mrcal.cameramodel(args.model)
except Exception as e:
    print(f"Couldn't load camera model '{args.model}': {e}", file=sys.stderr)
    sys.exit(1)

optimization_inputs = model.optimization_inputs()
if optimization_inputs is None:
    print("The given model MUST have the optimization_inputs for this tool to be useful",
          file=sys.stderr)
    sys.exit(1)

object_height_n,object_width_n = optimization_inputs['observations_board'].shape[1:3]
object_spacing  = optimization_inputs['calibration_object_spacing']

# I cover a space of N+1 squares wide/tall: N-1 between all the corners + 1 on
# either side. I span squares -1..N inclusive
Nx = int(args.resolution_px_m * object_spacing * (object_width_n  + 1) + 0.5)
Ny = int(args.resolution_px_m * object_spacing * (object_height_n + 1) + 0.5)

# shape (Nframes,6)
rt_ref_frame    = optimization_inputs['frames_rt_toref']
rt_cam_ref      = optimization_inputs['extrinsics_rt_fromref']
ifcice          = optimization_inputs['indices_frame_camintrinsics_camextrinsics']
observations    = optimization_inputs['observations_board']
lensmodel       = optimization_inputs['lensmodel']
intrinsics_data = optimization_inputs['intrinsics']
imagepaths      = optimization_inputs.get('imagepaths')
if imagepaths is None:
    print("The given model MUST have the image paths for this tool to be useful",
          file=sys.stderr)
    sys.exit(1)

# shape (Ny,Nx,3)
p_frame = \
    mrcal.ref_calibration_object(object_width_n,object_height_n,
                                 object_spacing,
                                 x_corner0      = -1,
                                 x_corner1      = object_width_n,
                                 Nx             = Nx,
                                 y_corner0      = -1,
                                 y_corner1      = object_height_n,
                                 Ny             = Ny,
                                 calobject_warp = optimization_inputs['calobject_warp'])

for i in range(len(ifcice)):

    iframe, icamintrinsics, icamextrinsics = ifcice[i]

    p_ref = \
        mrcal.transform_point_rt(rt_ref_frame[iframe], p_frame)
    if icamextrinsics >= 0:
        p_cam = \
            mrcal.transform_point_rt(rt_cam_ref[icamextrinsics], p_ref)
    else:
        p_cam = p_ref

    # shape (Ny,Nx,2)
    q = mrcal.project(p_cam, lensmodel, intrinsics_data[icamintrinsics]).astype(np.float32)

    imagepath = imagepaths[i]
    if args.image_path_prefix is not None:
        imagepath = f"{args.image_path_prefix}/{imagepath}"
    elif args.image_directory is not None:
        imagepath = f"{args.image_directory}/{os.path.basename(imagepath)}"

    try:
        image = mrcal.load_image(imagepath,
                                 bits_per_pixel = 24,
                                 channels       = 3)
    except:
        print(f"Couldn't read image at '{imagepath}'", file=sys.stderr)
        sys.exit(1)

    image_out = mrcal.transform_image(image,q)

    # shape (Ny,Nx)
    weight       = observations[i,:,:,2]
    mask_outlier = weight<=0

    # Mark all the outliers
    for iy,ix in np.argwhere(mask_outlier):
        # iy,ix index corners. I need to convert the to pixels in my image
        # I have pixels = linspace(-1,object_width_n,Nx)
        # So:
        qx = (ix+1)*(Nx-1) / (object_width_n +1)
        qy = (iy+1)*(Ny-1) / (object_height_n+1)

        # Red circle around the outlier
        red = (0,0,255)
        cv2.circle(image_out,
                   center    = np.round(np.array((qx,qy))).astype(int),
                   radius    = 10,
                   color     = red,
                   thickness = 2)

    imagepath_out = f"{args.outdir}/{os.path.basename(imagepath)}"
    if not args.force and os.path.exists(imagepath_out):
        print("File already exists: '{imagepath_out}'. Not overwriting; pass --force to overwrite. Giving up.",
              file=sys.stderr)
        sys.exit(1)
    mrcal.save_image(imagepath_out, image_out)
    print(f"Wrote '{imagepath_out}'")
