#!/usr/bin/env python
# -*- coding: utf-8 -*-
# dnacurve.py

# Copyright (c) 1994-2008, Christoph Gohlke
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
#   notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in the
#   documentation and/or other materials provided with the distribution.
# * Neither the name of the copyright holders nor the names of any
#   contributors may be used to endorse or promote products derived
#   from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

"""DNA Curvature Analysis.

Calculates the global 3D structure of a B-DNA molecule from its nucleotide
sequence according to the dinucleotide wedge model.  Analyzes local bending
angles and macroscopic curvature at each nucleotide.

For command line usage run ``python dnacurve.py --help``

:Authors: `Christoph Gohlke <http://www.lfd.uci.edu/~gohlke/>`__

:Version: 20080625

Requirements
------------

*  `Python 2.5 <http://www.python.org>`__
*  `Numpy 1.1 <http://numpy.scipy.org>`__
*  `Matplotlib 0.98 <http://matplotlib.sourceforge.net>`__

References
----------

(1) Bending and curvature calculations in B-DNA.
    Goodsell DS, Dickerson RE. Nucleic Acids Res 22, 5497-503, 1994.
    See also http://mgl.scripps.edu/people/goodsell/research/bend/
(2) Curved DNA without A-A: experimental estimation of all 16 DNA wedge angles.
    Bolshoy A et al. Proc Natl Acad Sci USA 88, 2312-6, 1991.
(3) A comparison of six DNA bending models.
    Tan RK and Harvey SC. J Biomol Struct Dyn 5, 497-512, 1987.
(4) Curved DNA: design, synthesis, and circularization.
    Ulanovsky L et al. Proc Natl Acad Sci USA 83, 862-6, 1986.
(5) The ten helical twist angles of B-DNA.
    Kabsch W, Sander C, and Trifonov EN. Nucleic Acids Res 10, 1097-1104, 1982.

Examples
--------

>>> from dnacurve import CurvedDNA
>>> result = CurvedDNA("ATGCAAATTG"*5, "trifonov", name="Example")
>>> print result.curvature[:, 18:22]
[[ 0.58061616  0.58163338  0.58277938  0.583783  ]
 [ 0.08029914  0.11292516  0.07675816  0.03166286]
 [ 0.57923902  0.57580064  0.57367815  0.57349872]]
>>> result.save_csv("_test.csv")
>>> result.save_pdb("_test.pdb")
>>> result.plot("_test.png", dpi=160)

"""

from __future__ import division, with_statement

import sys
import os
import math
import datetime
import optparse
import re

import numpy

__docformat__ = "restructuredtext en"


