Commit 9ea1c7ea authored by Andrew Quinn's avatar Andrew Quinn
Browse files

Finish first pass at glm-irasa

parent 1104d509
Loading
Loading
Loading
Loading
+327 −1
Original line number Diff line number Diff line
@@ -42,13 +42,15 @@ Worker functions:
import logging
import typing
import warnings
import fractions
from dataclasses import dataclass
from functools import wraps
from copy import deepcopy

import numpy as np
from scipy import fft as sp_fft
from scipy import stats
from scipy.signal import signaltools
from scipy.signal import signaltools, resample_poly
from scipy.signal.windows import dpss

try:
@@ -1497,6 +1499,138 @@ def multitaper(x, average='mean', num_tapers='auto',
        return f, p


@set_verbose
def irasa(x, resample_factors=None,
          # Periodogram args
          average='mean',
          # General STFT kwargs
          fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None,
          detrend='constant', return_onesided=True, mode='psd', scaling='density',
          axis=-1, fmin=None, fmax=None, return_config=False, verbose=None):
    """Compute Periodogram by averaging across windows in a STFT.

    Parameters
    ----------
    x : array_like
        Time series of measurement values
    average : { 'mean', 'median' }, optional
        Method to use when averaging periodograms. Defaults to 'mean'.
    fs : float, optional
        Sampling frequency of the `x` time series. Defaults to 1.0.
    window_type : str or tuple or array_like, optional
        Desired window to use. If `window` is a string or tuple, it is
        passed to `get_window` to generate the window values, which are
        DFT-even by default. See `get_window` for a list of windows and
        required parameters. If `window` is array_like it will be used
        directly as the window and its length must be nperseg. Defaults
        to a Hann window.
    nperseg : int, optional
        Length of each segment. Defaults to None, but if window is str or
        tuple, is set to 256, and if window is array_like, is set to the
        length of the window.
    noverlap : int, optional
        Number of points to overlap between segments. If `None`,
        ``noverlap = nperseg // 2``. Defaults to `None`.
    nfft : int, optional
        Length of the FFT used, if a zero padded FFT is desired. If
        `None`, the FFT length is `nperseg`. Defaults to `None`.
    detrend : str or function or `False`, optional
        Specifies how to detrend each segment. If `detrend` is a
        string, it is passed as the `type` argument to the `detrend`
        function. If it is a function, it takes a segment and returns a
        detrended segment. If `detrend` is `False`, no detrending is
        done. Defaults to 'constant'.
    return_onesided : bool, optional
        If `True`, return a one-sided spectrum for real data. If
        `False` return a two-sided spectrum. Defaults to `True`, but for
        complex data, a two-sided spectrum is always returned.
    scaling : { 'density', 'spectrum' }, optional
        Selects between computing the power spectral density ('density')
        where `Pxx` has units of V**2/Hz and computing the power
        spectrum ('spectrum') where `Pxx` has units of V**2, if `x`
        is measured in V and `fs` is measured in Hz. Defaults to
        'density'
    axis : int, optional
        Axis along which the periodogram is computed; the default is
        over the last axis (i.e. ``axis=-1``).
    fmin : float or None, optional
        Minimum frequency value to return (Default value = 0)
    fmax : float or None, optional
        Maximum frequency value to return (Default value = 0.5)
    return_config : bool
        Indicate whether parameter configuration object should be returned
        alongside result (Default value = False)

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    config : PeriodogramConfig, optional
        Configuration object containing all parameters used to compute
        spectrum, optionally returned based on value of `return_config`.

    """
    # Config object stores options in one place and sets sensible defaults for
    # unspecified options given the data in-hand
    logging.info('Setting config options')
    config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)),
                               average=average, fs=fs, window_type=window_type, nperseg=nperseg,
                               noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode,
                               return_onesided=return_onesided, scaling=scaling, axis=axis,
                               fmin=fmin, fmax=fmax, output_axis='time_first')

    if resample_factors is None:
        resample_factors = np.linspace(1.1, 1.9, 17)
    resample_factors = np.round(resample_factors, 4)

    logging.info('Starting computation')
    for ii, rf in enumerate(resample_factors):
        rat = fractions.Fraction(str(rf))
        print(rat)
        up, down = rat.numerator, rat.denominator

        y = resample_poly(x, up, down, axis=config.axis)
        z = resample_poly(x, down, up, axis=config.axis)

        f, t, Y = compute_stft(y, **config.stft_args)
        Y = apply_average(Y, config.average, axis=0, keepdims=True)

        f, t, Z = compute_stft(z, **config.stft_args)
        Z = apply_average(Z, config.average, axis=0, keepdims=True)

        if ii == 0:
            pxx = np.sqrt(Y * Z)
        else:
            pxx = np.concatenate((pxx, np.sqrt(Y * Z)), axis=0)

    aperiodic_pxx = np.median(pxx, axis=0, keepdims=False)

    f, t, full_pxx = compute_stft(x, **config.stft_args)
    full_pxx = apply_average(full_pxx,
                             config.average,
                             axis=0)

    return full_pxx - aperiodic_pxx, aperiodic_pxx


