# -*- coding: utf-8 -*-
# psf.py

# Copyright (c) 2008, The Regents of the University of California
# Produced by the Laboratory for Fluorescence Dynamics
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
#   notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in the
#   documentation and/or other materials provided with the distribution.
# * Neither the name of the copyright holders nor the names of any
#   contributors may be used to endorse or promote products derived
#   from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

"""Point spread function calculations for fluorescence microscopy.

:Authors:
  `Christoph Gohlke <http://www.lfd.uci.edu/~gohlke/>`__,
  Laboratory for Fluorescence Dynamics, University of California, Irvine

:Version: 20080625

Requirements
------------

* `Python 2.5 <http://www.python.org>`__
* `Numpy 1.1 <http://numpy.scipy.org>`__
* `Matplotlib 0.98 <http://matplotlib.sourceforge.net>`__

References
----------

(1) Focal volume optics and experimental artifacts in confocal fluorescence
    correlation spectroscopy.
    S T Hess, W W Webb. Biophys J (83) 2300-17, 2002.
(2) Electromagnetic description of image formation in confocal fluorescence
    microscopy.
    T D Viser, S H Wiersma. J Opt Soc Am A (11) 599-608, 1994.
(3) Photon counting histogram: one-photon excitation.
    B Huang, T D Perroud, R N Zare. Chem Phys Chem (5), 1523-31, 2004.
    Supporting information: Calculation of the observation volume profile.
(4) Gaussian approximations of fluorescence microscope point-spread function
    models.
    B Zhang, J Zerubia, J C Olivo-Marin. Appl. Optics (46) 1819-29, 2007.

Examples
--------

>>> import psf
>>> args = dict(shape=(32,32), dims=(4,4), ex_wavelen=488, em_wavelen=520,
...             num_aperture=1.2, refr_index=1.333, pinhole_radius=0.55)
>>> obsvol = psf.PSF(psf.GAUSSIAN|psf.CONFOCAL, **args)
>>> obsvol.sigma.ou
(2.5883205173739698, 1.3705931673819745)
>>> obsvol = psf.PSF(psf.ISOTROPIC|psf.CONFOCAL, **args)
>>> obsvol[0, :3]
array([ 1.        ,  0.51071049,  0.04396894])
>>> #save the image plane to file
>>> obsvol.slice(0).tofile("_test_slice.bin")
>>> #save a full 3D PSF volume to file
>>> obsvol.volume().tofile("_test_volume.bin")

"""

from __future__ import division

import sys
import math
import time
import threading
import numpy
import _psf

__docformat__ = "restructuredtext en"

ANISOTROPIC = 1
ISOTROPIC = 2
GAUSSIAN = 4
GAUSSLORENTZ = 8
EXCITATION = 16
EMISSION = 32
WIDEFIELD = 64
CONFOCAL = 128
TWOPHOTON = 256
PARAXIAL = 512


