Commit bd76897f authored by Andrew Quinn's avatar Andrew Quinn
Browse files

quirasa and glm_quirasa

parent 9ea1c7ea
Loading
Loading
Loading
Loading
+303 −8
Original line number Diff line number Diff line
@@ -234,12 +234,9 @@ def compute_fft(x, nfft=256, axis=-1,
    # Get frequency values
    freqvals = _set_freqvalues(nfft, fs, side)
    # Trim frequency range to specified limits
    logging.info('Trimming freq axis to range {0} - {1}'.format(fmin, fmax))
    fidx = (freqvals >= fmin) & \
           (freqvals <= fmax)
    fidx = _proc_get_freq_inds(freqvals, fmin, fmax)
    result = result[..., fidx]
    freqs = freqvals[fidx]
    logging.debug('fft trimmed output shape {0}'.format(result.shape))

    return freqs, result

@@ -668,6 +665,14 @@ def _proc_spectrum_scaling(pxx, scale, side, mode, nfft):
    return pxx


def _proc_get_freq_inds(freqvals, fmin, fmax):
    logging.info('Trimming freq axis to range {0} - {1}'.format(fmin, fmax))
    fidx = (freqvals >= fmin) & \
           (freqvals <= fmax)
    logging.debug('fft trimmed length {0}'.format(fidx.sum()))
    return fidx


# ------------------------------------------------------------------
# Config Functions
#
@@ -1581,7 +1586,7 @@ def irasa(x, resample_factors=None,
                               average=average, fs=fs, window_type=window_type, nperseg=nperseg,
                               noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode,
                               return_onesided=return_onesided, scaling=scaling, axis=axis,
                               fmin=fmin, fmax=fmax, output_axis='time_first')
                               fmin=None, fmax=None, output_axis='time_first')

    if resample_factors is None:
        resample_factors = np.linspace(1.1, 1.9, 17)
