/* fft2d.cpp

Implementation of the rfft3d class for the 'fft2d' DLL.

The rfft3d class implements fast 3D auto- and cross-correlation using the 
`fft2d library by Takuya Ooura <http://www.kurims.kyoto-u.ac.jp/~ooura/fft.html>`_.

Refer to the header file 'fft2d.h' for documentation.

:Author: 
  `Christoph Gohlke <http://www.lfd.uci.edu/~gohlke/>`_
   
:Organization:
  Laboratory for Fluorescence Dynamics, University of California, Irvine

:Version: 2015.07.17
   
License
-------
Copyright (c) 2015, Christoph Gohlke
Copyright (c) 2015, The Regents of the University of California
Produced at the Laboratory for Fluorescence Dynamics.
All rights reserved.
   
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the copyright holders nor the names of any
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
*/

#include <math.h>
#include <algorithm> 

#include "fft2d.h"

extern "C" FFT2D_API char* FFT2D_VERSION = FFT2D_VERSION_STR;
extern "C" FFT2D_API int FFT2D_THREADS = FFT2D_MAX_THREADS;
extern "C" FFT2D_API int FFT3D_THREADS = FFT3D_MAX_THREADS;

/** Helper functions **/

/* Return if number is a power of 2 */
inline 
int ispow2(ssize_t n) {
    return (n > 1) & ((n & (n - 1)) == 0);
}

/* Multiply `a` by its complex conjugate */
inline 
void complex_multiply(double*** a, const ssize_t ni, const ssize_t nj, const ssize_t nk)
{
    for (ssize_t i = 0; i < ni; i++) {
        for (ssize_t j = 0; j < nj; j++) {
            for (ssize_t k = 0; k < nk; k += 2) {
                double re = a[i][j][k];
                double im = a[i][j][k + 1];
                a[i][j][k] = re*re + im*im;
                a[i][j][k + 1] = 0.0;
            }
        }
    }
}

/* Multiply `a` by `b`'s complex conjugate */
inline 
void complex_multiply(double*** a, double*** b, const ssize_t ni, const ssize_t nj, const ssize_t nk)
{
    for (ssize_t i = 0; i < ni; i++) {
        for (ssize_t j = 0; j < nj; j++) {
            for (ssize_t k = 0; k < nk; k += 2) {
                double br = b[i][j][k];
                double bi = b[i][j][k + 1];
                double ar = a[i][j][k];
                double ai = a[i][j][k + 1];
                a[i][j][k] = ar * br + ai * bi;
                a[i][j][k + 1] = ai * br - ar * bi;
            }
        }
    }
}

/** C++ API **/

/* Class constructor */
rfft3d::rfft3d(ssize_t shape0, ssize_t shape1, ssize_t shape2, int mode)
{
    if ((shape0 < 2) || (shape0 > 2147483646) || !ispow2(shape0) ||
        (shape1 < 2) || (shape1 > 2147483646) || !ispow2(shape1) ||
        (shape2 < 2) || (shape2 > 2147483646) || !ispow2(shape2)) {
        throw FFT2D_VALUE_ERROR;
    }

    mode_ = mode;
    n0_ = shape0;
    n1_ = shape1;
    n2_ = shape2;
    np_ = n0_ * n1_ * n2_;

    a_ = NULL;
    b_ = NULL;
    w_ = NULL;
    t_ = NULL;
    ip_ = NULL;

    alloc_();
}

/* Class destructor */
rfft3d::~rfft3d()
{
    free_();
}

