Commit 75c744dd authored by Andrew Quinn's avatar Andrew Quinn
Browse files

Further glm_irasa updates

parent d59a55e4
Loading
Loading
Loading
Loading
+122 −62
Original line number Diff line number Diff line
@@ -539,6 +539,40 @@ def _proc_get_time_range(nperseg, nstep, fs, xlen):
    return np.arange(nperseg/2, xlen - nperseg/2 + 1, nperseg - noverlap)/float(fs)


def _proc_nan_freq_range(result, freqvals, fmin=None, fmax=None, axis=-1):
    """Set values to NaN within a desired frequency range.

    Parameters
    ----------
    result : array_like
        Spectrum result with frequency on the final axis
    freqvals : vector
        Vector of frequency values with length matching the final axis of result
    %(freq_range)s'
    %(axis)s'

    Returns
    -------
    result : array_like
        Input array with final dimension trimmed in-place
    freqs : array_like
        New frequency array matching the final axis of output

    """    
    if fmin >= fmax:
        logging.error('Selected fmin ({}) is larger than fmax ({})'.format(fmin,))
    
    logging.info('Setting data within range {0} - {1} to NaN'.format(fmin, fmax))
    fidx = _proc_get_freq_inds(freqvals, fmin, fmax)

    # Move freq axis to front to make indexing simpler
    result = np.moveaxis(result, axis, 0)
    result[fidx] = np.nan
    result = np.moveaxis(result, 0, axis)

    return result


def _proc_trim_freq_range(result, freqvals, fmin=None, fmax=None, axis=-1):
    """Trim an FFT output to desired frequency range.

@@ -565,6 +599,9 @@ def _proc_trim_freq_range(result, freqvals, fmin=None, fmax=None, axis=-1):
        # Just passing through
        return freqs, result
    
    if fmin >= fmax:
        logging.error('Selected fmin ({}) is larger than fmax ({})'.format(fmin,))
    
    logging.info('Trimming freq axis to range {0} - {1}'.format(fmin, fmax))
    fmin = freqvals[0] if fmin is None else fmin
    fmax = freqvals[-1] if fmax is None else fmax
@@ -602,6 +639,10 @@ def _proc_apply_average(p, average, axis=0, keepdims=False):
        Raised if option for average not recognised.

    """ 
    if average is not None:
        # Keep quite if we're just passing through
        logging.info('Applying average to dim {} using method {}'.format(axis, average))

    if average == 'mean':
        p = np.nanmean(p, axis=axis, keepdims=keepdims)
    elif average == 'median':
@@ -611,6 +652,8 @@ def _proc_apply_average(p, average, axis=0, keepdims=False):
        p = np.nanmedian(p, axis=axis, keepdims=keepdims) / bias
    elif average == 'min':
        p = np.nanmin(p, axis=axis, keepdims=keepdims)
    elif average == None:
        pass
    else:
        msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'"
        raise ValueError(msg.format(average))
@@ -1085,6 +1128,9 @@ class GLMIRASAConfig(GLMPeriodogramConfig, IRASAConfig):
@dataclass
class SpectrumResult:

    def __repr__(self):
        return 'Spectrum result size {} using {}'.format(self.spectrum.shape, type(self.config))

    def __init__(self, f, t, pxx, config=None):

        self.f = f
@@ -1101,7 +1147,7 @@ def periodogram(x, average='mean',
                nfft=None, axis=-1, return_onesided=True, mode='psd',
                scaling='density', fs=1.0, fmin=None, fmax=None,
                # misc
                return_config=False, verbose=None):
                verbose=None):
    """Compute Periodogram by averaging across windows in a STFT.

    Parameters
@@ -1111,22 +1157,22 @@ def periodogram(x, average='mean',
    %(average)s`
    %(stft_window_user)s'
    %(fft_user)s'
    return_config : bool
        Indicate whether parameter configuration object should be returned
        alongside result (Default value = False)
    %(verbose)s'

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    config : PeriodogramConfig, optional
        Configuration object containing all parameters used to compute
        spectrum, optionally returned based on value of `return_config`.
    result : sails.stft.SpectrumResult
        Object containing the following attributes

        spectrum : ndarray
            The fitted frequency spectrum
        f : ndarray
            The frequency axis values for the fitted spectrum
        t : {ndarray | None}
            The time axis values for the fitted spectrum or 
            None if the spectrum has been averaged across segments.
        config : object
            The configuration object specifying how the spectrum was computed

    """
    # Config object stores options in one place and sets sensible defaults for
