#!/usr/bin/env python
# -*- coding: utf8 -*-
#
#    Project: Azimuthal integration 
#             https://forge.epn-campus.eu/projects/azimuthal
#
#    File: "$Id$"
#
#    Copyright (C) European Synchrotron Radiation Facility, Grenoble, France
#
#    Principal author:       Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

"""
pyFAI-calib

A tool for determining the geometry of a detector using a reference sample.

usage:
  python2.6 pyFAI-calib [-g=100] [-nd] [-d] [-pix=172e-6,172e-6] inputFile.edf
  
-g= size of the gap (in pixels) between two consecutive rings, by default size/20 
Increase the value if the arc is not complete
Decrease the value if arcs are mixed together.  

-nd : to avoid diagonal expansion of the various massifs. (default)  
-d  : to allow diagonal expansion of the various massifs.
-spline=/path/to/file.spline
-min for automatic removal of background (or -min=100 if you know the value)
 
"""

__author__ = "Jerome Kieffer"
__contact__ = "Jerome.Kieffer@ESRF.eu"
__license__ = "GPLv3+"
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
__date__ = "21/12/2011"
__satus__ = "development"

import os, sys, gc, threading, time, logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("pyFAI.calib")
import numpy
from numpy import sin, cos, arccos, sqrt, floor, ceil, radians, degrees, pi
import fabio
import matplotlib
import pylab
from scipy.optimize import fmin, leastsq, fmin_slsqp, anneal
from scipy.interpolate import interp2d
import pyFAI
from pyFAI.geometryRefinement import GeometryRefinement
from pyFAI.peakPicker import PeakPicker
from pyFAI.utils import averageImages
from  matplotlib.path import Path
import matplotlib.path as mpath
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt


try:
    from rfoo.utils import rconsole
    rconsole.spawn_server()
except ImportError:
    logging.info("No socket opened for debugging -> install rfoo")


def main():
    pixelSize = None
    gaussianWidth = None
    listInputFiles = []
    labelPattern = [[0, 1, 0], [1, 1, 1], [0, 1, 0]]
    splineFile = None
    cutBackground = None
    for  arg in sys.argv[1:]:
        if arg == "-test":
            test()
            sys.exit(0)
        elif arg.find("-debug") in [0, 1]:
            logger.setLevel(logging.DEBUG)
        elif arg.find("-g=") in [0, 1]:
            gaussianWidth = float(arg.split("=")[1])
        elif arg.find("-d") in [0, 1]:
            labelPattern = [[1] * 3] * 3
        elif arg.find("-nd") in [0, 1]:
            labelPattern = [[0, 1, 0], [1, 1, 1], [0, 1, 0]]
        elif arg.find("-min") in [0, 1]:
            cutBackground = True
            if "=" in arg:
                cutBackground = float(arg.split("=")[1])
        elif arg.find("-spline=") in [0, 1]:
            splineFile = arg.split("=")[1]
            if not os.path.isfile(splineFile):
                splineFile = None
        elif arg.find("-pix=") in [0, 1]:
            pixels = arg.split("=")[1]
            if "," in pixels:
                pixelSize = [float(i) for  i in pixels.split(",")[:2]]
            else:
                pixelSize = [float(pixels), float(pixels)]
        elif arg.startswith("--version"):
            print(pyFAI.version)
            sys.exit(0)
        elif arg.find("-h") in [0, 1]:
            print(__doc__)
            sys.exit(0)
        elif  os.path.isfile(arg):
            listInputFiles.append(arg)

    if len(listInputFiles) == 0:
        logging.info(__doc__)
        sys.exit(1)

    elif len(listInputFiles) == 1:
        if cutBackground:
            inputFile = averageImages(listInputFiles, "merged.edf", minimum=cutBackground)
        else:
            inputFile = listInputFiles[0]
    else:
        inputFile = averageImages(listInputFiles, "merged.edf", minimum=cutBackground)

    peakPicker = PeakPicker(inputFile)
    if gaussianWidth is not None:
        peakPicker.massif.setValleySize(gaussianWidth)
    else:
        peakPicker.massif.initValleySize()
    oneThread = threading.Thread(target=peakPicker.massif.getLabeledMassif, args=(labelPattern,))
    oneThread.start()
    if  (pixelSize is None) and (splineFile is None):
        pixelSize = [1.5e-5, 1.5e-5]
        ans = raw_input("Please enter the pixel size (in meter, C order, i.e. %.2e %.2e) or a spline file: " % tuple(pixelSize)).strip()
        if os.path.isfile(ans):
            splineFile = ans
        elif len(ans.split()) == 2:
            px = ans.split()
            try:
                pixelSize = [float(i) for i in px[0:2]]
            except:
                logging.error("error in reading pixel size")
                sys.exit(1)
        elif len(ans.split()) == 1:
            px = ans
            try:
                pixelSize = [float(px), float(px)]
            except:
                logging.error("error in reading pixel size")
                sys.exit(1)
    basename = os.path.splitext(inputFile)[0]
    peakPicker.gui(True)
    datafile = os.path.splitext(inputFile)[0] + ".npt"
    if os.path.isfile(datafile):
        peakPicker.load(datafile)
    data = peakPicker.finish(datafile)
    if os.name == "nt":
        logging.info("We are under windows, matplotlib is not able to display too many images without crashing, this is why the window showing the diffraction image is closed")
        peakPicker.closeGUI()
    if splineFile:
        geoRef = GeometryRefinement(data, dist=0.1, splineFile=splineFile)
    else:
        geoRef = GeometryRefinement(data, dist=0.1, pixel1=pixelSize[0], pixel2=pixelSize[1])
    paramfile = os.path.splitext(inputFile)[0] + ".poni"
    if os.path.isfile(paramfile):
        geoRef.load(paramfile)
    print geoRef
    previous = sys.maxint