class Model(object):
    """N-mer DNA-bending model.

    Transformation parameters and matrices for all oligonucleotides of
    certain length.

    Instance Attributes
    -------------------

    name : str
        Human readable label.

    order : int
        Order of model, i.e. length of oligonucleotides.
        Order 2 is a dinucleotide model, order 3 a trinucleotide model etc.

    rise : float
        Displacement along the Z axis.

    twist : dict
        Rotation angle in deg about the Z axis for all oligonucleotides.

    roll : dict
        Rotation angle in deg about the Y axis for all oligonucleotides.

    tilt : dict
        Rotation angle in deg about the Z axis for all oligonucleotides.

    matrices : dict
        Homogeneous transformation matrices for all oligonucleotides.

    Examples
    --------

    >>> m = Model("AAWedge")
    >>> m = Model("Nucleosome")
    >>> m = Model(**Model.straight)
    >>> m = Model(Model.calladine, name="My Model", rise=4.0)
    >>> assert m.name=="My Model" and m.rise==4.0
    >>> m = Model(name="Test", rise=3.38,
    ...           oligo="AA AC AG AT CA GG CG GA GC TA".split(),
    ...           twist=(34.29, )*10, roll=(0., )*10, tilt=(0., )*10)
    >>> m.save("_test.dat")
    >>> assert m.twist == Model("_test.dat").twist

    """
    straight = dict(
        name = "Straight",
        oligo =  "AA AC AG AT CA GG CG GA GC TA",
        twist = (360.0/10.5, ) * 10,
        roll = (0.0, ) * 10,
        tilt = (0.0, ) * 10,
        rise = 3.38)

    aawedge = dict(
        name = "AA Wedge",
        oligo =  "AA     AC    AG    AT    CA    GG     CG    GA    GC    TA",
        twist = (35.62, 34.4, 27.7, 31.5, 34.5, 33.67, 29.8, 36.9, 40.0, 36.0),
        roll = ( -8.40,  0.0,  0.0,  0.0,  0.0,   0.0,  0.0,  0.0,  0.0,  0.0),
        tilt = (  2.40,  0.0,  0.0,  0.0,  0.0,   0.0,  0.0,  0.0,  0.0,  0.0),
        rise = 3.38)

    trifonov = dict(
        name = "Bolshoi & Trifonov",
        oligo =  "AA     AC    AG    AT    CA    GG     CG    GA    GC    TA",
        twist = (35.62, 34.4, 27.7, 31.5, 34.5, 33.67, 29.8, 36.9, 40.0, 36.0),
        roll =  (-6.50, -0.9,  8.4,  2.6,  1.6,  1.2,   6.7, -2.7, -5.0,  0.9),
        tilt =   (3.20, -0.7, -0.3,  0.0,  3.1, -1.80,  0.0, -4.6,  0.0,  0.0),
        rise = 3.38)

    calladine = dict(
        name = "Calladine & Drew",
        oligo =  "AA    AC    AG    AT    CA    GG    CG    GA    GC    TA",
        twist = (35.0, 34.0, 34.0, 34.0, 34.0, 34.0, 34.0, 34.0, 34.0, 34.0),
        roll =   (0.0,  3.3,  3.3,  3.3,  3.3,  3.3,  3.3,  3.3,  3.3,  6.6),
        tilt =   (0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0),
        rise = 3.38)

    reversed = dict(
        name = "Reversed Calladine & Drew",
        oligo =  "AA    AC    AG    AT    CA    GG    CG    GA    GC    TA",
        twist = (35.0, 34.0, 34.0, 34.0, 34.0, 34.0, 34.0, 34.0, 34.0, 34.0),
        roll =   (3.3,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0, -3.3),
        tilt =   (0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0),
        rise = 3.38)

    desantis = dict(
        name = "Cacchione & De Santis",
        oligo =  "AA    AC    AG    AT    CA    GG    CG    GA    GC    TA",
        twist = (35.9, 34.6, 35.6, 35.0, 34.5, 33.0, 33.7, 35.8, 33.3, 34.6),
        roll =  (-5.4, -2.4,  1.0, -7.3,  6.7,  1.3,  4.6,  2.0, -3.7,  8.0),
        tilt =  (-0.5, -2.7, -1.6,  0.0,  0.4, -0.6,  0.0, -1.7,  0.0,  0.0),
        rise = 3.38)

    nucleosome = dict(
        name = "Nucleosome Positioning",
        oligo = """
                AAA  ATA  AGA  ACA  TAA  TTA  TGA  TCA
                GAA  GTA  GGA  GCA  CAA  CTA  CGA  CCA
                AAT  ATT  AGT  ACT  TAT  TTT  TGT  TCT
                GAT  GTT  GGT  GCT  CAT  CTT  CGT  CCT
                AAG  ATG  AGG  ACG  TAG  TTG  TGG  TCG
                GAG  GTG  GGG  GCG  CAG  CTG  CGG  CCG
                AAC  ATC  AGC  ACC  TAC  TTC  TGC  TCC
                GAC  GTC  GGC  GCC  CAC  CTC  CGC  CCC""",
        roll = (0.0, 2.3, 2.7, 4.2, 1.6, 1.6, 4.4, 4.4,
                2.4, 3.0, 3.1, 4.9, 2.7, 1.8, 6.7, 4.4,
                0.6, 0.6, 4.7, 4.7, 2.3, 0.0, 4.2, 2.7,
                4.3, 3.0, 4.4, 6.1, 5.4, 4.2, 4.4, 4.4,
                4.2, 5.4, 4.4, 4.4, 1.8, 2.7, 4.4, 6.7,
                4.4, 5.3, 4.9, 6.1, 3.4, 3.4, 3.8, 3.8,
                3.0, 4.3, 6.1, 4.4, 3.0, 2.4, 4.9, 3.1,
                4.4, 4.4, 8.1, 8.1, 5.3, 4.4, 6.1, 4.9,),
        twist = (34.3, ) * 64,
        tilt =  (0.0, ) * 64,
        rise = 3.38)

    def __init__(self, model=None, **kwargs):
        """Initialize instance from predefined model, file, or arguments.

        Arguments
        ---------

        model : various types
            Name of predefined model : str
                'aawedge', 'trifonov', 'desantis', 'calladine', 'straight'
            Class or Dict:
                Instance containing model parameters
            Path name: str
                File containing model parameters
            None:
                Default model 'straight'

        name : str
            Human readable label.

        oligos : str or tuple
            Oligonucleotide sequences separated by whitespace or as tuple.

        twist : sequence of floats
            Twist values for given oligonucleotides in degrees.

        roll : sequence of floats
            Roll values for given oligonucleotides in degrees.

        tilt : sequence of floats
            Tilt values for given oligonucleotides in degrees.

        rise : float
            Rise value.

        """
        if model:
            for importfunction in (self._fromname, self._fromdict,
                                   self._fromclass, self._fromfile):
                try:
                    # import functions return dictionary or raise exception
                    model = importfunction(model)
                    break
                except Exception:
                    pass
            else:
                raise ValueError("Can not initialize Model from %s" % model)
        else:
            model = Model.straight

        model.update(kwargs)

        try:
            self.oligos = model["oligo"].split()
        except Exception:
            self.oligos = model["oligo"]
        self.order = len(self.oligos[0])
        self.name = str(model["name"][:32])
        self.rise = float(model["rise"])
        self.twist = dict(zip(self.oligos, model["twist"]))
        self.roll = dict(zip(self.oligos, model["roll"]))
        self.tilt = dict(zip(self.oligos, model["tilt"]))

        self.matrices = {}
        for oligo in oligonucleotides(self.order):
            if not oligo in self.twist:
                c = complementary(oligo)
                self.twist[oligo] = self.twist[c]
                self.roll[oligo] = self.roll[c]
                self.tilt[oligo] = -self.tilt[c] # tilt reverses sign
            self.matrices[oligo] = dinucleotide_matrix(self.rise,
                self.twist[oligo], self.roll[oligo], self.tilt[oligo]).T
        self.matrices[None] = dinucleotide_matrix(self.rise, 34.3, 0.0, 0.0).T

    def __str__(self):
        """Return string representation of model."""
        if (self.order % 2):
            oligos = list(oligonucleotides(self.order))
        else:
            oligos = list(unique_oligos(self.order))

        def format(items, formatstr="%5.2f", sep="  "):
            items = [formatstr % item for item in items]
            return "\n        ".join(sep.join(line) for line in chunks(items))

        result = ["%s\nRise    %.2f" % (self.name.split('\n')[0], self.rise)]
        result.append("Oligo   " + format(oligos, "%s", " " * (7-self.order)))
        result.append("Twist   " + format(self.twist[i] for i in oligos))
        result.append("Roll    " + format(self.roll[i] for i in oligos))
        result.append("Tilt    " + format(self.tilt[i] for i in oligos))
        return "\n".join(result)

    def _fromfile(self, path):
        """Return model parameters as dict from file."""
        d = {}
        with open(path, 'r') as fd:
            d["name"] = fd.readline().rstrip()
            d["rise"] = float(fd.readline().split()[-1])

            def readtuple(format, line):
                alist = [format(i) for i in line.split()[1:]]
                while 1:
                    line = fd.readline()
                    if line.startswith("     "):
                        alist.extend(format(i) for i in line.split())
                    else:
                        break
                return tuple(alist), line

            d["oligo"], line = readtuple(str, fd.readline())
            d["twist"], line = readtuple(float, line)
            d["roll"], line = readtuple(float, line)
            d["tilt"], line = readtuple(float, line)
        return d

    def _fromname(self, name):
        """Return predefined model parameters as dict."""
        return getattr(Model, name.lower())

    def _fromclass(self, aclass):
        """Return model parameters as dict from class."""
        return dict((a, getattr(aclass, a)) for a in Model.straight.keys())

    def _fromdict(self, adict):
        """Return model parameters as dict from dictionary."""
        for attr in Model.straight.keys():
            adict[attr]
        return adict

    def save(self, path):
        """Save model to file."""
        with open(path, "w") as fd:
            fd.write(str(self))