@@ -1590,7 +1595,6 @@ def irasa(x, resample_factors=None,
    logging.info('Starting computation')
    for ii, rf in enumerate(resample_factors):
        rat = fractions.Fraction(str(rf))
        print(rat)
        up, down = rat.numerator, rat.denominator

        y = resample_poly(x, up, down, axis=config.axis)
@@ -1598,15 +1602,22 @@ def irasa(x, resample_factors=None,

        f, t, Y = compute_stft(y, **config.stft_args)
        Y = apply_average(Y, config.average, axis=0, keepdims=True)
        freqsY = _set_freqvalues(config.nfft, config.fs*up, config.side)
        fidxY = _proc_get_freq_inds(freqsY, fmin, fmax)

        f, t, Z = compute_stft(z, **config.stft_args)
        Z = apply_average(Z, config.average, axis=0, keepdims=True)
        freqsZ = _set_freqvalues(config.nfft, config.fs*down, config.side)
        fidxZ = _proc_get_freq_inds(freqsZ, fmin, fmax)

        if ii == 0:
            pxx = np.sqrt(Y * Z)
            valid_freqs = fidxY + fidxZ
        else:
            pxx = np.concatenate((pxx, np.sqrt(Y * Z)), axis=0)
            valid_freqs = valid_freqs + fidxY + fidxZ

    #aperiodic_pxx = np.median(pxx, axis=0, keepdims=False)
    aperiodic_pxx = np.median(pxx, axis=0, keepdims=False)

    f, t, full_pxx = compute_stft(x, **config.stft_args)
@@ -1614,7 +1625,125 @@ def irasa(x, resample_factors=None,
                             config.average,
                             axis=0)

    return full_pxx - aperiodic_pxx, aperiodic_pxx
    return full_pxx - aperiodic_pxx, aperiodic_pxx, valid_freqs, pxx


@set_verbose
def quirasa(x, resample_factors=None,
            # Periodogram args
            average='mean',
            # General STFT kwargs
            fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None,
            detrend='constant', return_onesided=True, mode='psd', scaling='density',
            axis=-1, fmin=None, fmax=None, return_config=False, verbose=None):
    """Compute Periodogram by averaging across windows in a STFT.

    Parameters
    ----------
    x : array_like
        Time series of measurement values
    average : { 'mean', 'median' }, optional
        Method to use when averaging periodograms. Defaults to 'mean'.
    fs : float, optional
        Sampling frequency of the `x` time series. Defaults to 1.0.
    window_type : str or tuple or array_like, optional
        Desired window to use. If `window` is a string or tuple, it is
        passed to `get_window` to generate the window values, which are
        DFT-even by default. See `get_window` for a list of windows and
        required parameters. If `window` is array_like it will be used
        directly as the window and its length must be nperseg. Defaults
        to a Hann window.
    nperseg : int, optional
        Length of each segment. Defaults to None, but if window is str or
        tuple, is set to 256, and if window is array_like, is set to the
        length of the window.
    noverlap : int, optional
        Number of points to overlap between segments. If `None`,
        ``noverlap = nperseg // 2``. Defaults to `None`.
    nfft : int, optional
        Length of the FFT used, if a zero padded FFT is desired. If
        `None`, the FFT length is `nperseg`. Defaults to `None`.
    detrend : str or function or `False`, optional
        Specifies how to detrend each segment. If `detrend` is a
        string, it is passed as the `type` argument to the `detrend`
        function. If it is a function, it takes a segment and returns a
        detrended segment. If `detrend` is `False`, no detrending is
        done. Defaults to 'constant'.
    return_onesided : bool, optional
        If `True`, return a one-sided spectrum for real data. If
        `False` return a two-sided spectrum. Defaults to `True`, but for
        complex data, a two-sided spectrum is always returned.
    scaling : { 'density', 'spectrum' }, optional
        Selects between computing the power spectral density ('density')
        where `Pxx` has units of V**2/Hz and computing the power
        spectrum ('spectrum') where `Pxx` has units of V**2, if `x`
        is measured in V and `fs` is measured in Hz. Defaults to
        'density'
    axis : int, optional
        Axis along which the periodogram is computed; the default is
        over the last axis (i.e. ``axis=-1``).
    fmin : float or None, optional
        Minimum frequency value to return (Default value = 0)
    fmax : float or None, optional
        Maximum frequency value to return (Default value = 0.5)
    return_config : bool
        Indicate whether parameter configuration object should be returned
        alongside result (Default value = False)

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    config : PeriodogramConfig, optional
        Configuration object containing all parameters used to compute
        spectrum, optionally returned based on value of `return_config`.

    """
    # Config object stores options in one place and sets sensible defaults for
    # unspecified options given the data in-hand
    logging.info('Setting config options')
    config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)),
                               average=average, fs=fs, window_type=window_type, nperseg=nperseg,
                               noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode,
                               return_onesided=return_onesided, scaling=scaling, axis=axis,
                               fmin=None, fmax=None, output_axis='time_first')

    if resample_factors is None:
        resample_factors = np.linspace(0.1, 1.9, 15)
    resample_factors = np.round(resample_factors, 4)

    logging.info('Starting computation')
    for ii, rf in enumerate(resample_factors):
        rat = fractions.Fraction(str(rf))
        up, down = rat.numerator, rat.denominator

        y = resample_poly(x, up, down, axis=config.axis)

        f, t, Y = compute_stft(y, **config.stft_args)
        Y = apply_average(Y, config.average, axis=0, keepdims=True)
        freqsY = _set_freqvalues(config.nfft, config.fs*up, config.side)
        fidxY = _proc_get_freq_inds(freqsY, fmin, fmax)
        #f, Y = periodogram(y, average=config.average, **config.stft_args)

        if ii == 0:
            pxx = Y
            #valid_freqs = fidxY
        else:
            pxx = np.concatenate((pxx, Y), axis=0)
            #valid_freqs = valid_freqs + fidxY

    aperiodic_pxx = np.median(pxx, axis=0, keepdims=False)

    f, t, full_pxx = compute_stft(x, **config.stft_args)
    full_pxx = apply_average(full_pxx,
                             config.average,
                             axis=0)

    return full_pxx - aperiodic_pxx, aperiodic_pxx, pxx