/* Allocate internal buffers for fft2d */
void rfft3d::alloc_() {
    /* input/output data */
    a_ = alloc_3d_double((int)n0_, (int)n1_, (int)n2_ + 2);
    if (a_ == NULL) {
        throw FFT2D_MEMORY_ERROR;
    }
    if (mode_ & FFT3D_MODE_CC) {
        b_ = alloc_3d_double((int)n0_, (int)n1_, (int)n2_ + 2);
        if (b_ == NULL) {
            free_();
            throw FFT2D_MEMORY_ERROR;
        }
    }

    /* work area for bit reversal */
    ssize_t nt = std::max(n0_, n1_);
    ssize_t n = std::max(nt, n2_ / 2);
    ip_ = alloc_1d_int(2 + (int)sqrt(n + 0.5));
    if (ip_ == NULL) {
        free_();
        throw FFT2D_MEMORY_ERROR;
    }
    ip_[0] = 0;

    /* cos/sin table */
    n = (std::max(nt, n2_) * 3) / 2;
    w_ = alloc_1d_double((int)n);
    if (w_ == NULL) {
        free_();
        throw FFT2D_MEMORY_ERROR;
    }

    /* work area */
    n = std::max(n0_ * 8, n1_ * 8) * FFT3D_MAX_THREADS;
    t_ = alloc_1d_double((int)n);
    if (t_ == NULL) {
        free_();
        throw FFT2D_MEMORY_ERROR;
    }
}

/* Free internal buffers */
void rfft3d::free_() {
    if (t_ != NULL)
        free_1d_double(t_);
    if (w_ != NULL)
        free_1d_double(w_);
    if (ip_ != NULL)
        free_1d_int(ip_);
    if (a_ != NULL)
        free_3d_double(a_);
    if (b_ != NULL)
        free_3d_double(b_);

    a_ = NULL;
    b_ = NULL;
    w_ = NULL;
    t_ = NULL;
    ip_ = NULL;
}

/* 3D auto correlation */
template <typename Ti, typename To>
void rfft3d::autocorrelate(Ti* data, To* out, ssize_t* strides)
{
    double scale, offset;

    /* copy data to a_ */
    double sum = copy_input_(a_, data, strides);

    /* real dft */
    rdft3d((int)n0_, (int)n1_, (int)n2_, 1, a_, t_, ip_, w_);
    rdft3dsort((int)n0_, (int)n1_, (int)n2_, 1, a_);

    /* multiply by complex conjugate */
    complex_multiply(a_, n0_, n1_, n2_ + 2);

    /* inverse dft */
    rdft3dsort((int)n0_, (int)n1_, (int)n2_, -1, a_);
    rdft3d((int)n0_, (int)n1_, (int)n2_, -1, a_, t_, ip_, w_);

    /* copy, scale and center data */
    if (mode_ & FFT3D_MODE_FCS) {
        scale = 2.0 / sum / sum;
        offset = -1.0;
    }
    else {
        scale = 2.0 / np_;
        offset = 0.0;
    }
    copy_output_(a_, out, scale, offset);
}

/* 3D cross correlation */
template <typename Ti, typename To>
void rfft3d::crosscorrelate(Ti* data0, Ti* data1, To* out, ssize_t* strides0, ssize_t* strides1)
{
    double scale, offset;

    /* allocate b_ if necessary */
    if (b_ == NULL) {
        b_ = alloc_3d_double((int)n0_, (int)n1_, (int)n2_ + 2);
        if (b_ == NULL) {
            throw FFT2D_MEMORY_ERROR;
        }
    }

    /* copy data to a_ and b_ */
    double sum0 = copy_input_(a_, data0, strides0);
    double sum1 = copy_input_(b_, data1, strides1);

    /* real dft */
    rdft3d((int)n0_, (int)n1_, (int)n2_, 1, b_, t_, ip_, w_);
    rdft3dsort((int)n0_, (int)n1_, (int)n2_, 1, b_);

    rdft3d((int)n0_, (int)n1_, (int)n2_, 1, a_, t_, ip_, w_);
    rdft3dsort((int)n0_, (int)n1_, (int)n2_, 1, a_);

    /* multiply by complex conjugate */
    complex_multiply(a_, b_, n0_, n1_, n2_ + 2);

    /* inverse dft */
    rdft3dsort((int)n0_, (int)n1_, (int)n2_, -1, a_);
    rdft3d((int)n0_, (int)n1_, (int)n2_, -1, a_, t_, ip_, w_);

    /* copy, scale and center data */
    if (mode_ & FFT3D_MODE_FCS) {
        scale = 2.0 / sum0 / sum1;
        offset = -1.0;
    }
    else {
        scale = 2.0 / np_;
        offset = 0.0;
    }
    copy_output_(a_, out, scale, offset);
}