class PSF(object):
    """Calculate point spread function of various types.

    Instance Attributes
    -------------------

    type : int
        A combination of the following properties. Valid combinations are
        listed in PSF.types.

        ANISOTROPIC or ISOTROPIC or GAUSSIAN or GAUSSLORENTZ:
           Specify calculation model.

        EXCITATION or EMISSION or WIDEFIELD or CONFOCAL or TWOPHOTON:
            Specify type of PSF.

        PARAXIAL:
            Border case for Gaussian approximations.

    name : str
        A human readable label.

    data : 2D Numpy array of floats (C doubles)
        PSF values in z,r space normalized to the value at the origin.

    shape : sequence of int
        Size of the data array in pixel. Default (256, 256)

    dims : Dimension instance
        Dimensions of the data array in px (pixel), um (micrometers),
        ou (optical units), and au (airy units).

    ex_wavelen and em_wavelen : float or None
        Excitation or emission wavelengths in micrometers if applicable.

    num_aperture : float
        Numerical aperture (NA) of the objective. Default 1.2.

    refr_index : float
        Index of refraction of the sample medium. Default 1.333 (water).

    magnification : float
        Total magnification of the optical system. Default 1.0.

    underfilling : float
        Underfilling factor, i.e. the ratio of the radius of the objective
        back aperture to the exp(-2) radius of the excitation laser.
        Default 1.0.

    sigma : Dimension instance or None
        Gaussian sigma parameters in px (pixel), um (micrometers),
        ou (optical units), and au (airy units) if applicable.

    pinhole : Pinhole instance or None
        Pinhole applies to confocal types only.

    expsf, empsf : PSF instance or None
        Excitation or Emission PSF objects if applicable (i.e. when calculated
        intermediately for confocal)

    Notes
    -----

    Calculations of the isotropic PSFs are based on the complex integration
    representation for the diffraction near the image plane proposed by
    Richards and Wolf [1-3].

    Gaussian approximations are calculated according to [4].

    Widefield calculations are used if the pinhole radius is larger than ~8 au.

    Models for polarized excitation or emission light (ANISOTROPIC) and the
    Gaussian-Lorentzian approximation (GAUSSLORENTZ) are not implemented.

    """
    types = {
        ISOTROPIC|EXCITATION: "Excitation, Isotropic",
        ISOTROPIC|EMISSION: "Emission, Isotropic",
        ISOTROPIC|WIDEFIELD: "Widefield, Isotropic",
        ISOTROPIC|CONFOCAL: "Confocal, Isotropic",
        ISOTROPIC|TWOPHOTON: "Two-Photon, Isotropic",
        GAUSSIAN|EXCITATION: "Excitation, Gaussian",
        GAUSSIAN|EMISSION: "Emission, Gaussian",
        GAUSSIAN|WIDEFIELD: "Widefield, Gaussian", # same as Gaussian Emission
        GAUSSIAN|CONFOCAL: "Confocal, Gaussian",
        GAUSSIAN|TWOPHOTON: "Two-Photon, Gaussian",
        GAUSSIAN|EXCITATION|PARAXIAL: "Excitation, Gaussian, Paraxial",
        GAUSSIAN|EMISSION|PARAXIAL: "Emission, Gaussian, Paraxial",
        GAUSSIAN|WIDEFIELD|PARAXIAL: "Widefield, Gaussian, Paraxial",
        GAUSSIAN|CONFOCAL|PARAXIAL: "Confocal, Gaussian, Paraxial",
        GAUSSIAN|TWOPHOTON|PARAXIAL: "Two-Photon, Gaussian, Paraxial",
    }

    def __init__(self, type, shape=(256, 256), dims=(4., 4.),
        ex_wavelen=None, em_wavelen=None, num_aperture=1.2,
        refr_index=1.333, magnification=1.0, underfilling=1.0,
        pinhole_radius=None, pinhole_type=0,
        expsf=None, empsf=None, name=None):
        """Initialize the PSF object.

        Arguments
        ---------

        type, shape, num_aperture, refr_index, magnification, underfilling,
        expsf, and empsf: see PSF attributes.

        dims : sequence of float
            Dimensions of the data array in *micrometers*. Default (4., 4.)

        ex_wavelen and em_wavelen : float or None
            Excitation or emission wavelengths in *nanometers* if applicable.

        pinhole_radius : float or None
            Radius of the pinhole in *micrometers* in the object space. This
            is the back-projected radius, i.e. the real physical radius of the
            pinhole divided by the magnification of the system.

        pinhole_type : int
            Either 0 (round) or 1 (square). Default 0.

        """
        try:
            self.name = self.types[type]
            self.type = type
        except:
            raise ValueError("PSF type is invalid or not supported.")

        if name:
            self.name = str(name)

        self.shape = int(shape[0]), int(shape[1])
        self.dims = Dimensions(px=shape, um=(float(dims[0]), float(dims[1])))

        self.ex_wavelen = ex_wavelen/1e3 if ex_wavelen else None
        self.em_wavelen = em_wavelen/1e3 if em_wavelen else None
        self.num_aperture = num_aperture
        self.refr_index = refr_index
        self.magnification = magnification
        self.underfilling = underfilling
        self.sigma = None
        self.pinhole = None
        self.expsf = expsf
        self.empsf = empsf

        if (not (type&EXCITATION)) and (em_wavelen is None):
            raise ValueError("emission wavelength not specified.")

        if (not (type&EMISSION)) and (ex_wavelen is None):
            raise ValueError("excitation wavelength not specified.")

        if ((type&CONFOCAL)) and (pinhole_radius is None):
            raise ValueError("pinhole radius not specified.")

        self.sinalpha = self.num_aperture / self.refr_index
        if (self.sinalpha >= 1.0):
            raise ValueError("quotient of the numeric aperture (%.1f) and "
                "refractive index (%.1f) is greater than 1.0 (%.2f)" % (
                self.num_aperture, self.refr_index, self.sinalpha))

        if type&EMISSION:
            au = (1.22 * self.em_wavelen / self.num_aperture)
            ou = zr2uv(self.dims.um, self.em_wavelen, self.sinalpha,
                       self.refr_index, self.magnification)
        else:
            au = (1.22 * self.ex_wavelen / self.num_aperture)
            ou = zr2uv(self.dims.um, self.ex_wavelen, self.sinalpha,
                       self.refr_index, 1.0)
        self.dims.update(ou=ou, au=(self.dims.um[0]/au, self.dims.um[1]/au))

        if pinhole_radius:
            self.pinhole = Pinhole(pinhole_radius,self.dims, pinhole_type)

        start = time.clock()
        if type&GAUSSIAN:
            self.sigma = Dimensions(**self.dims)
            if self.underfilling != 1.0:
                raise NotImplementedError(
                    "Underfilling not implemented in Gaussian approximation.")

            if type&EXCITATION or type&TWOPHOTON:
                widefield = True
                self.em_wavelen = None
                self.magnification = None
                self.pinh_radius = None
                lex = lem = self.ex_wavelen
                radius = 0.0;
            elif type&EMISSION or type&WIDEFIELD:
                widefield = True
                self.ex_wavelen = None
                self.magnification = None
                lex = lem = self.em_wavelen
                radius = 0.0;
            elif type&CONFOCAL:
                radius = self.pinhole.radius.um
                if radius > 9.76 * self.ex_wavelen/self.num_aperture:
                    # use widefield approximation for pinholes > 8 AU
                    widefield = True
                    lex = lem = self.ex_wavelen
                else:
                    widefield = False
                    lex = self.ex_wavelen
                    lem = self.em_wavelen
                if self.pinhole.type:
                    raise NotImplementedError(
                        "Gaussian approximation only valid for round pinhole.")

            paraxial = bool(type&PARAXIAL)
            self.sigma.um = _psf.gaussian_sigma(lex, lem, self.num_aperture,
                                self.refr_index, radius, widefield, paraxial)
            self.data = _psf.gaussian2d(self.dims.px, self.sigma.px)

        elif type&ISOTROPIC:
            if type&EXCITATION or type&TWOPHOTON:
                self.em_wavelen = None
                self.magnification = None
                self.data = _psf.psf(0, self.shape, self.dims.ou, 1.0,
                                     self.sinalpha, self.underfilling, 1.0, 80)
            elif type&EMISSION:
                self.ex_wavelen = None
                self.underfilling = None
                self.data = _psf.psf(1, self.shape, self.dims.ou,
                                     self.magnification, self.sinalpha,
                                     1.0, 1.0, 80)
            elif type&CONFOCAL or type&WIDEFIELD:
                if em_wavelen < ex_wavelen:
                    raise ValueError("Excitation > Emission wavelength.")
                # start threads to calculate excitation and emission PSF
                threads = []
                if not (self.expsf and
                        self.expsf.type == ISOTROPIC|EXCITATION):
                    threads.append(("expsf", PSFthread(ISOTROPIC|EXCITATION,
                        shape, dims, ex_wavelen, None, num_aperture,
                        refr_index, 1.0, underfilling)))
                if not (self.empsf and
                        self.empsf.type == ISOTROPIC|EMISSION):
                    threads.append(("empsf", PSFthread(ISOTROPIC|EMISSION,
                        shape, dims, None, em_wavelen, num_aperture,
                        refr_index, magnification, 1.0)))
                for a,t in threads:
                    t.start()
                for a,t in threads:
                    t.join()
                    setattr(self, a, t.psf)
                if not (self.expsf.iscompatible(self.empsf)):
                    raise ValueError(
                        "Excitation and Emission PSF not compatible.")
                if type&WIDEFIELD or (self.pinhole.radius.um >
                    self.ex_wavelen / self.num_aperture * 9.76):
                    # use widefield approximation for pinholes > 8 AU
                    self.data = _psf.obsvol(self.expsf.data, self.empsf.data)
                else:
                    self.data = _psf.obsvol(self.expsf.data, self.empsf.data,
                                            self.pinhole.kernel())

        if type&TWOPHOTON:
            self.data *= self.data
        self.time = float(time.clock()-start) * 1e3

    def __getitem__(self, key):
        """Return value of data array at position."""
        return self.data[key]

    def __str__(self):
        """Return properties of PSF object as string."""
        s = [self.name]
        s.append("  Shape: (%i, %i) pixel" % self.dims.px)
        s.append("  Dimensions: %s" %
            self.dims.format(["um", "ou", "au"], ["%.2f", "%.2f", "%.2f"]))
        if self.ex_wavelen:
            s.append("  Excitation Wavelength: %.1f nm" %
                     (self.ex_wavelen * 1e3))
        if self.em_wavelen:
            s.append("  Emission Wavelength: %.1f nm" %
                     (self.em_wavelen * 1e3))
        s.append("  Numeric Aperture: %.2f" % self.num_aperture)
        s.append("  Refractive Index: %.2f" % self.refr_index)
        s.append("  Half Cone Angle: %.2f deg" %
            math.degrees(math.asin(self.sinalpha)))
        if self.magnification:
            s.append("  Magnification: %.2f" % self.magnification)
        if self.underfilling:
            s.append("  Underfilling: %.2f" % self.underfilling)
        if self.pinhole:
            s.append("  Pinhole Radius: %s" % self.pinhole.radius.format(
                ["um", "ou", "au", "px"], ["%.3f", "%.3f", "%.4f", "%.2f"]))
        if self.sigma:
            s.append("  Gauss Sigma: %s" % self.sigma.format(
                ["um", "ou", "au", "px"], ["%.3f", "%.3f", "%.3f", "%.2f"]))
        s.append("  Computing Time: %.2f ms\n" % self.time)
        return "\n".join(s)

    def iscompatible(self, other):
        """Return True if objects match dimensions and optical properties."""
        return ((self.dims.px[0] == other.dims.px[0])
            and (self.dims.px[1] == other.dims.px[1])
            and (self.dims.um[0] == other.dims.um[0])
            and (self.dims.um[1] == other.dims.um[1])
            and (self.num_aperture == other.num_aperture)
            and (self.refr_index == other.refr_index))

    def slice(self, key=slice(None)):
        """Return a z slice of the PSF with rotational symmetries applied."""
        return _psf.zr2zxy(self.data[key])

    def volume(self):
        """Return a 3D volume of the PSF with all symmetries applied.

        The shape of the returned array is
            (2*self.shape[0]-1, 2*self.shape[1]-1, 2*self.shape[1]-1)

        """
        return mirror_symmetry(_psf.zr2zxy(self.data))

    def imshow(self, subplot=111, **kwargs):
        """Log-plot PSF image using matplotlib.pyplot. Return plot axis."""
        title = kwargs.get("title", self.name)
        aspect = self.shape[1]/self.shape[0] * self.dims.um[0]/self.dims.um[1]
        kwargs.update(dict(data=self.data, title=title, subplot=subplot,
                           aspect=aspect))
        return imshow(**kwargs)