@@ -1159,7 +1205,7 @@ def multitaper(x, average='mean',
               nfft=None, axis=-1, return_onesided=True, mode='psd',
               scaling='density', fs=1.0, fmin=None, fmax=None,
               # misc
               return_config=False, verbose=None):
               verbose=None):
    """Compute a multi-tapered power spectrum averaged across windows in a STFT.

    Parameters
@@ -1170,22 +1216,22 @@ def multitaper(x, average='mean',
    %(multitaper_core)s'
    %(stft_window_user)s'
    %(fft_user)s'
    return_config : bool
        Indicate whether parameter configuration object should be returned
        alongside result (Default value = False)
    %(verbose)s'

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    config : MultiTaperConfig, optional
        Configuration object containing all parameters used to compute
        spectrum, optionally returned based on value of `return_config`.
    result : sails.stft.SpectrumResult
        Object containing the following attributes

        spectrum : ndarray
            The fitted frequency spectrum
        f : ndarray
            The frequency axis values for the fitted spectrum
        t : {ndarray | None}
            The time axis values for the fitted spectrum or 
            None if the spectrum has been averaged across segments.
        config : object
            The configuration object specifying how the spectrum was computed

    """
    # Config object stores options in one place and sets sensible defaults for
@@ -1216,6 +1262,9 @@ def multitaper(x, average='mean',
@dataclass
class GLMSpectrumResult:

    def __repr__(self):
        return 'GLMSpectrum result size {} using {}'.format(self.model.copes.shape, self.config)

    def __init__(self, f, model, design, data, config=None):

        self.f = f
@@ -1451,7 +1500,6 @@ def _run_prefit_checks(data, design_matrix, contrasts):
    -------
    None


    """
    # Make sure we're set for model fitting
    assert(data.shape[0] == design_matrix.shape[0])
@@ -1665,7 +1713,7 @@ def glm_periodogram(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,

    Returns
    -------
    GLMSpectrumResult : object
    result : sails.stft.GLMSpectrumResult
        Object containing the fitted GLM Periodogram

    """
@@ -1864,9 +1912,9 @@ def irasa_v3(x, method='original', resample_factors=None, aperiodic_average='med
          nperseg=None, noverlap=None, window_type='hann', detrend='constant',
          # FFT core args
          nfft=None, axis=-1, return_onesided=True, mode='psd',
          scaling='density', fs=1.0, fmin=None, fmax=None,
          scaling='density', fs=1.0, fmin=None, fmax=None, fmask=None,
          # misc
          return_config=False, verbose=None, avg2=False):
          return_config=False, verbose=None, avg2=False, output_axis='auto'):
    """Compute Periodogram by averaging across windows in a STFT.

    Parameters
@@ -1884,15 +1932,20 @@ def irasa_v3(x, method='original', resample_factors=None, aperiodic_average='med

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    config : PeriodogramConfig, optional
        Configuration object containing all parameters used to compute
        spectrum, optionally returned based on value of `return_config`.
    aperiodic, oscillatory : sails.stft.GLMSpectrumResult
        Objects containing the following attributes for the GLM fitted on the
        aperiodic and oscillatory parts of the spectrum

        model : object
            The fitted frequency spectrum
        design : object
            The model design matrix
        data : object
            The modelled data
        f : ndarray
            The frequency axis values for the fitted spectrum
        config : object
            The configuration object specifying how the spectrum was computed

    """
    # Config object stores options in one place and sets sensible defaults for