class Sequence(object):
    """DNA nucleotide sequence.

    Instance Attributes
    -------------------

    name : str
        Human readable label.

    comment : str
        Single line description of sequence.

    Examples
    --------

    >>> Sequence("0AxT-C:G a`t~c&g\t\r")[:]
    'ATCGATCG'
    >>> seq = Sequence("ATGCAAATTG"*3, name="Test")
    >>> seq == "ATGCAAATTG"*3
    True
    >>> seq == None
    False
    >>> seq.save("_test.seq")
    >>> seq == Sequence("_test.seq")
    True

    """
    kinetoplast = """
        GATCTAGACT AGACGCTATC GATAAAGTTT AAACAGTACA ACTATCGTGC TACTCACCTG
        TTGCCAAACA TTGCAAAAAT GCAAAATTGG GCTTGTGGAC GCGGAGAGAA TTCCCAAAAA
        TGTCAAAAAA TAGGCAAAAA ATGCCAAAAA TCCCAAACTT TTTAGGTCCC TCAGGTAGGG
        GCGTTCTCCG AAAACCGAAA AATGCATGCA GAAACCCCGT TCAAAAATCG GCCAAAATCG
        CCATTTTTTC AATTTTCGTG TGAAACTAGG GGTTGGTGTA AAATAGGGGT GGGGCTCCCC
        GGGGTAATTC TGGAAATTCG GGCCCTCAGG CTAGACCGGT CAAAATTAGG CCTCCTGACC
        CGTATATTTT TGGATTTCTA AATTTTGTGG CTTTAGATGT GGGAGATTTG """
    out_of_phase_AAAAAA = "CGCGCGCAAAAAACG"
    phased_AAAAAA = "CGAAAAAACG"
    phased_GGGCCC = "GAGGGCCCTA"

    def __init__(self, arg, name="Untitled", comment=""):
        """Initialize instance from nucleotide sequence string or file name."""
        self.name = name
        self.comment = comment
        if os.path.isfile(arg):
            self._fromfile(arg)
            self.fname = os.path.split(arg)[1]
        else:
            self._sequence = arg
            self.fname = None

        self.name = str(self.name.split("\n")[0].strip())[:32]
        self.comment = comment.split("\n")[0]

        # remove all but ATCG from sequence
        nucls = dict(zip("ATCGatcg", "ATCGATCG"))
        self._sequence = "".join(nucls.get(c, "") for c in self._sequence)
        if not self._sequence:
            raise ValueError("Not a valid sequence.")

    def __getitem__(self, key):
        """Return nucleotide at position."""
        return self._sequence[key]

    def __len__(self):
        """Return number of nucleotides in the sequence."""
        return len(self._sequence)

    def __iter__(self):
        """Return iterator over nucleotides."""
        return iter(self._sequence)

    def __eq__(self, other):
        """Return result of sequence comparison."""
        try:
            return self._sequence == other[:]
        except Exception:
            return False

    def __str__(self):
        """Return string representation of sequence."""
        return "%s\n%s\n%s" % (self.name, self.comment, self.format())

    def _fromfile(self, path, maxsize=1024*1024):
        """Read name, comment and sequence from file."""
        with open(path, "r") as fd:
            self.name = fd.readline().rstrip()
            self.comment = fd.readline().rstrip()
            self._sequence = fd.read(maxsize)

    def save(self, path):
        """Save sequence to file."""
        with open(path, "w") as fd:
            fd.write(str(self))

    def format(self, block=10, line=6):
        """Return string of sequence formated in blocks and lines."""
        lines = chunks(chunks(self._sequence, block), line)
        format = "%%%ii %%s" % (len("%i" % ((len(lines)-1)*block*line, )), )
        for i, l in enumerate(lines):
            lines[i] = format % (i*line*block, " ".join(l))
        return "\n".join(lines)