class PSFthread(threading.Thread):
    """Calculate point spread function in a thread."""
    def __init__(self, *args, **kwargs):
        threading.Thread.__init__(self)
        self.args = args
        self.kwargs = kwargs
        self.psf = None

    def run(self):
        self.psf = PSF(*self.args, **self.kwargs)


class Pinhole(object):
    """Pinhole object for confocal microscopy.

    Instance Attributes
    -------------------

    radius : Dimension instance
        Dimensions of the pinhole radius in px (pixel), um (micrometers),
        ou (optical units), and au (airy units).

    type : int
        Shape of the pinhole. Either 0 (round) or 1 (square).

    Examples
    --------

    >>> ph = Pinhole(0.1, dict(px=16, um=1.0))
    >>> ph.radius.px
    1.6000000000000001
    >>> ph.kernel()
    array([[ 1.        ,  1.6       ,  0.6       ],
           [ 0.8       ,  1.18578644,  0.36393202],
           [ 0.3       ,  0.36393202,  0.        ]])
    """
    types = {0: "Round", 1: "Square" }

    def __init__(self, radius, dimensions, type=0):
        """Initialize the pinhole object.

        Arguments
        ---------

        radius : float
            Pinhole radius in micrometers in object space.

        dimensions : dict
            Dimensions of the object space in "px" (pixel), "um" (micrometers),
            "ou" (optical units), and "au" (airy units).

        """
        self.name = "%s Pinhole" % self.types[type]
        self.type = type
        try:
            dimensions = dict((k, v[1]) for k,v in dimensions.iteritems())
        except TypeError:
            pass
        self.radius = Dimensions(**dimensions)
        self.radius.um = float(radius)
        self._kernel = None

    def __str__(self):
        s = [self.name]
        s.append("  Radius: %s" % str(self.radius))
        return "\n".join(s)

    def kernel(self):
        """Return convolution kernel for integration over the pinhole."""
        if self._kernel is None:
            self._kernel = _psf.pinhole_kernel(self.radius.px)
        return self._kernel