/* Copy input array to internal buffer */
template <typename T>
inline
double rfft3d::copy_input_(double*** a, T* data, ssize_t* strides)
{
    const ssize_t s0 = (strides == NULL) ? 0 : strides[0] - n1_ * strides[1];
    const ssize_t s1 = (strides == NULL) ? 0 : strides[1] - n2_ * strides[2];
    const ssize_t s2 = (strides == NULL) ? sizeof(T) : strides[2];
    char *pdata = (char *)data;
    double sum = 0.0;

    for (ssize_t i = 0; i < n0_; i++) {
        for (ssize_t j = 0; j < n1_; j++) {
            for (ssize_t k = 0; k < n2_; k++) {
                double d = (double)*(T *)pdata;
                a[i][j][k] = d;
                sum += d;
                pdata += s2;
            }
            pdata += s1;
        }
        pdata += s0;
    }
    return sum;
}

/* Shift and scale internal buffer to output array */
template <typename T>
inline
void rfft3d::copy_output_(double*** a, T* out, const double scale, const double offset)
{
    const ssize_t n0 = n0_;
    const ssize_t n1 = n1_;
    const ssize_t n2 = n2_;
    const ssize_t h0 = n0 / 2;
    const ssize_t h1 = n1 / 2;
    const ssize_t h2 = n2 / 2;
    const ssize_t n12 = n1 * n2;

    if (mode_ & FFT3D_MODE_TYX) {
        /* do not center temporal dimension */
        for (ssize_t i = 0; i < n0_; i++) {
            ssize_t ni = i * n12;
            for (ssize_t j = 0, jj = h1; j < h1; j++, jj++) {
                ssize_t nj = j * n2;
                ssize_t njj = jj * n2;
                for (ssize_t k = 0, kk = h2; k < h2; k++, kk++) {
                    out[ni + nj + k] = (T)(a[i][jj][kk] * scale + offset);
                    out[ni + njj + kk] = (T)(a[i][j][k] * scale + offset);
                    out[ni + njj + kk] = (T)(a[i][j][k] * scale + offset);
                    out[ni + nj + k] = (T)(a[i][jj][kk] * scale + offset);
                    out[ni + njj + k] = (T)(a[i][j][kk] * scale + offset);
                    out[ni + nj + kk] = (T)(a[i][jj][k] * scale + offset);
                    out[ni + nj + kk] = (T)(a[i][jj][k] * scale + offset);
                    out[ni + njj + k] = (T)(a[i][j][kk] * scale + offset);
                }
            }
        }
    }
    else {
        for (ssize_t i = 0, ii = h0; i < h0; i++, ii++) {
            ssize_t ni = i * n12;
            ssize_t nii = ii * n12;
            for (ssize_t j = 0, jj = h1; j < h1; j++, jj++) {
                ssize_t nj = j * n2;
                ssize_t njj = jj * n2;
                for (ssize_t k = 0, kk = h2; k < h2; k++, kk++) {
                    out[ni + nj + k] = (T)(a[ii][jj][kk] * scale + offset);
                    out[nii + njj + kk] = (T)(a[i][j][k] * scale + offset);
                    out[ni + njj + kk] = (T)(a[ii][j][k] * scale + offset);
                    out[nii + nj + k] = (T)(a[i][jj][kk] * scale + offset);
                    out[ni + njj + k] = (T)(a[ii][j][kk] * scale + offset);
                    out[nii + nj + kk] = (T)(a[i][jj][k] * scale + offset);
                    out[ni + nj + kk] = (T)(a[ii][jj][k] * scale + offset);
                    out[nii + njj + k] = (T)(a[i][j][kk] * scale + offset);
                }
            }
        }
    }
}

/** C API **/

