Commit ac1d0d38 authored by Andrew Quinn's avatar Andrew Quinn
Browse files

Merge branch 'improved_wavelet' into 'master'

Improved wavelet

See merge request !71
parents 8ec9ee8b 9c1b60b1
Loading
Loading
Loading
Loading
Loading
+160 −0
Original line number Diff line number Diff line
Tutorial 11 - Morlet Wavelet Decomposition
=======================================================

In this tutorial, we will look at describing time-frequency dynamics in a
signal using a Morlet Wavelet decomposition.

For this tutorial, we will use the same MEG example data which we have used in previous
tutorials.

We start by importing our modules and finding and loading the example data.

.. code-block:: python

    import os
    from os.path import join

    import h5py

    import numpy as np
    import matplotlib.pyplot as plt

    import sails

    plt.style.use('ggplot')


SAILS will automatically detect the example data if downloaded into your home
directory. If you've used a different location, you can specify this in an
environment variable named ``SAILS_EXAMPLE_DATA``.

.. code-block:: python

    # Specify environment variable with example data location
    # This is only necessary if you have not checked out the
    # example data into $HOME/sails-example-data
    #os.environ['SAILS_EXAMPLE_DATA'] = '/path/to/sails-example-data'

    # Locate and validate example data directory
    example_path = sails.find_example_path()

    # Load data using h5py
    data_path = os.path.join(sails.find_example_path(), 'meg_occipital_ve.hdf5')
    X = h5py.File(data_path, 'r')['X'][:, 122500:130000, 0]

    sample_rate = 250

    time_vect = np.linspace(0, X.shape[1] // sample_rate, X.shape[1])

The Morlet Wavelet transform first defines a set of wavelet functions to use as
am adaptive basis set. These wavelets are simple burst-like oscillations
created to a pre-defined set of parameters.

.. code-block:: python

    # Wavelet frequencies in Hz
    freqs = [10]
    # Number of cycles within the oscillatory event
    ncycles = 5
    # Length of window in seconds
    win_len = 10

    # Compute wavelets
    mlt = sails.wavelet.get_morlet_basis(freqs, ncycles, win_len, sample_rate, normalise=False)

``mlt`` is now a list of wavelet basis functions. In this case, we have a
single 10Hz basis. This wavelet is a complex-valued array (in the same way that
a Fourier transform returns a complex valued result). To visualise the wavelet,
we can plot the real and imaginary parts of the complex function.

.. code-block:: python

    plt.figure()
    plt.plot(mlt[0].real)
    plt.plot(mlt[0].imag)
    plt.legend(['Real', 'Imag'])

Note that the real and imaginary components are the same apart from a 90 degree
phase shift in the time-series. This shift allow the wavelet transform to
estimate the full amplitude envelope and phase of the underlying signal.

The ``ncycles`` parameter is critical for defining the time-frequency
resolution of the wavelet transform. A small value (typically less than 5) will
lead to high temporal resolution but low frequency resolution whilst a large
value (typically greater than 7) will have a low temporal resolutoin and a high
frequency resolution. We will look into this in more detail later - for now, we
can see that this parameter simply changes the number of cycles of oscillation
present in our basis wavelets.

.. code-block:: python

    # Compute wavelets
    mlt3 = sails.wavelet.get_morlet_basis(freqs, 3, win_len, sample_rate, normalise=False)
    mlt5 = sails.wavelet.get_morlet_basis(freqs, 5, win_len, sample_rate, normalise=False)
    mlt7 = sails.wavelet.get_morlet_basis(freqs, 7, win_len, sample_rate, normalise=False)

    plt.figure()
    for idx, mlt in enumerate([mlt3, mlt5, mlt7]):
        y = mlt[0].real
        t = np.arange((len(y))) - len(y)/2 # zero-centre the wavelet
        plt.plot(t, y + idx*2)
    plt.legend(['3-cycles', '5-cycles', '7-cycles'])

We can see that the frequency of the oscillation in each wavelet is unchanged
whilst the number of cycles is modified by changing ``ncycles``.

We will often compute wavelets for a range of frequencies rather than just one.
Here we pass in an array of frequency values to compute wavelets for.

.. code-block:: python

    freqs = [3, 6, 9, 12, 15]

    # Compute wavelets
    mlt = sails.wavelet.get_morlet_basis(freqs, ncycles, win_len, sample_rate, normalise=False)

    plt.figure()
    for ii in range(len(freqs)):
        y = mlt[ii].real
        t = np.arange((len(y))) - len(y)/2 # zero-centre the wavelet
        plt.plot(t, y + ii*2)
    plt.legend(freqs)

This time, we see that changing frequency keeps a consistent number of cycles
in each wavelet but modifies the oscillatory period.

To compute the wavelet transform itself, each wavelet basis function is
convolved across the dataset. In this instance (as the wavelet function is
symmetric and the input time-series are real values) this convolution is
similar to computing the correlation between the basis function and the
time-series at each point in time.

Let's compute the wavelet transform at 10Hz on our real data.

.. code-block:: python

    freqs = [10]
    cwt = sails.wavelet.morlet(X[0, :], freqs, sample_rate)

    plt.figure()
    plt.subplot(211)
    plt.plot(X.T)
    plt.subplot(212)
    plt.plot(cwt.T)

We can see that the wavelet power tracks the amplitude of the oscillations
visible in the original time-series.

Finally, let's compute a full wavelet transform across a wider range of frequencies

.. code-block:: python

    freqs = np.linspace(1, 20, 38)
    cwt = sails.wavelet.morlet(X[0, :], freqs, sample_rate, normalise='tallon')

    plt.figure()
    plt.subplot(211)
    plt.plot(time_vect, X.T)
    plt.xlim(time_vect[0], time_vect[-1])
    plt.subplot(212)
    plt.pcolormesh(time_vect, freqs, cwt)
+3 −2
Original line number Diff line number Diff line
@@ -13,10 +13,11 @@ class MorletBasisTests(unittest.TestCase):
        sample_rate = 128
        freqs = np.linspace(1, 50, 49)
        ncycles = 5
        win_len = 5

        from ..wavelet import get_morlet_basis

        mlt = get_morlet_basis(freqs, ncycles, sample_rate)
        mlt = get_morlet_basis(freqs, ncycles, win_len, sample_rate)

        # Check we have same number of morlet basis waves as input freqs
        assert(len(mlt) == freqs.shape[0])
@@ -26,7 +27,7 @@ class MorletBasisTests(unittest.TestCase):
        assert(np.all(np.diff([len(m) for m in mlt]) <= 0))

        # Example with faster sample_rate
        mlt2 = get_morlet_basis(freqs, ncycles, sample_rate*2)
        mlt2 = get_morlet_basis(freqs, ncycles, win_len, sample_rate*2)
        check = np.c_[[len(m) for m in mlt], [len(m) for m in mlt2]]

        # Morlets from double sample rate should be twice as long (or off-by-one)
+106 −17
Original line number Diff line number Diff line
from scipy import signal
#!/usr/bin/python

# vim: set expandtab ts=4 sw=4:

import numpy as np
from scipy import signal


def morlet(x, freqs, sample_rate, window_len=4, ncycles=5, ret_basis=False,
           ret_mode='power', normalise=False):
def morlet(x, freqs, sample_rate, win_len=4, ncycles=5, ret_basis=False, ret_mode='power', normalise='wikipedia'):
    """Compute a morlet wavelet time-frequency transform on a univariate dataset.

    Parameters
@@ -14,7 +17,7 @@ def morlet(x, freqs, sample_rate, window_len=4, ncycles=5, ret_basis=False,
        Array of frequency values in Hz
    sample_rate : scalar
        Sampling frequency of data in Hz
    window_len : scalar
    win_len : scalar
        Length of wavelet window
    ncycles : int
        Width of wavelets in number of cycles
@@ -22,22 +25,34 @@ def morlet(x, freqs, sample_rate, window_len=4, ncycles=5, ret_basis=False,
        Boolean flag indicating whether to return the basis set alongside the transform.
    ret_mode : {'power', 'amplitude', 'complex'}
        Flag indicating whether which form of the wavelet transform to return.
    normalise : {None, 'simple', 'tallon', 'wikipedia', 'mne'}
        Flag indicating which normalisation factor to apply to the wavelet
        basis. See `sails.wavelet.get_morlet_basis` for details.
        Default = 'wikipedia'.

    Returns
    -------
    2D array
        Array containing morlet wavelet transformed data [nfreqs x nsamples]

    """
    orig_dim = x.ndim
    if orig_dim == 1:
        x = x[np.newaxis, :]

    cwt = np.zeros((len(freqs), *x.shape[:]), dtype=complex)
    if 0 in freqs:
        raise ValueError("0 cannot be in freqs.")

    cwt = np.zeros((x.shape[0], len(freqs), x.shape[1]), dtype=complex)

    # Get morlet basis
    mlt = get_morlet_basis(freqs, ncycles, sample_rate, normalise=normalise)
    mlt = get_morlet_basis(freqs, ncycles, win_len, sample_rate, normalise)

    for jj in range(x.shape[0]):
        for ii in range(len(freqs)):
        a = signal.convolve(x, mlt[ii].real, mode='same', method='fft')
        b = signal.convolve(x, mlt[ii].imag, mode='same', method='fft')
        cwt[ii, ...] = a+1j*b
            a = signal.convolve(x[jj, :], mlt[ii].real, mode='same', method='fft')
            b = signal.convolve(x[jj, :], mlt[ii].imag, mode='same', method='fft')
            cwt[jj, ii, :] = a+1j*b

    if ret_mode == 'power':
        cwt = np.power(np.abs(cwt), 2)
@@ -46,13 +61,68 @@ def morlet(x, freqs, sample_rate, window_len=4, ncycles=5, ret_basis=False,
    elif ret_mode != 'complex':
        raise ValueError("'ret_mode not recognised, please use one of {'power','amplitude','complex'}")

    if orig_dim == 1:
        cwt = cwt[0, ...]

    if ret_basis:
        return cwt, mlt
    else:
        return cwt


def get_morlet_basis(freq, ncycles, sample_rate, normalise=False, win_len=5):
def cross_morlet(x, freqs, sample_rate, win_len=4, ncycles=5, ret_mode='power', normalise='wikipedia'):
    """Compute a morlet cross wavelet time-frequency transform on a multivariate dataset.

    Parameters
    ----------
    x : vector array_like
        Time-series to compute cross wavelet transform from.
    freqs : array_like
        Array of frequency values in Hz.
    sample_rate : scalar
        Sampling frequency of data in Hz.
    win_len : scalar
        Length of wavelet window.
    ncycles : int
        Width of wavelets in number of cycles.
    ret_mode : {'power', 'amplitude', 'complex'}
        Flag indicating whether which form of the wavelet transform to return.
    normalise : {None, 'simple', 'tallon', 'wikipedia', 'mne'}
        Flag indicating which normalisation factor to apply to the wavelet
        basis. See `sails.wavelet.get_morlet_basis` for details.
        Default = 'wikipedia'.

    Returns
    -------
    4D array
        Array containing morlet cross wavelet transformed data [nfreqs x nsamples x nchannels x nchannels].

    """
    if ret_mode not in ['power', 'amplitude', 'complex']:
        raise ValueError("'ret_mode not recognised, please use one of {'power','amplitude','complex'}")

    # Run standard wavelet decomposition return complex values
    wt = morlet(x, freqs, sample_rate, win_len=win_len, ncycles=ncycles, ret_mode='complex', normalise=normalise)

    # Preallocate output array [nchannels x nchannels x nfreqs x ntimes]
    S = np.empty((wt.shape[0], wt.shape[0], wt.shape[1], wt.shape[2]), dtype=complex)

    # Main loop
    for ii in range(wt.shape[1]):
        for jj in range(wt.shape[2]):
            S[:, :, ii, jj] = np.dot(wt[:, ii, jj, np.newaxis], wt[np.newaxis, :, ii, jj].conj())

    if ret_mode == 'power':
        S = np.power(np.abs(S), 2)
    elif ret_mode == 'amplitude':
        S = np.abs(S)
    elif ret_mode != 'complex':
        raise ValueError("'ret_mode not recognised, please use one of {'power', 'amplitude', 'complex'}")

    return S


def get_morlet_basis(freq, ncycles, win_len, sample_rate, normalise='wikipedia'):
    """Compute a morlet wavelet basis set based on specified parameters.

    Parameters
@@ -61,20 +131,38 @@ def get_morlet_basis(freq, ncycles, sample_rate, normalise=False, win_len=5):
        Array of frequency values in Hz
    ncycles : int
        Width of wavelets in number of cycles
    win_len : scalar
        Length of wavelet window
    sample_rate : scalar
        Sampling frequency of data in Hz
    normalise : {None, 'simple', 'tallon', 'wikipedia', 'mne'}
        Flag indicating which normalisation factor to apply to the wavelet basis.
    win_len : float
        Window length duration factor
        Flag indicating which normalisation factor to apply to the wavelet
        basis (default is 'wikipedia') - can be one of:

        * None - no normalisation is applied

        * 'simple' - wavelet is normalised by its own sum

        * 'tallon' - normalisation from Tallon-Baudry et al 1997

        * 'wikipedia' normalisation from https://en.wikipedia.org/wiki/Morlet_wavelet

        * 'mne' - normalisation used in MNE-Python

    Returns
    -------
    list of vector arrays
        Complex valued arrays containing morlet wavelets

    """
    References
    ----------
    .. [1] Tallon-Baudry, C., Bertrand, O., Delpuech, C., & Pernier, J. (1997).
       Oscillatory γ-Band (30–70 Hz) Activity Induced by a Visual Search Task in
       Humans. In The Journal of Neuroscience (Vol. 17, Issue 2, pp. 722–734).
       Society for Neuroscience.
       https://doi.org/10.1523/jneurosci.17-02-00722.1997

    """
    m = []
    for ii in range(len(freq)):
        # Sigma controls the width of the gaussians applied to each wavelet. This
@@ -112,4 +200,5 @@ def get_morlet_basis(freq, ncycles, sample_rate, normalise=False, win_len=5):
            mlt = mlt / (np.sqrt(0.5) * np.linalg.norm(mlt.ravel()))

        m.append(mlt)

    return m