class CurvedDNA(object):
    """Calculate DNA helix coordinates, local bending and curvature.

    Instance Attributes
    -------------------

    model : Instance of Model class

    sequence : Instance of Sequence class

    coordinates : 3D Numpy array
        Homogeneous coordinates at each nucleotide of:
        Index 0) helix axis.
        Index 1) phosphate of 5'-3' strand.
        Index 2) phosphate of antiparallel strand.
        Index 3) basepair normal vector.
        Index 4) smoothed basepair normal vector.

    curvature : 2D Numpy array
        Values at each nucleotide, normalized relative to curvature in
        nucleosome:
        Index 0) curvature.
        Index 1) local bend angle.
        Index 2) curvature angle.

    windows : sequence of int
        Window sizes for calculating curvature, local bend angle, and
        curvature angle.

    scales : 2D Numpy array
        Scaling factors used to normalize curvature array.

    Notes
    -----

    Atomic coordinates are centered at origin and oriented such that:
    (1) helix-axis endpoints lie on x-axis and
    (2) maximum deviation of DNA- from x-axis is along the z-axis.

    The **curvature** at nucleotide N is one over the radius of a
    circle passing through helix axis coordinates N-window, N, and
    N+window, which are separated by one respectively two helix turns.
    The three points define a triangle.  The radius is the product of
    the length of the triangle sides divided by twice the area of
    the triangle.  A window size of 10 is optimal for B-DNA.

    The **local bend angle** at nucleotide N is the angle between the
    normal vectors of basepairs N-window and N+window.  The window size
    should be one or two.

    The **curvature angle** at nucleotide N is the angle between the
    smoothed normal vectors of basepair N-window and N+window.
    The window size should be in the order of 15.

    The curvature and bend values are normalized relative to the
    DNA curvature in a nucleosome (0.0234).

    Examples
    --------

    See module examples.

    """
    p_coord = ( # cylindrical coordinates of 5' phosphate
         8.91,  # distance from axis
        -5.2,   # angle to roll axis
         2.08)  # distance from bp plane

    def __init__(self, sequence, model="trifonov", name="Untitled",
                 curvature_window=10, bend_window=2, curve_window=15):
        """Initialize instance from sequence and model.

        Arguments
        ---------

        sequence : various types
            Sequence instance, file name, or nucleotide sequence.
            See Sequence constructor documentation.

        model : various types
            Model instance, file name, class, dict, or name of
            predefined model. See Model constructor documentation.

        name: str
            Optional human readable label.

        curvature_window : int
            Window size for calculating the curvature (default 10).

        bend_window : int
            Window size for calculating local bend angles (default 2).

        curve_window : int
            Window size for calculating curvature angles (default 15).

        """
        self.model = model if isinstance(model, Model) else Model(model)
        self.sequence = sequence \
            if isinstance(sequence, Sequence) else Sequence(sequence, name)
        if len(self.sequence) < self.model.order:
            raise ValueError("Sequence must be >%i nucleotides long." % \
                             self.model.order)

        assert 0 < curvature_window < 21
        assert 0 < bend_window < 4
        assert 9 < curve_window < 21
        self.windows = [curvature_window, bend_window, curve_window]
        self._limits = [10., 10., 10.]

        self.date = datetime.datetime.now()
        self.coordinates = numpy.zeros((5, len(self), 4), dtype=numpy.float64)
        self.curvature = numpy.zeros((3, len(self)), dtype=numpy.float64)
        self.scales = numpy.ones((3, 1), dtype=numpy.float64)

        self._coordinates()
        self._reorient()
        self._center()
        self._curvature()

    def __len__(self):
        """Return number of nucleotides in sequence."""
        return len(self.sequence)

    def __str__(self):
        """Return string representation of sequence and model."""
        return "%s\n\n%s\n" % (str(self.sequence), str(self.model))

    def _coordinates(self):
        """Calculate coordinates and normal vectors from sequence and model."""
        p = self.p_coord
        p = numpy.array((p[0] * math.cos(math.radians(p[1])),
                         p[0] * math.sin(math.radians(p[1])),
                         p[2]))

        xyz = self.coordinates
        xyz[0:3, :, 3] = 1.0 # homogeneous coordinates
        xyz[1, :, 0:3] = p # 5' phosphate
        xyz[2, :, 0:3] = -p[0], p[1], -p[2] # phosphate of antiparallel strand
        xyz[3, :, 2] = 1.0 # basepair normal vectors

        matrices = self.model.matrices
        for i, seq in enumerate(dinuc_window(self.sequence, self.model.order)):
            xyz[:4, :i+1, :] = numpy.dot(xyz[:4, :i+1, :], matrices[seq])

        # Average direction vector of one helix turn,
        # calculated by smoothing the basepair normals
        if len(self.sequence) > 10:
            kernel = numpy.array([.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, .5])
            kernel /= kernel.sum()
            for i in 0, 1, 2:
                xyz[4, :, i] = numpy.convolve(kernel, xyz[3, :, i], "same")
            for i in range(5, len(self)-5):
                xyz[4, i, :] /= vector_length(xyz[4, i, :])

    def _reorient(self):
        """Reorient coordinates."""
        xyz = self.coordinates[0, :, 0:3] # helix axis
        xyz = xyz - xyz[-1]
        # assert start point is at origin
        assert numpy.allclose(xyz[-1], (0, 0, 0))
        # normalized end to end vector
        e = +xyz[0]
        e_len = vector_length(e)
        e /= e_len
        # point i of maximum distance to end to end line
        x = numpy.cross(e, xyz)
        x = numpy.sum(x*x, axis=1)
        i = numpy.argmax(x)
        x = math.sqrt(x[i])
        # distance of endpoint to xyz[i]
        w = vector_length(xyz[i])
        # distance of endpoint to point on end to end line nearest to xyz[i]
        u = math.sqrt(w*w - x*x)
        # find transformation matrix
        v0 = xyz[[0, i, -1]]
        v1 = numpy.array(((0, 0, 0), (e_len-u, 0, x), (e_len, 0, 0)))
        M = superimpose_matrix(v0, v1)
        self.coordinates = numpy.dot(self.coordinates, M.T)

    def _center(self):
        """Center atomic coordinates at origin."""
        xyz = self.coordinates[0:3, :, 0:3] # helix axis and P atoms
        low = numpy.min(numpy.min(xyz, axis=1), axis=0)
        upp = numpy.max(numpy.max(xyz, axis=1), axis=0)
        self._limits = (upp - low) / 2.0
        self.coordinates[0:3, :, 0:3] -= (low + self._limits)

    def _curvature(self):
        """Calculate normalized curvature and bend angles."""
        # curvature from radius
        window = self.windows[0]
        if len(self) >= 2*window:
            result = self.curvature[0, :]
            <