def apply_average(X, method, axis=0, keepdims=False):
    # Average over sliding windows.
    if method == 'mean':
        X = np.nanmean(X, axis=axis, keepdims=keepdims)
    elif method == 'median':
        X = np.nanmedian(X, axis=axis, keepdims=keepdims)
    elif method is None:
        pass
    else:
        msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'"
        raise ValueError(msg.format(method))
    return X


# -----------------------------------------------------------------------
# GLM Spectrogram Functions

@@ -2186,3 +2320,195 @@ def glm_multitaper(X, reg_ztrans=None, reg_unitmax=None, fit_method='pinv', fit_
        raise ValueError('fit_method not recognised')

    return f, copes, varcopes, extras


@set_verbose
def glm_irasa(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,
              contrasts=None, fit_method='pinv', fit_intercept=True,
              ret_class=True,
              # IRASA kwargs
              resample_factors=None, average='median',
              # General STFT kwargs
              fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None,
              detrend='constant', return_onesided=True, scaling='density',
              axis=-1, mode='psd', fmin=None, fmax=None, verbose=None):
    """Compute a Power Spectrum with a General Linear Model.

    Parameters
    ----------
    x : array_like
        Time series of measurement values
    reg_ztrans : dict or None
        Dictionary of covariate time series to be added as z-standardised regessors. (Default value = None)
    reg_unitmax : dict or None
        Dictionary of confound time series to be added as positive-valued unitmax regessors. (Default value = None)
    fit_method : {'pinv', 'lstsq', 'glmtools', sklearn estimator instance}
        Specifies how the GLM parameters will be estimated.
        * `pinv` uses the design matrix psuedo-inverse method
        * `lstsq` uses np.linalg.lstsq.
        * `glmtools` uses the OLSModel from the glmtools package.
        * A parametrised instance of a sklearn estimator is used if specified here. (Default value = 'pinv')
    fit_intercept : bool
        Specifies whether a constant valued 'intercept' regressor is included in the model. (Default value = True)
    fs : float, optional
        Sampling frequency of the `x` time series. Defaults to 1.0.
    nperseg : int, optional
        Length of each segment. Defaults to None, but if window is str or
        tuple, is set to 256, and if window is array_like, is set to the
        length of the window.
    noverlap : int, optional
        Number of points to overlap between segments. If `None`,
        ``noverlap = nperseg // 2``. Defaults to `None`.
    nfft : int, optional
        Length of the FFT used, if a zero padded FFT is desired. If
        `None`, the FFT length is `nperseg`. Defaults to `None`.
    detrend : str or function or `False`, optional
        Specifies how to detrend each segment. If `detrend` is a
        string, it is passed as the `type` argument to the `detrend`
        function. If it is a function, it takes a segment and returns a
        detrended segment. If `detrend` is `False`, no detrending is
        done. Defaults to 'constant'.
    return_onesided : bool, optional
        If `True`, return a one-sided spectrum for real data. If
        `False` return a two-sided spectrum. Defaults to `True`, but for
        complex data, a two-sided spectrum is always returned.
    scaling : { 'density', 'spectrum' }, optional
        Selects between computing the power spectral density ('density')
        where `Pxx` has units of V**2/Hz and computing the power
        spectrum ('spectrum') where `Pxx` has units of V**2, if `x`
        is measured in V and `fs` is measured in Hz. Defaults to
        'density'
    axis : int, optional
        Axis along which the periodogram is computed; the default is
        over the last axis (i.e. ``axis=-1``).
    fmin : float or None, optional
        Minimum frequency value to return (Default value = 0)
    fmax : float or None, optional
        Maximum frequency value to return (Default value = 0.5)
    return_config : bool
        Indicate whether parameter configuration object should be returned
        alongside result (Default value = False)

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    extras : tuple
        Additional model information depending on the fit method used.

    """
    # Option housekeeping
    if axis == -1:
        axis = X.ndim - 1

    if X.ndim != 1 and fit_method in ['pinv', 'lstsq']:
        msg = "Data input should be vector for 'pinv' and 'lstsq' fits - data shape {0} was passed in"
        logging.error(msg.format(X.shape))
        logging.error("Use fit_method='glmtools' for multdimensional data")
        raise ValueError("Fit methods 'pinv' and 'lstsq' not implemented for multidimensional data")

    # Set configuration
    logging.info('Setting config options')
    config = GLMPeriodogramConfig(X.shape[axis], reg_ztrans=reg_ztrans,
                                  reg_unitmax=reg_unitmax,
                                  fit_method=fit_method, contrasts=contrasts,
                                  fit_intercept=fit_intercept,
                                  input_complex=np.iscomplexobj(X), fs=fs,
                                  fmin=fmin, fmax=fmax,
                                  window_type=window_type, nperseg=nperseg,
                                  noverlap=noverlap,
                                  nfft=nfft, detrend=detrend,
                                  return_onesided=return_onesided,
                                  scaling=scaling, axis=axis, mode=mode,
                                  output_axis='time_first')
    print(config)

    # Transform inputs into predicable, sanity checked dictionaries
    logging.info('Processing Conditions, Covariates and Confounds')
    reg_categorical = _process_input_covariate(reg_categorical, config.input_len)
    reg_ztrans = _process_input_covariate(reg_ztrans, config.input_len)
    reg_unitmax = _process_input_covariate(reg_unitmax, config.input_len)

    # Compute STFT
    logging.info('Computing sliding window periodogram')
    f, t, p = compute_stft(X, **config.stft_args)

    # Compute model - each method MUST assign copes, varcopes and extras
    model, des, data = _glm_fit_glmtools(p, reg_categorical, reg_ztrans,
                                         reg_unitmax, config,
                                         contrasts=contrasts,
                                         fit_intercept=fit_intercept)

    if resample_factors is None:
        resample_factors = np.linspace(1.1, 1.9, 17)
    resample_factors = np.round(resample_factors, 4)

    logging.info('Starting computation')
    for ii, rf in enumerate(resample_factors):
        rat = fractions.Fraction(str(rf))
        print(rat)
        up, down = rat.numerator, rat.denominator

        y = _resample_helper(X, reg_categorical, reg_ztrans, reg_unitmax,
                             up, down, axis=config.axis)
        y, y_categorical, y_ztrans, y_unitmax = y

        f, t, Y = compute_stft(y, **config.stft_args)
        modelY, desY, dataY = _glm_fit_glmtools(Y, y_categorical, y_ztrans,
                                                y_unitmax, config,
                                                contrasts=contrasts,
                                                fit_intercept=fit_intercept)

        z = _resample_helper(X, reg_categorical, reg_ztrans, reg_unitmax,
                             up, down, axis=config.axis)
        z, z_categorical, z_ztrans, z_unitmax = z
        f, t, Z = compute_stft(z, **config.stft_args)
        modelZ, desZ, dataZ = _glm_fit_glmtools(Z, z_categorical, z_ztrans,
                                                z_unitmax, config,
                                                contrasts=contrasts,
                                                fit_intercept=fit_intercept)

        if ii == 0:
            betas = np.sqrt(modelY.betas * modelZ.betas)[np.newaxis, ...]
            copes = np.sqrt(modelY.copes * modelZ.copes)[np.newaxis, ...]
            varcopes = np.sqrt(modelY.varcopes * modelZ.varcopes)[np.newaxis, ...]
        else:
            new_betas = np.sqrt(modelY.betas * modelZ.betas)[np.newaxis, ...]
            betas = np.concatenate((betas, new_betas), axis=0)
            new_copes = np.sqrt(modelY.copes * modelZ.copes)[np.newaxis, ...]
            copes = np.concatenate((copes, new_copes), axis=0)
            new_varcopes = np.sqrt(modelY.varcopes * modelZ.varcopes)[np.newaxis, ...]
            varcopes = np.concatenate((varcopes, new_varcopes), axis=0)

    model_aperiodic = deepcopy(model)
    model_aperiodic.betas = apply_average(betas, average, axis=0)
    model_aperiodic.copes = apply_average(copes, average, axis=0)
    model_aperiodic.varcopes = apply_average(varcopes, average, axis=0)

    model.betas = model.betas - model_aperiodic.betas
    model.copes = model.copes - model_aperiodic.copes
    model.varcopes = model.varcopes - model_aperiodic.varcopes

    return model_aperiodic, model


def _resample_helper(X, reg_categorical, reg_ztrans, reg_unitmax, up, down, axis=0):
    y = resample_poly(X, up, down, axis=axis)

    out_categorical = reg_categorical.copy()
    for key, val in reg_categorical.items():
        out_categorical[key] = resample_poly(val, up, down, axis=0)

    out_ztrans = reg_ztrans.copy()
    for key, val in reg_ztrans.items():
        out_ztrans[key] = resample_poly(val, up, down, axis=0)

    out_unitmax = reg_unitmax.copy()
    for key, val in reg_unitmax.items():
        out_unitmax[key] = resample_poly(val, up, down, axis=0)

    return y, out_categorical, out_ztrans, out_unitmax