rfft3d_handle rfft3d_new(ssize_t shape0, ssize_t shape1, ssize_t shape2, int mode)
{
    try {
        return reinterpret_cast<rfft3d_handle>(new rfft3d(shape0, shape1, shape2, mode));
    }
    catch (...)  {
        return NULL;
    }
}

void rfft3d_del(rfft3d_handle handle)
{
    try {
        delete reinterpret_cast<rfft3d_handle>(handle);
    }
    catch (...)  {
        ;
    }
}

void rfft3d_init_tables(rfft3d_handle handle)
{
    reinterpret_cast<rfft3d_handle>(handle)->init_tables();
}

int rfft3d_autocorrelate_dd(rfft3d_handle handle, double* data, double* out, ssize_t* strides)
{
    try {
        reinterpret_cast<rfft3d_handle>(handle)->autocorrelate(data, out, strides);
    }
    catch (int e)  {
        return e;
    }
    return 0;
}

int rfft3d_autocorrelate_ff(rfft3d_handle handle, float* data, float* out, ssize_t* strides)
{
    try {
        reinterpret_cast<rfft3d_handle>(handle)->autocorrelate(data, out, strides);
    }
    catch (int e)  {
        return e;
    }
    return 0;
}

int rfft3d_autocorrelate_fd(rfft3d_handle handle, float* data, double* out, ssize_t* strides)
{
    try {
        reinterpret_cast<rfft3d_handle>(handle)->autocorrelate(data, out, strides);
    }
    catch (int e)  {
        return e;
    }
    return 0;
}

int rfft3d_autocorrelate_hf(rfft3d_handle handle, int16_t* data, float* out, ssize_t* strides)
{
    try {
        reinterpret_cast<rfft3d_handle>(handle)->autocorrelate(data, out, strides);
    }
    catch (int e)  {
        return e;
    }
    return 0;
}

int rfft3d_autocorrelate_hd(rfft3d_handle handle, int16_t* data, double* out, ssize_t* strides)
{
    try {
        reinterpret_cast<rfft3d_handle>(handle)->autocorrelate(data, out, strides);
    }
    catch (int e)  {
        return e;
    }
    return 0;
}

int rfft3d_crosscorrelate_dd(rfft3d_handle handle, double* data0, double* data1, double* out, ssize_t* strides0, ssize_t* strides1)
{
    try {
        reinterpret_cast<rfft3d_handle>(handle)->crosscorrelate(data0, data1, out, strides0, strides1);
    }
    catch (int e)  {
        return e;
    }
    return 0;
}

int rfft3d_crosscorrelate_ff(rfft3d_handle handle, float* data0, float* data1, float* out, ssize_t* strides0, ssize_t* strides1)
{
    try {
        reinterpret_cast<rfft3d_handle>(handle)->crosscorrelate(data0, data1, out, strides0, strides1);
    }
    catch (int e)  {
        return e;
    }
    return 0;
}

int rfft3d_crosscorrelate_fd(rfft3d_handle handle, float* data0, float* data1, double* out, ssize_t* strides0, ssize_t* strides1)
{
    try {
        reinterpret_cast<rfft3d_handle>(handle)->crosscorrelate(data0, data1, out, strides0, strides1);
    }
    catch (int e)  {
        return e;
    }
    return 0;
}

int rfft3d_crosscorrelate_hf(rfft3d_handle handle, int16_t* data0, int16_t* data1, float* out, ssize_t* strides0, ssize_t* strides1)
{
    try {
        reinterpret_cast<rfft3d_handle>(handle)->crosscorrelate(data0, data1, out, strides0, strides1);
    }
    catch (int e)  {
        return e;
    }
    return 0;
}

int rfft3d_crosscorrelate_hd(rfft3d_handle handle, int16_t* data0, int16_t* data1, double* out, ssize_t* strides0, ssize_t* strides1)
{
    try {
        reinterpret_cast<rfft3d_handle>(handle)->crosscorrelate(data0, data1, out, strides0, strides1);
    }
    catch (int e)  {
        return e;
    }
    return 0;
}