class Dimensions(dict):
    """Store dimensions in various units and perform linear conversions.

    Examples
    --------

    >>> dim = Dimensions(px=100, um=2)
    >>> dim(50, "px", "um")
    1.0
    >>> dim.px, dim.um
    (100, 2)
    >>> dim.px = 50
    >>> dim.um
    1.0
    >>> dim.format(("um", "px"), ("%.2f", "%.1f"))
    '1.00 um, 50.0 px'

    >>> dim = Dimensions(px=(100, 200), um=(2, 8))
    >>> dim((50, 50), "px", "um")
    (1.0, 2.0)
    >>> dim.ou = (1, 2)
    >>> dim.px
    (100, 200)
    >>> dim["px"] = (50, 100)
    >>> dim.ou
    (0.5, 1.0)

    """
    __slots__ = []

    def __call__(self, value, unit, newunit):
        """Return value given in unit in another unit."""
        dim = self[unit]
        new = self[newunit]
        try:
            return value * (new/dim)
        except TypeError:
            return tuple(v*(o/u) for v,u,o in zip(value, dim, new))

    def __setitem__(self, unit, value):
        """Add a dimension or rescale all dimensions to new value."""
        try:
            dim = self[unit]
        except KeyError:
            dict.__setitem__(self, unit, value)
        else:
            try:
                scale = value/dim
                for k,v in self.iteritems():
                    dict.__setitem__(self, k, v * scale)
            except TypeError:
                scale = tuple(v/d for v,d in zip(value, dim))
                for k,v in self.iteritems():
                    dict.__setitem__(self, k,
                                     tuple(v*s for v,s in zip(self[k], scale)))

    def __getattr__(self, unit):
        """Return value of unit."""
        return self[unit]

    def __setattr__(self, unit, value):
        """Add a dimension or rescale all dimensions to new value."""
        self.__setitem__(unit, value)

    def format(self, keys, format):
        """Return formatted string."""
        s = []
        try:
            for k, f in zip(keys, format):
                s.append("%s %s" % (f % self[k], k))
        except TypeError:
            for k,f in zip(keys, format):
                v = self[k]
                t = []
                for i in v:
                    t.append(f % i)
                s.append("(%s) %s" % (", ".join(t), k))
        return ", ".join(s)


