Commit f4eaa88e authored by Andrew Quinn's avatar Andrew Quinn
Browse files

Add broaddcast or loop option to multitaper, fix simple GLM fits and ensure...

Add broaddcast or loop option to multitaper, fix simple GLM fits and ensure default args and returns are consistent
parent db7a6c52
Loading
Loading
Loading
Loading
Loading
+82 −45
Original line number Diff line number Diff line
@@ -306,6 +306,7 @@ def compute_stft(x, fs=1.0, fmin=0, fmax=0.5, axis=-1,


def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwidth=5,
                            apply_tapers='broadcast',
                            # Data kwargs
                            fs=1.0, fmin=0, fmax=0.5, axis=-1,
                            # window kwargs
@@ -415,6 +416,7 @@ def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwi
                             detrend_func=detrend_func,
                             window=None, padded=padded)

    if apply_tapers == 'broadcast':
        # Apply tapers - via broadcasting to avoid loops
        to_shape = np.r_[np.ones((len(y.shape)-1),), num_tapers, nperseg].astype(int)
        z = y[..., np.newaxis, :] * np.broadcast_to(tapers, to_shape)
@@ -428,6 +430,27 @@ def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwi
        # Average over tapers - could be high level option? mean or median?
        result = np.average(result, weights=taper_weights, axis=-2)

    elif apply_tapers == 'loop':
        # Apply tapers in a loop - slower but uses much less RAM
        to_shape = np.r_[np.ones((len(y.shape)-1),), nperseg].astype(int)
        for ii in range(num_tapers):
            logging.debug('running taper: {0}'.format(ii))
            z = y * np.broadcast_to(tapers[ii, :], to_shape)
            # Run actual FFT
            freqs, taper_result = compute_fft(z, nfft=nfft, axis=-1, side=side, mode=mode,
                                        scale=scale, fs=fs, fmin=fmin, fmax=fmax)
            logging.debug('tapered and fftd data shape {0}'.format(taper_result.shape))

            # Run an incremental average so we don't have to store anything
            # https://math.stackexchange.com/questions/106700/incremental-averaging/1836447
            if ii == 0:
                result = taper_result
            else:
                result = result + (taper_result - result) / (ii + 1)
    else:
        raise ValueError("'apply_tapers' option '{0}' not recognised. Use one of 'broadcast' or 'loop''".format(apply_tapers))


    # Periodogram Scaling
    result = result / fs

@@ -442,6 +465,16 @@ def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwi
    return freqs, time, result


def compute_spectral_matrix_fft(psd):
    # psd should be [channels x freq] complex for now
    S = np.zeros((psd.shape[0], psd.shape[0], psd.shape[1]), dtype=complex)

    for ii in range(psd.shape[1]):
        S[:, :, ii] = np.dot(psd[:, ii, np.newaxis], psd[np.newaxis, :, ii].conj())

    return S


# Helpers - private functions assisting low-level processors

def _proc_roll_input(x, axis=-1):
@@ -834,7 +867,7 @@ def _set_mode(mode):
    """
    modelist = ['psd', 'complex', 'magnitude', 'angle', 'phase']
    if mode not in modelist:
        raise ValueError('unknown value for mode {}, must be one of {}'
        raise ValueError("Invalid value ('{}') for mode, must be one of {}"
                         .format(mode, modelist))


@@ -958,6 +991,7 @@ class MultiTaperConfig(STFTConfig):
    time_bandwidth: int = 3
    num_tapers: typing.Union[str, int] = 'auto'
    freq_resolution: int = 1
    apply_tapers: str = 'broadcast'

    def __post_init__(self):
        super().__post_init__()
@@ -967,9 +1001,9 @@ class MultiTaperConfig(STFTConfig):
        """ """
        args = {}
        for key in ['time_bandwidth', 'num_tapers', 'freq_resolution',
                    'fs', 'nperseg', 'nstep', 'nfft', 'detrend_func',
                    'side', 'scale', 'axis', 'mode',
                    'padded', 'fmin', 'fmax', 'output_axis']:
                    'apply_tapers', 'fs', 'nperseg', 'nstep', 'nfft',
                    'detrend_func', 'side', 'scale', 'axis', 'mode', 'padded',
                    'fmin', 'fmax', 'output_axis']:
            args[key] = getattr(self, key)
        return args

@@ -1123,7 +1157,7 @@ def set_options(input_len,
def sw_periodogram(x,
                   # General STFT kwargs
                   fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None,
                   detrend='constant', return_onesided=True, scaling='density',
                   detrend='constant', return_onesided=True, mode='psd', scaling='density',
                   axis=-1, fmin=None, fmax=None, return_config=False):
    """Compute Periodogram by averaging across windows in a STFT.

@@ -1194,7 +1228,7 @@ def sw_periodogram(x,
    # unspecified options given the data in-hand
    config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)),
                               average=None, fs=fs, window_type=window_type, nperseg=nperseg,
                               noverlap=noverlap, nfft=nfft, detrend=detrend,
                               noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode,
                               return_onesided=return_onesided, scaling=scaling, axis=axis,
                               fmin=fmin, fmax=fmax, output_axis='auto')