@@ -1941,6 +1994,11 @@ def irasa_v3(x, method='original', resample_factors=None, aperiodic_average='med
        zz = signal.resample_poly(yy, up, down, axis=-1)  # This changes the variance of the signal...
        freqs, pxx_resampled = compute_fft(zz, **config.fft_args)

        if fmask is not None:
            pxx_resampled = _proc_nan_freq_range(pxx_resampled, freqs, 
                                                 fmask[0]*rf, fmask[1]*rf, 
                                                 axis=-1)

        if ii == 0:
            pxx_aperiodic = pxx_resampled[None, ...]
        else:
@@ -1960,9 +2018,9 @@ def irasa_v3(x, method='original', resample_factors=None, aperiodic_average='med
    pxx_aperiodic = _proc_apply_average(pxx_aperiodic, config.average, axis=0)
    pxx_oscillatory = _proc_apply_average(pxx_oscillatory, config.average, axis=0)

    if pxx_aperiodic.ndim > 1:
        pxx_aperiodic = _proc_unroll_output(pxx_aperiodic, pxx_aperiodic.ndim-1, output_axis='auto')
        pxx_oscillatory = _proc_unroll_output(pxx_oscillatory, pxx_aperiodic.ndim-1, output_axis='auto')
    if pxx_aperiodic.ndim > 1 and output_axis is not None:
        pxx_aperiodic = _proc_unroll_output(pxx_aperiodic, pxx_aperiodic.ndim-1, output_axis=output_axis)
        pxx_oscillatory = _proc_unroll_output(pxx_oscillatory, pxx_aperiodic.ndim-1, output_axis=output_axis)

    aperiodic = SpectrumResult(out_freqs, t, pxx_aperiodic, config=config)
    oscillatory = SpectrumResult(out_freqs, t, pxx_oscillatory, config=config)
@@ -2108,7 +2166,7 @@ def glm_irasa_v3(x,
              nperseg=None, noverlap=None, window_type='hann', detrend='constant',
              # FFT core args
              nfft=None, axis=-1, return_onesided=True, mode='psd',
              scaling='density', fs=1.0, fmin=None, fmax=None,
              scaling='density', fs=1.0, fmin=None, fmax=None, fmask=None,
              # misc
              return_config=False, verbose=None):
    """Compute Periodogram by averaging across windows in a STFT.
@@ -2129,15 +2187,19 @@ def glm_irasa_v3(x,

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    config : PeriodogramConfig, optional
        Configuration object containing all parameters used to compute
        spectrum, optionally returned based on value of `return_config`.
    aperiodic, oscillatory : sails.stft.SpectrumResult
        Objects containing the following attributes for the 
        aperiodic and oscillatory parts of the spectrum

        spectrum : ndarray
            The fitted frequency spectrum
        f : ndarray
            The frequency axis values for the fitted spectrum
        t : {ndarray | None}
            The time axis values for the fitted spectrum or 
            None if the spectrum has been averaged across segments.
        config : object
            The configuration object specifying how the spectrum was computed

    """
    # Config object stores options in one place and sets sensible defaults for
@@ -2150,7 +2212,7 @@ def glm_irasa_v3(x,
                            method=method, resample_factors=resample_factors, 
                            aperiodic_average=aperiodic_average,
                            # Periodogram args
                            average=average,
                            average=None,
                            # STFT window args
                            nperseg=nperseg, noverlap=noverlap, 
                            window_type=window_type, detrend=detrend,
@@ -2159,9 +2221,7 @@ def glm_irasa_v3(x,
                            scaling=scaling, fs=fs, fmin=fmin, fmax=fmax)
    
    logging.info('Setting config options')
    aperiodic, oscillatory = irasa_v3(x, **config.irasa_args)


    aperiodic, oscillatory = irasa_v3(x, output_axis='time_first', fmask=fmask, **config.irasa_args)

    # Transform inputs into predicable, sanity checked dictionaries
    logging.info('Processing Conditions, Covariates and Confounds')
@@ -2169,19 +2229,19 @@ def glm_irasa_v3(x,
    config.reg_ztrans = _process_input_covariate(config.reg_ztrans, config.input_len)
    config.reg_unitmax = _process_input_covariate(config.reg_unitmax, config.input_len)
    
    for ind, pxx in enumerate((aperiodic, oscillatory)):
    ret = []
    for pxx in [aperiodic, oscillatory]:
        logging.info('Processing Conditions, Covariates and Confounds')

        # Fit model
        model, des, data = _glm_fit_glmtools(pxx.spectrum, config.reg_categorical, config.reg_ztrans,
                                             config.reg_unitmax, config,
                                             contrasts=config.contrasts,
                                             fit_intercept=config.fit_intercept)
        if ind == 0:
            glm_aperiodic = GLMSpectrumResult(pxx.f, model, des, data, config=config)
        else:
            glm_oscillatory = GLMSpectrumResult(pxx.f, model, des, data, config=config)
        
    return glm_aperiodic, glm_oscillatory
        ret.append(GLMSpectrumResult(pxx.f, model, des, data, config=config))

    return ret


@set_verbose