def uv2zr(uv, wavelength, sinalpha, refr_index, magnification=1.0):
    """Return z,r in units of the wavelength from u,v given in optical units.

    For excitation, magnification should be 1.

    >>> uv2zr((1, 1), 488, 0.9, 1.33)
    (72.094692498695736, 64.885223248826165)

    """
    a = wavelength / (2.0 * math.pi * sinalpha * refr_index * magnification)
    b = a / (sinalpha * magnification)
    return uv[0]*b, uv[1]*a


def zr2uv(zr, wavelength, sinalpha, refr_index, magnification=1.0):
    """Return u,v in optical units from z,r given in units of the wavelength.

    For excitation, magnification should be 1.

    >>> zr2uv((1e3, 1e3), 488, 0.9, 1.33)
    (13.870646580788051, 15.411829534208946)

    """
    a = (2.0 * math.pi * refr_index * sinalpha * magnification) / wavelength
    b = a * sinalpha * magnification
    return zr[0]*b, zr[1]*a


def mirror_symmetry(data):
    """Apply mirror symmetry along one face in each dimension.

    The input array can be 1, 2 or 3-dimensional.

    The shape of the returned array is 2*data.shape-1 in each dimension.

    >>> mirror_symmetry([0, 1])
    array([ 1.,  0.,  1.])

    >>> mirror_symmetry([[0, 1],[0, 1]])
    array([[ 1.,  0.,  1.],
           [ 1.,  0.,  1.],
           [ 1.,  0.,  1.]])

    >>> mirror_symmetry([[[0, 1],[0, 1]], [[0, 1],[0, 1]], [[0, 1],[0, 1]]])[0]
    array([[ 1.,  0.,  1.],
           [ 1.,  0.,  1.],
           [ 1.,  0.,  1.]])

    """
    data = numpy.array(data)
    result = numpy.empty([2*i-1 for i in data.shape], numpy.float64)
    if data.ndim == 1:
        x = data.shape[0] - 1
        result[x:] = data
        result[:x] = data[-1:0:-1]
    elif data.ndim == 2:
        x, y = (i-1 for i in data.shape)
        result[x:, y:] = data
        result[:x, y:] = data[-1:0:-1, :]
        result[ :, :y] = result[:, -1:y:-1]
    elif data.ndim == 3:
        x, y, z = (i-1 for i in data.shape)
        result[x:, y:, z:] = data
        result[:x, y:, z:] = data[-1:0:-1, :, :]
        result[ :, :y, z:] = result[:, -1:y:-1, z:]
        result[ :, :,  :z] = result[:, :, -1:z:-1]
    else:
        raise NotImplementedError("%i-dimensional arrays not supported." %
                                  data.ndim)
    return result


def imshow(subplot, data, title=None, sharex=None, sharey=None,
         vmin=-2.5, vmax=0, cmap=None, interpolation='bilinear', **kwargs):
    """Log-plot image using matplotlib.pyplot. Return plot axis and plot.

    Mirror symmetry is applied along the x and y axes.

    Requires pyplot already imported ``from matplotlib import pyplot``.

    """
    pyplot = sys.modules["matplotlib.pyplot"]

    ax = pyplot.subplot(subplot, sharex=sharex, sharey=sharey)
    if title:
        pyplot.title(title)
    if cmap is None:
        cmap = pyplot.cm.hot
    im = pyplot.imshow(mirror_symmetry(numpy.log10(data)),
                       vmin=vmin, vmax=vmax, cmap=cmap,
                       interpolation=interpolation, **kwargs)
    pyplot.axis("off")
    return ax, im


if __name__ == "__main__":
    import doctest
    doctest.testmod(verbose=True)