@@ -1210,7 +1244,7 @@ def sw_periodogram(x,
def periodogram(x, average='mean',
                # General STFT kwargs
                fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None,
                detrend='constant', return_onesided=True, scaling='density',
                detrend='constant', return_onesided=True, mode='psd', scaling='density',
                axis=-1, fmin=None, fmax=None, return_config=False):
    """Compute Periodogram by averaging across windows in a STFT.

@@ -1283,7 +1317,7 @@ def periodogram(x, average='mean',
    # unspecified options given the data in-hand
    config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)),
                               average=average, fs=fs, window_type=window_type, nperseg=nperseg,
                               noverlap=noverlap, nfft=nfft, detrend=detrend,
                               noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode,
                               return_onesided=return_onesided, scaling=scaling, axis=axis,
                               fmin=fmin, fmax=fmax, output_axis='time_first')

@@ -1291,9 +1325,9 @@ def periodogram(x, average='mean',
    logging.debug(p.shape)

    if config.average == 'mean':
        p = np.nanmean(p, axis=0).real
        p = np.nanmean(p, axis=0)
    elif config.average == 'median':
        p = np.nanmedian(p, axis=0).real
        p = np.nanmedian(p, axis=0)
    elif config.average is None:
        pass
    else:
@@ -1306,11 +1340,11 @@ def periodogram(x, average='mean',
        return f, p


def sw_multitaper(x, num_tapers='auto', time_bandwidth=5, freq_resolution=1,
def sw_multitaper(x, num_tapers='auto', time_bandwidth=5, freq_resolution=1, apply_tapers='broadcast',
                  # General STFT kwargs
                  fs=1.0, nperseg=None, noverlap=None, nfft=None,
                  detrend='constant', return_onesided=True, scaling='density',
                  axis=-1, fmin=None, fmax=None, return_config=True):
                  detrend='constant', return_onesided=True, mode='psd', scaling='density',
                  axis=-1, fmin=None, fmax=None, return_config=False):
    """Compute a multi-tapered power spectrum across windows in a STFT.

    Parameters
@@ -1383,26 +1417,28 @@ def sw_multitaper(x, num_tapers='auto', time_bandwidth=5, freq_resolution=1,
    # unspecified options given the data in-hand
    config = MultiTaperConfig(x.shape[axis],
                              input_complex=np.any(np.iscomplex(x)),
                              time_bandwidth=time_bandwidth, num_tapers=num_tapers,
                              freq_resolution=freq_resolution, average=None, fs=fs,
                              window_type=None, nperseg=nperseg, noverlap=noverlap,
                              nfft=nfft, detrend=detrend, return_onesided=return_onesided,
                              time_bandwidth=time_bandwidth,
                              num_tapers=num_tapers, apply_tapers=apply_tapers,
                              freq_resolution=freq_resolution, average=None,
                              fs=fs, window_type=None, nperseg=nperseg,
                              noverlap=noverlap, nfft=nfft, detrend=detrend,
                              return_onesided=return_onesided, mode=mode,
                              scaling=scaling, axis=axis, fmin=fmin, fmax=fmax,
                              output_axis='auto')

    f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args)

    if return_config:
        return f, t, p.real, config
        return f, t, p, config
    else:
        return f, t, p.real
        return f, t, p


def multitaper(x, average='mean', time_bandwidth=5, num_tapers='auto', freq_resolution=1,
def multitaper(x, average='mean', time_bandwidth=5, num_tapers='auto', freq_resolution=1, apply_tapers='broadcast',
               # General STFT kwargs
               fs=1.0, nperseg=None, noverlap=None, nfft=None,
               detrend='constant', return_onesided=True, scaling='density',
               axis=-1, fmin=None, fmax=None, return_config=True):
               detrend='constant', return_onesided=True, scaling='density', mode='psd',
               axis=-1, fmin=None, fmax=None, return_config=False):
    """Compute a multi-tapered power spectrum averaged across windows in a STFT.

    Parameters
@@ -1477,27 +1513,29 @@ def multitaper(x, average='mean', time_bandwidth=5, num_tapers='auto', freq_reso
    # unspecified options given the data in-hand
    config = MultiTaperConfig(x.shape[axis],
                              input_complex=np.any(np.iscomplex(x)),
                              time_bandwidth=time_bandwidth, num_tapers=num_tapers,
                              freq_resolution=freq_resolution, average=average, fs=fs,
                              window_type=None, nperseg=nperseg, noverlap=noverlap,
                              nfft=nfft, detrend=detrend, return_onesided=return_onesided,
                              time_bandwidth=time_bandwidth,
                              num_tapers=num_tapers, apply_tapers=apply_tapers,
                              freq_resolution=freq_resolution, average=average,
                              fs=fs, window_type=None, nperseg=nperseg,
                              noverlap=noverlap, nfft=nfft, detrend=detrend,
                              return_onesided=return_onesided, mode=mode,
                              scaling=scaling, axis=axis, fmin=fmin, fmax=fmax,
                              output_axis='time_first')

    f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args)

    if config.average == 'mean':
        p = np.nanmean(p, axis=0).real
        p = np.nanmean(p, axis=0)
    elif config.average == 'median':
        p = np.nanmedian(p, axis=0).real
        p = np.nanmedian(p, axis=0)
    else:
        msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'"
        raise ValueError(msg.format(config.average))

    if return_config:
        return f, p.real, config
        return f, p, config
    else:
        return f, p.real
        return f, p


# -----------------------------------------------------------------------
@@ -1998,11 +2036,10 @@ def glm_periodogram(X, covariates=None, confounds=None, fit_method='pinv',
    f, t, p = compute_stft(X, **config.stft_args)

    # Compute model - each method MUST assign copes, varcopes and extras
    if fit_method == 'pinv':
        copes, varcopes, extras = _glm_fit_simple(p, config)
    elif fit_method == 'lstsq':
    if fit_method in ['pinv', 'lstsq']:
        copes, varcopes, extras = _glm_fit_simple(p, covariates, confounds, config,
                                                  fit_method='lstsq', fit_constant=fit_constant)
                                                  fit_method=fit_method,
                                                  fit_constant=fit_constant)
    elif fit_method == 'glmtools':
        copes, varcopes, extras = _glm_fit_glmtools(p, covariates, confounds, config,
                                                    fit_constant=fit_constant)