#    geoRef.spline.writeEDF("fromSpline")
    finished = False
    fig2 = None
    while not finished:
        while previous > geoRef.chi2():
            previous = geoRef.chi2()
            geoRef.refine2(1000000)
            print geoRef
#        geoRef.refine1()
        print geoRef
        geoRef.save(basename + ".poni")
        geoRef.del_ttha()
        geoRef.del_dssa()
        geoRef.del_chia()
        t0 = time.time()
        tth = geoRef.twoThetaArray(peakPicker.shape)
        t1 = time.time()
        dsa = geoRef.solidAngleArray(peakPicker.shape)
        t2 = time.time()
        geoRef.chiArray(peakPicker.shape)
        t2a = time.time()
        geoRef.cornerArray(peakPicker.shape)
        t2b = time.time()
        if os.name == "nt":
            logging.info("We are under windows, matplotlib is not able to display too many images without crashing, this is why little information is displayed")
        else:
            peakPicker.contour(tth)
            if fig2 is None:
                fig2 = pylab.plt.figure()
                sp = fig2.add_subplot(111)
            else:
                sp.images.pop()
            sp.imshow(dsa)
            #self.fig.canvas.draw()
            fig2.show()

        change = raw_input("Modify parameters ?\t ").strip()
        if (change == '') or (change.lower()[0] == "n"):
            finished = True
        else:
            peakPicker.readFloatFromKeyboard("Enter Distance in meter (or dist_min[%.3f] dist[%.3f] dist_max[%.3f]):\t " % (geoRef.dist_min, geoRef.dist, geoRef.dist_max), {1:[geoRef.set_dist], 3:[ geoRef.set_dist_min, geoRef.set_dist, geoRef.set_dist_max]})
            peakPicker.readFloatFromKeyboard("Enter Poni1 in meter (or poni1_min[%.3f] poni1[%.3f] poni1_max[%.3f]):\t " % (geoRef.poni1_min, geoRef.poni1, geoRef.poni1_max), {1:[geoRef.set_poni1], 3:[ geoRef.set_poni1_min, geoRef.set_poni1, geoRef.set_poni1_max]})
            peakPicker.readFloatFromKeyboard("Enter Poni2 in meter (or poni2_min[%.3f] poni2[%.3f] poni2_max[%.3f]):\t " % (geoRef.poni2_min, geoRef.poni2, geoRef.poni2_max), {1:[geoRef.set_poni2], 3:[ geoRef.set_poni2_min, geoRef.set_poni2, geoRef.set_poni2_max]})
            peakPicker.readFloatFromKeyboard("Enter Rot1 in rad (or rot1_min[%.3f] rot1[%.3f] rot1_max[%.3f]):\t " % (geoRef.rot1_min, geoRef.rot1, geoRef.rot1_max), {1:[geoRef.set_rot1], 3:[ geoRef.set_rot1_min, geoRef.set_rot1, geoRef.set_rot1_max]})
            peakPicker.readFloatFromKeyboard("Enter Rot2 in rad (or rot2_min[%.3f] rot2[%.3f] rot2_max[%.3f]):\t " % (geoRef.rot2_min, geoRef.rot2, geoRef.rot2_max), {1:[geoRef.set_rot2], 3:[ geoRef.set_rot2_min, geoRef.set_rot2, geoRef.set_rot2_max]})
            peakPicker.readFloatFromKeyboard("Enter Rot3 in rad (or rot3_min[%.3f] rot3[%.3f] rot3_max[%.3f]):\t " % (geoRef.rot3_min, geoRef.rot3, geoRef.rot3_max), {1:[geoRef.set_rot3], 3:[ geoRef.set_rot3_min, geoRef.set_rot3, geoRef.set_rot3_max]})
            previous = sys.maxint
    fig3 = pylab.plt.figure()
    xrpd = fig3.add_subplot(111)
    fig4 = pylab.plt.figure()
    xrpd2 = fig4.add_subplot(111)

    t3 = time.time()
    a, b = geoRef.xrpd(peakPicker.data, 1024, basename + ".xy")
    t4 = time.time()
    img = geoRef.xrpd2(peakPicker.data, 400, 360, basename + ".azim")[0]
    t5 = time.time()
    print ("Timings:\n two theta array generation %.3fs\n diff Solid Angle  %.3fs\n\
 chi array generation %.3fs\n\
 corner coordinate array %.3fs\n\
 1D Azimuthal integration: %.3fs\n\
 2D Azimuthal integration: %.3fs" % (t1 - t0, t2 - t1, t2a - t2, t2b - t2a, t4 - t3, t5 - t4))
    xrpd.plot(a, b)
    fig3.show()
    xrpd2.imshow(numpy.log(img - img.min() + 1e-3))
    fig4.show()



    raw_input("Press enter to quit")

if __name__ == "__main__":
    main()