def apply_average(X, method, axis=0, keepdims=False):
@@ -1623,6 +1752,8 @@ def apply_average(X, method, axis=0, keepdims=False):
        X = np.nanmean(X, axis=axis, keepdims=keepdims)
    elif method == 'median':
        X = np.nanmedian(X, axis=axis, keepdims=keepdims)
    elif method == 'min':
        X = np.nanmin(X, axis=axis, keepdims=keepdims)
    elif method is None:
        pass
    else:
@@ -2327,7 +2458,7 @@ def glm_irasa(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,
              contrasts=None, fit_method='pinv', fit_intercept=True,
              ret_class=True,
              # IRASA kwargs
              resample_factors=None, average='median',
              resample_factors=None, average='min',
              # General STFT kwargs
              fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None,
              detrend='constant', return_onesided=True, scaling='density',
@@ -2445,6 +2576,10 @@ def glm_irasa(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,

    if resample_factors is None:
        resample_factors = np.linspace(1.1, 1.9, 17)
        p = np.log([1.1, 2.5])
        edges = np.linspace(p[0], p[1], 15)
        resample_factors = np.exp(edges)

    resample_factors = np.round(resample_factors, 4)

    logging.info('Starting computation')
@@ -2496,6 +2631,166 @@ def glm_irasa(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,
    return model_aperiodic, model


@set_verbose
def glm_quirasa(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,
                contrasts=None, fit_method='pinv', fit_intercept=True,
                ret_class=True,
                # Periodogram args
                average='median', resample_factors=None,
                # General STFT kwargs
                fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None,
                detrend='constant', return_onesided=True, mode='psd', scaling='density',
                axis=-1, fmin=None, fmax=None, return_config=False, verbose=None):
    """Compute Periodogram by averaging across windows in a STFT.

    Parameters
    ----------
    x : array_like
        Time series of measurement values
    average : { 'mean', 'median' }, optional
        Method to use when averaging periodograms. Defaults to 'mean'.
    fs : float, optional
        Sampling frequency of the `x` time series. Defaults to 1.0.
    window_type : str or tuple or array_like, optional
        Desired window to use. If `window` is a string or tuple, it is
        passed to `get_window` to generate the window values, which are
        DFT-even by default. See `get_window` for a list of windows and
        required parameters. If `window` is array_like it will be used
        directly as the window and its length must be nperseg. Defaults
        to a Hann window.
    nperseg : int, optional
        Length of each segment. Defaults to None, but if window is str or
        tuple, is set to 256, and if window is array_like, is set to the
        length of the window.
    noverlap : int, optional
        Number of points to overlap between segments. If `None`,
        ``noverlap = nperseg // 2``. Defaults to `None`.
    nfft : int, optional
        Length of the FFT used, if a zero padded FFT is desired. If
        `None`, the FFT length is `nperseg`. Defaults to `None`.
    detrend : str or function or `False`, optional
        Specifies how to detrend each segment. If `detrend` is a
        string, it is passed as the `type` argument to the `detrend`
        function. If it is a function, it takes a segment and returns a
        detrended segment. If `detrend` is `False`, no detrending is
        done. Defaults to 'constant'.
    return_onesided : bool, optional
        If `True`, return a one-sided spectrum for real data. If
        `False` return a two-sided spectrum. Defaults to `True`, but for
        complex data, a two-sided spectrum is always returned.
    scaling : { 'density', 'spectrum' }, optional
        Selects between computing the power spectral density ('density')
        where `Pxx` has units of V**2/Hz and computing the power
        spectrum ('spectrum') where `Pxx` has units of V**2, if `x`
        is measured in V and `fs` is measured in Hz. Defaults to
        'density'
    axis : int, optional
        Axis along which the periodogram is computed; the default is
        over the last axis (i.e. ``axis=-1``).
    fmin : float or None, optional
        Minimum frequency value to return (Default value = 0)
    fmax : float or None, optional
        Maximum frequency value to return (Default value = 0.5)
    return_config : bool
        Indicate whether parameter configuration object should be returned
        alongside result (Default value = False)

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    config : PeriodogramConfig, optional
        Configuration object containing all parameters used to compute
        spectrum, optionally returned based on value of `return_config`.

    """
    # Option housekeeping
    if axis == -1:
        axis = X.ndim - 1

    if X.ndim != 1 and fit_method in ['pinv', 'lstsq']:
        msg = "Data input should be vector for 'pinv' and 'lstsq' fits - data shape {0} was passed in"
        logging.error(msg.format(X.shape))
        logging.error("Use fit_method='glmtools' for multdimensional data")
        raise ValueError("Fit methods 'pinv' and 'lstsq' not implemented for multidimensional data")

    # Set configuration
    logging.info('Setting config options')
    config = GLMPeriodogramConfig(X.shape[axis], reg_ztrans=reg_ztrans,
                                  reg_unitmax=reg_unitmax,
                                  fit_method=fit_method, contrasts=contrasts,
                                  fit_intercept=fit_intercept,
                                  input_complex=np.iscomplexobj(X), fs=fs,
                                  fmin=fmin, fmax=fmax,
                                  window_type=window_type, nperseg=nperseg,
                                  noverlap=noverlap,
                                  nfft=nfft, detrend=detrend,
                                  return_onesided=return_onesided,
                                  scaling=scaling, axis=axis, mode=mode,
                                  output_axis='time_first')
    print(config)

    # Transform inputs into predicable, sanity checked dictionaries
    logging.info('Processing Conditions, Covariates and Confounds')
    reg_categorical = _process_input_covariate(reg_categorical, config.input_len)
    reg_ztrans = _process_input_covariate(reg_ztrans, config.input_len)
    reg_unitmax = _process_input_covariate(reg_unitmax, config.input_len)

    # Compute STFT
    logging.info('Computing sliding window periodogram')
    f, t, p = compute_stft(X, **config.stft_args)

    # Compute model - each method MUST assign copes, varcopes and extras
    model, des, data = _glm_fit_glmtools(p, reg_categorical, reg_ztrans,
                                         reg_unitmax, config,
                                         contrasts=contrasts,
                                         fit_intercept=fit_intercept)

    if resample_factors is None:
        resample_factors = np.linspace(0.1, 1.9, 15)
    resample_factors = np.round(resample_factors, 4)

    logging.info('Starting computation')
    for ii, rf in enumerate(resample_factors):
        rat = fractions.Fraction(str(rf))
        up, down = rat.numerator, rat.denominator

        y = _resample_helper(X, reg_categorical, reg_ztrans, reg_unitmax,
                             up, down, axis=config.axis)
        y, y_categorical, y_ztrans, y_unitmax = y

        f, t, Y = compute_stft(y, **config.stft_args)
        modelY, desY, dataY = _glm_fit_glmtools(Y, y_categorical, y_ztrans,
                                                y_unitmax, config,
                                                contrasts=contrasts,
                                                fit_intercept=fit_intercept)


        if ii == 0:
            betas = modelY.betas[np.newaxis, ...]
            copes = modelY.copes[np.newaxis, ...]
            varcopes = modelY.varcopes[np.newaxis, ...]
        else:
            betas = np.concatenate((betas, modelY.betas[np.newaxis, ...]), axis=0)
            copes = np.concatenate((copes, modelY.copes[np.newaxis, ...]), axis=0)
            varcopes = np.concatenate((varcopes, modelY.varcopes[np.newaxis, ...]), axis=0)

    model_aperiodic = deepcopy(model)
    model_aperiodic.betas = apply_average(betas, average, axis=0)
    model_aperiodic.copes = apply_average(copes, average, axis=0)
    model_aperiodic.varcopes = apply_average(varcopes, average, axis=0)

    model.betas = model.betas - model_aperiodic.betas
    model.copes = model.copes - model_aperiodic.copes
    model.varcopes = model.varcopes - model_aperiodic.varcopes

    return model_aperiodic, model


def _resample_helper(X, reg_categorical, reg_ztrans, reg_unitmax, up, down, axis=0):
    y = resample_poly(X, up, down, axis=axis)