#!/usr/bin/env python3
"""
Fitting example: fit with masks
"""

from matplotlib import pyplot as plt
import bornagain as ba
from bornagain import deg, nm, ba_fitmonitor
import model1_cylinders as model


def get_masked_simulation(P, add_masks=True):
    simulation = model.get_simulation(P)
    if add_masks:
        add_mask_to_simulation(simulation)
    return simulation


def add_mask_to_simulation(simulation):
    """
    Here we demonstrate how to add masks to the simulation.
    Only unmasked areas will be simulated and then used during the fit.

    Masks can have different geometrical shapes (ba.Rectangle, ba.Ellipse, Line)
    with the mask value either "True" (detector bin is excluded from the simulation)
    or False (will be simulated).

    Every subsequent mask overrides previously defined mask in this area.

    In the code below we put masks in such way that simulated image will look like
    a Pac-Man from ancient arcade game.
    """
    # mask all detector (put mask=True to all detector channels)
    simulation.detector().maskAll()

    # set mask to simulate pacman's head
    simulation.detector().addMask(ba.Ellipse(0, 1*deg, 0.5*deg, 0.5*deg), False)

    # set mask for pacman's eye
    simulation.detector().addMask(ba.Ellipse(0.11*deg, 1.25*deg, 0.05*deg, 0.05*deg),
                       True)

    # set mask for pacman's mouth
    points = [[0*deg, 1*deg], [0.5*deg, 1.2*deg], [0.5*deg, 0.8*deg],
              [0*deg, 1*deg]]
    simulation.detector().addMask(ba.Polygon(points), True)

    # giving pacman something to eat
    simulation.detector().addMask(
        ba.Rectangle(0.45*deg, 0.95*deg, 0.55*deg, 1.05*deg), False)
    simulation.detector().addMask(
        ba.Rectangle(0.61*deg, 0.95*deg, 0.71*deg, 1.05*deg), False)
    simulation.detector().addMask(
        ba.Rectangle(0.75*deg, 0.95*deg, 0.85*deg, 1.05*deg), False)


if __name__ == '__main__':
    real_data = model.create_real_data()

    fit_objective = ba.FitObjective()
    fit_objective.addSimulationAndData(get_masked_simulation, real_data, 1)
    fit_objective.initPrint(10)
    observer = ba_fitmonitor.PlotterGISAS()
    fit_objective.initPlot(10, observer)

    P = ba.Parameters()
    P.add("radius", 6.*nm, min=4, max=8)
    P.add("height", 9.*nm, min=8, max=12)

    minimizer = ba.Minimizer()
    result = minimizer.minimize(fit_objective.evaluate, P)
    fit_objective.finalize(result)

    plt.show()
