Commit ca781017 authored by cgohil8's avatar cgohil8 Committed by Andrew Quinn
Browse files

Minor improvements and fixed bugs in wavelet function.

parent d379154c
Loading
Loading
Loading
Loading
+35 −10
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ from scipy import signal
import numpy as np


def morlet(x, freqs, sample_rate, window_len=4, ncycles=5, ret_basis=False,
def morlet(x, freqs, sample_rate, win_len=4, ncycles=5, ret_basis=False,
           ret_mode='power', normalise=False):
    """Compute a morlet wavelet time-frequency transform on a univariate dataset.

@@ -14,7 +14,7 @@ def morlet(x, freqs, sample_rate, window_len=4, ncycles=5, ret_basis=False,
        Array of frequency values in Hz
    sample_rate : scalar
        Sampling frequency of data in Hz
    window_len : scalar
    win_len : scalar
        Length of wavelet window
    ncycles : int
        Width of wavelets in number of cycles
@@ -22,6 +22,8 @@ def morlet(x, freqs, sample_rate, window_len=4, ncycles=5, ret_basis=False,
        Boolean flag indicating whether to return the basis set alongside the transform.
    ret_mode : {'power','amplitude','complex'}
        Flag indicating whether which form of the wavelet transform to return.
    normalise : {None,'simple','tallon','wikipedia','mne'}
        Flag indicating which normalisation factor to apply to the wavelet basis.

    Returns
    -------
@@ -32,10 +34,13 @@ def morlet(x, freqs, sample_rate, window_len=4, ncycles=5, ret_basis=False,
    if orig_dim == 1:
        x = x[np.newaxis, :]

    if 0 in freqs:
        raise ValueError("0 cannot be in freqs.")

    cwt = np.zeros((x.shape[0], len(freqs), x.shape[1]), dtype=complex)

    # Get morlet basis
    mlt = get_morlet_basis(freqs, ncycles, sample_rate, normalise=normalise)
    mlt = get_morlet_basis(freqs, ncycles, win_len, sample_rate, normalise)

    for jj in range(x.shape[0]):
        for ii in range(len(freqs)):
@@ -59,19 +64,39 @@ def morlet(x, freqs, sample_rate, window_len=4, ncycles=5, ret_basis=False,
        return cwt


def wavelet_csd(x, freqs, sample_rate):
def cross_morlet(x, freqs, sample_rate, win_len=4, ncycles=5, normalise=False):
    """Compute a morlet cross wavelet time-frequency transform on a multivariate dataset.

    Parameters
    ----------
    x : vector array_like
        Time-series to compute cross wavelet transform from.
    freqs : array_like
        Array of frequency values in Hz.
    sample_rate : scalar
        Sampling frequency of data in Hz.
    win_len : scalar
        Length of wavelet window.
    ncycles : int
        Width of wavelets in number of cycles.
    normalise : {None,'simple','tallon','wikipedia','mne'}
        Flag indicating which normalisation factor to apply to the wavelet basis.

    wt = morlet(x, freqs, sample_rate, ret_mode='complex')
    Returns
    -------
    4D array
        Array containing morlet cross wavelet transformed data [nfreqs x nsamples x nchannels x nchannels].
    """
    wt = morlet(x, freqs, sample_rate, win_len=win_len, ncycles=ncycles,
                ret_mode='complex', normalise=normalise)
    S = np.zeros((wt.shape[0], wt.shape[0], wt.shape[1], wt.shape[2]), dtype=complex)
    for ii in range(wt.shape[1]):
        for jj in range(wt.shape[2]):
            S[:, :, ii, jj] = np.dot(wt[:, ii, jj, np.newaxis], wt[np.newaxis, :, ii, jj].conj())

    return S



def get_morlet_basis(freq, ncycles, sample_rate, normalise=False, win_len=5):
def get_morlet_basis(freq, ncycles, win_len, sample_rate, normalise):
    """Compute a morlet wavelet basis set based on specified parameters.

    Parameters
@@ -80,12 +105,12 @@ def get_morlet_basis(freq, ncycles, sample_rate, normalise=False, win_len=5):
        Array of frequency values in Hz
    ncycles : int
        Width of wavelets in number of cycles
    win_len : scalar
        Length of wavelet window
    sample_rate : scalar
        Sampling frequency of data in Hz
    normalise : {None,'simple','tallon','wikipedia','mne'}
        Flag indicating which normalisation factor to apply to the wavelet basis.
    win_len : float
        Window length duration factor

    Returns
    -------