Commit c072b030 authored by Andrew Quinn's avatar Andrew Quinn
Browse files

big periodogram update... WIP

parent 4c7b7522
Loading
Loading
Loading
Loading
+291 −73
Original line number Diff line number Diff line
import warnings
import logging
from dataclasses import dataclass

import numpy as np
from scipy import fft as sp_fft
from scipy import signal, stats
from scipy.signal import signaltools
from scipy.signal.windows import dpss

logging.basicConfig(level=logging.DEBUG)

# ------------------------------------------------------------------
# Low level computations
#
# These functions are stand-alone data processors
# These functions are stand-alone data processors which are usable on their own
# Inputs are not sanity checked and documentation may point elsewhere but these
# are fast and flexible for expert users.
#
# Most users will interact with these via the high level functions and option
# handlers below.


def apply_delay_embedding(x, nperseg, nstep, window=None, detrend_func=None, padded=False):
@@ -47,6 +57,7 @@ def apply_delay_embedding(x, nperseg, nstep, window=None, detrend_func=None, pad
    # y.shape == nperseg + (nseg-1)*nstep
    # nadd = (-(y.shape[-1]-nperseg) % nstep) % nperseg
    # y = np.r_[y, np.zeros(nadd,)]
    logging.info('delay embedding {0} {1} {2}'.format(x.shape, nperseg, nstep))
    if padded:
        nadd = (-(x.shape[-1]-nperseg) % nstep) % nperseg
        zeros_shape = list(x.shape[:-1]) + [nadd]
@@ -62,25 +73,53 @@ def apply_delay_embedding(x, nperseg, nstep, window=None, detrend_func=None, pad
    strides = y.strides[:-1]+(step*y.strides[-1], y.strides[-1])
    y_window = np.lib.stride_tricks.as_strided(y, shape=shape, strides=strides)

    logging.info('delay embedding {0} '.format(y_window.shape, (y_window**2).sum()))

    if detrend_func is not None:
        logging.info('delay embedding - detrending {0} '.format(detrend_func))
        y_window = detrend_func(y_window)

    if window is not None:
        # Apply windowing
        logging.info('delay embedding - windowing {0} '.format(window.sum()))
        y_window = window * y_window

    logging.info('delay embedding - end {0} {1}'.format(y_window.shape, (y_window**2).sum()))
    return y_window


def compute_stft(x,
                 # kwargs from signal.spectral._spectral_helper
                 fs=1.0, window='hann',
                 nperseg=None, noverlap=None, nfft=None, detrend='constant',
                 return_onesided=True, scaling='density', axis=-1, mode='psd',
                 boundary=None, padded=False,
                 # kwargs from sails
                 fmin=None, fmax=None, return_config=False, config=None,
                 output_roll='auto'):
def compute_fft(x, nfft=256, axis=-1, side='onesided', mode='psd', scale=1.0, fs=1.0, fmin=-0.5, fmax=0.5):
    """Compute, trim and post-process an FFT on last dimension of input array."""
    # Compute FFT
    if side == 'twosided':
        func = sp_fft.fft
    else:
        x = x.real
        func = sp_fft.rfft
    logging.info('fft - start {0} {1} {2} {3}'.format(side, func, x.shape, (x**2).sum()))
    result = func(x, nfft)
    logging.info('fft - {0} {1}'.format(result.shape, (result**2).sum()))

    # Apply spectrum mode selection
    result = _proc_spectrum_mode(result, mode, axis=axis)

    # Apply scaling
    result = _proc_spectrum_scaling(result, scale, side, mode, nfft)

    # Get frequency values
    freqvals = _set_freqvalues(nfft, fs, side)
    # Trim frequency range to specified limits
    fidx = (freqvals >= fmin) & \
           (freqvals <= fmax)
    result = result[..., fidx]
    freqs = freqvals[fidx]

    return result, freqs


def compute_stft(x, nperseg=256, nstep=256, window=None, detrend_func=None,
                 padded=False, nfft=256, axis=-1, side='onesided', mode='psd',
                 scale=1.0, fs=1.0, fmin=0, fmax=0.5, output_axis='auto'):
    """Compute a short-time Fourier transform to a dataset.

    Parameters
@@ -150,7 +189,7 @@ def compute_stft(x,
        Dictionary of values specifying all parameters of a STFT set by
        set_options. Values in config override all other user specified
        options.
    output_roll : {'auto', 'glm'}
    output_axis : {'auto', 'glm'}
        Flag indicating where to roll the time and frequencies dimensions to in
        output array. 'auto' will return the transformed dimensions back the
        position of the transformed input, 'glm' will roll the time windows to
@@ -172,55 +211,84 @@ def compute_stft(x,


    """
    if config is None:
        config = set_options(x.shape[axis], input_complex=np.iscomplexobj(x),
                             fs=fs, window=window, nperseg=nperseg, noverlap=noverlap,
                             nfft=nfft, detrend=detrend, return_onesided=return_onesided,
                             scaling=scaling, axis=axis, mode=mode, boundary=boundary,
                             padded=padded)

    # ---- Work start here
    x = _proc_roll_input(x, axis=config['axis'])
    if axis == -1:
        axis = x.ndim-1
    x = _proc_roll_input(x, axis=axis)

    # window inputs
    y = apply_delay_embedding(x, config['nperseg'], config['nstep'],
                              window=config['win'], detrend_func=config['detrend_func'],
                              padded=config['padded'])
    y = apply_delay_embedding(x, nperseg, nstep, detrend_func=detrend_func, window=window, padded=padded)

    # Compute FFT
    if config['side'] == 'twosided':
        func = sp_fft.fft
    else:
        y = y.real
        func = sp_fft.rfft
    result = func(y, config['nfft'])
    # Run actual FFT
    print(scale)
    result, freqs = compute_fft(y, nfft=nfft, axis=axis, side=side, mode=mode,
                                scale=scale, fs=fs, fmin=fmin, fmax=fmax)

    # Apply spectrum mode selection
    result = _proc_spectrum_mode(result, config['mode'], axis=config['axis'])
    # Create time window vector
    noverlap = nperseg - nstep
    time = np.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1,
                     nperseg - noverlap)/float(fs)

    # Apply scaling
    result = _proc_spectrum_scaling(result, config['scale'], config['side'], config['mode'], config['nfft'])
    # Final two axes are now [..., time x freq]
    result = _proc_unroll_output(result, axis, output_axis=output_axis)

    # Create time window vector
    time = np.arange(config['nperseg']/2, x.shape[-1] - config['nperseg']/2 + 1,
                     config['nperseg'] - config['noverlap'])/float(config['fs'])
    return freqs, time, result

    # Trim frequency range to specified limits
    fidx = (config['freqvals'] >= config['fmin']) & \
           (config['freqvals'] <= config['fmax'])
    result = result[..., fidx]
    freqs = config['freqvals'][fidx]

    # Final two axes are now [..., time x freq]
    result = _proc_unroll_output(result, config['axis'], output_roll=output_roll)
def compute_multitaper_stft(x, freq_resolution=1, num_tapers='auto', time_bandwidth=5,
                            nperseg=256, nstep=256, window=None, detrend_func=None,
                            padded=False, nfft=256, axis=-1, side='onesided', mode='psd',
                            scale=1.0, fs=1.0, fmin=0, fmax=0.5, output_axis='auto'):

    seconds_perseg = nperseg / fs
    time_half_bandwidth = int(seconds_perseg * freq_resolution / 2)
    if num_tapers == 'auto':
        num_tapers = 2 * time_half_bandwidth - 1
    logging.info('multitaper {0} {1} {2} {3}'.format(time_bandwidth, num_tapers, freq_resolution, time_half_bandwidth))

    tapers, ratios = dpss(nperseg, time_bandwidth, num_tapers, return_ratios=True)
    taper_weights = np.ones((num_tapers,)) / num_tapers

    # ---- Work start here
    if axis == -1:
        axis = x.ndim-1
    x = _proc_roll_input(x, axis=axis)

    # delay embedding - don't apply window function...
    y = apply_delay_embedding(x, nperseg, nstep, detrend_func=detrend_func, window=None, padded=padded)

    # Apply tapers via broadcasting
    to_shape = np.r_[np.ones((len(y.shape)-1),), num_tapers, nperseg].astype(int)
    z = y[..., np.newaxis, :] * np.broadcast_to(tapers, to_shape)
    logging.info('multitaper data {0}'.format(z.shape))
    logging.info('multitaper seg ss {0}'.format((z[0, 0, -1, :]**2).sum()))

    # Run actual FFT
    result, freqs = compute_fft(z, nfft=nfft, axis=-1, side=side, mode=mode,
                                scale=scale, fs=fs, fmin=fmin, fmax=fmax)
    logging.info('multitaper fft data {0}'.format(result.shape))
    logging.info('multitaper fft data power {0}'.format((result[0, -1, :, :]**2).sum(axis=1)))


    # Average over tapers - could be high level option? mean or median?
    result = np.average(result, weights=taper_weights, axis=-2)

    # PERIODOGRAM SCALING DOESNT WORK!:!??!
    result = result/ fs

    # Create time window vector
    noverlap = nperseg - nstep
    time = np.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1,
                     nperseg - noverlap)/float(fs)

    # Final two axes are now [..., time x freq] - return them to requested position
    result = _proc_unroll_output(result, axis, output_axis=output_axis)

    if return_config:
        return freqs, time, result, config
    else:
    return freqs, time, result


# Helpers
# Helpers - private functions assisting low-level processors

def _proc_roll_input(x, axis=-1):
    """Move axis to be transformed to final position."""
@@ -229,13 +297,14 @@ def _proc_roll_input(x, axis=-1):
    return x


def _proc_unroll_output(result, axis, output_roll='auto'):
def _proc_unroll_output(result, axis, output_axis='auto'):
    """Move STFT'd dimensions to user specified position."""
    if output_roll == 'auto':
    print('unroll {0} {1} {2}'.format(result.shape, axis, output_axis))
    if output_axis == 'auto':
        # Return time and freq back to original position
        result = np.rollaxis(result, -2, axis)
        result = np.rollaxis(result, -1, axis+1)
    elif output_roll == 'glm':
    elif output_axis == 'glm':
        # Put time at front and freq in original position
        result = np.rollaxis(result, -2, 0)
        result = np.rollaxis(result, -1, axis+1)
@@ -245,6 +314,7 @@ def _proc_unroll_output(result, axis, output_roll='auto'):

def _proc_spectrum_mode(pxx, mode, axis=-1):
    """Apply specified transformation to STFT result."""
    logging.info('fft spectrum mode - {0} {1}'.format(mode, (pxx**2).sum()))
    if mode == 'magnitude':
        pxx = np.abs(pxx)
    elif mode == 'psd':
@@ -258,7 +328,7 @@ def _proc_spectrum_mode(pxx, mode, axis=-1):
            pxx = np.unwrap(pxx, axis=axis)
    elif mode == 'complex':
        pass

    logging.info('fft spectrum mode - {0} {1}'.format(mode, (pxx**2).sum()))
    return pxx


@@ -269,6 +339,7 @@ def _proc_spectrum_scaling(pxx, scale, side, mode, nfft):
    consistent with time-dimension.

    """
    logging.info('fft scaling - {0} {1} {2} {3} {4}'.format(mode, side, nfft, scale, (pxx**2).sum()))
    pxx *= scale
    if side == 'onesided' and mode == 'psd':
        if nfft % 2:
@@ -276,6 +347,7 @@ def _proc_spectrum_scaling(pxx, scale, side, mode, nfft):
        else:
            # Last point is unpaired Nyquist freq point, don't double
            pxx[..., 1:-1] *= 2
    logging.info('fft scaling - {0} {1} {2}'.format(mode, scale, (pxx**2).sum()))
    return pxx


@@ -334,6 +406,7 @@ def _set_noverlap(noverlap, nperseg):

def _set_scaling(scaling, fs, win):
    """Set scaling to be applied to FFT output."""
    print('setting scaling {0} {1} {2}'.format(scaling, fs, win))
    if scaling == 'density':
        scale = 1.0 / (fs * (win*win).sum())
    elif scaling == 'spectrum':
@@ -383,6 +456,105 @@ def _set_frange(fmin, fmax, fs):

    return fmin, fmax

import typing
@dataclass
class STFTConfig:
    # Data specific args
    input_len : int
    axis : int = -1
    input_complex : bool = False
    # General FFT args
    fs : float = 1.0
    window_type : str = 'hann'
    nperseg : int = None
    noverlap : int = None
    nfft : int = None
    detrend : typing.Union[typing.Callable, str] = 'constant'
    return_onesided : bool = True
    scaling : str = 'density'
    mode : str = 'psd'
    boundary : str = None  # Not currently used...
    padded = bool = False
    fmin : float = None
    fmax : float = None
    output_axis : typing.Union[int, str] = 'auto'

    def __post_init__(self):
        self.window, self.nperseg = signal.spectral._triage_segments(self.window_type, self.nperseg, input_length=self.input_len)
        self.nfft = _set_nfft(self.nfft, self.nperseg)
        self.noverlap = _set_noverlap(self.noverlap, self.nperseg)
        self.nstep = self.nperseg - self.noverlap
        self.scale = _set_scaling(self.scaling, self.fs, self.window)
        self.detrend_func = _set_detrend(self.detrend, axis=self.axis)
        self.side = _set_onesided(self.return_onesided, self.input_complex)
        self.freqvals = _set_freqvalues(self.nfft, self.fs, self.side)
        self.fmin, self.fmax = _set_frange(self.fmin, self.fmax, self.fs)
        _set_mode(self.mode)
        print(self)

    @property
    def stft_args(self):
        args = {}
        for key in ['fs', 'nperseg', 'nstep', 'nfft', 'detrend_func',
                    'side', 'scale', 'axis', 'mode', 'window',
                    'padded', 'fmin', 'fmax', 'output_axis']:
            args[key] = getattr(self, key)
        return args

    @property
    def embedding_args(self):
        args = {}
        for key in ['nperseg', 'nstep', 'detrend_func', 'window', 'padded']:
            args[key] = getattr(self, key)
        return args

    @property
    def fft_args(self):
        args = {}
        for key in ['nfft', 'axis', 'side', 'mode', 'scale', 'fs', 'fmin', 'fmax']:
            args[key] = getattr(self, key)
        return args


@dataclass
class PeriodogramConfig(STFTConfig):
    average : str = 'mean'

    def __post_init__(self):
        super().__post_init__()


@dataclass
class GLMPeriodogramConfig(STFTConfig):
    covariates : dict = None
    confounds : dict = None
    fit_method : str = 'pinv'
    fit_constant : bool = True

    def __post_init__(self):
        super().__post_init__()


@dataclass
class MultiTaperConfig(STFTConfig):
    average : str = 'mean'
    time_bandwidth : int = 3
    num_tapers : typing.Union[str, int] = 'auto'
    freq_resolution : int = 1

    def __post_init__(self):
        super().__post_init__()

    @property
    def multitaper_stft_args(self):
        args = {}
        for key in ['time_bandwidth', 'num_tapers', 'freq_resolution',
                    'fs', 'nperseg', 'nstep', 'nfft', 'detrend_func',
                    'side', 'scale', 'axis', 'mode',
                    'padded', 'fmin', 'fmax', 'output_axis']:
            args[key] = getattr(self, key)
        return args


def set_options(input_len,
                # scipy.signal.spectral._spectral_helper kwargs
@@ -526,11 +698,11 @@ def set_options(input_len,
# computations are needed


def psd(x, average='mean',
def periodograms(x, average='mean',
                # General STFT kwargs
        fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
                fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None,
                detrend='constant', return_onesided=True, scaling='density',
        axis=-1, fmin=None, fmax=None):
                axis=-1, fmin=None, fmax=None, return_config=False):
    """Compute Periodogram by averaging across windows in a STFT.

    Parameters
@@ -586,22 +758,66 @@ def psd(x, average='mean',
        Power spectral density or power spectrum of x.

    """
    f, t, p = compute_stft(x, fs=fs, window=window, nperseg=nperseg,

    # Config object stores options in one place and sets sensible defaults for
    # unspecified options given the data in-hand
    config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)),
                               average=average, fs=fs, window_type=window_type, nperseg=nperseg,
                               noverlap=noverlap, nfft=nfft, detrend=detrend,
                           return_onesided=return_onesided, scaling=scaling, axis=axis)
                               return_onesided=return_onesided, scaling=scaling, axis=axis,
                               fmin=fmin, fmax=fmax, output_axis='glm')

    f, t, p = compute_stft(x, **config.stft_args)
    print(p.shape)

    if average == 'mean':
    if config.average == 'mean':
        p = np.nanmean(p, axis=0).real
    elif average == 'median':
    elif config.average == 'median':
        p = np.nanmedian(p, axis=0).real
    else:
        msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'"
        raise ValueError(msg.format(average))
        raise ValueError(msg.format(config.average))

    if return_config:
        return f, p.real, config
    else:
        return f, p.real


def multitaper(x, time_bandwidth=5, num_tapers='auto', freq_resolution=1, average='mean',
               # General STFT kwargs
               fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None,
               detrend='constant', return_onesided=True, scaling='density',
               axis=-1, fmin=None, fmax=None, return_config=True):

    # Config object stores options in one place and sets sensible defaults for
    # unspecified options given the data in-hand
    config = MultiTaperConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)),
                              time_bandwidth=time_bandwidth, num_tapers=num_tapers, freq_resolution=freq_resolution,
                              average=average, fs=fs, window_type=window_type,
                              nperseg=nperseg, noverlap=noverlap, nfft=nfft,
                              detrend=detrend, return_onesided=return_onesided,
                              scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='glm')

    f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args)
    print(p.shape)

    if config.average == 'mean':
        p = np.nanmean(p, axis=0).real
    elif config.average == 'median':
        p = np.nanmedian(p, axis=0).real
    else:
        msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'"
        raise ValueError(msg.format(config.average))

    if return_config:
        return f, p.real, config
    else:
        return f, p.real


# -----------------------------------------------------------------------
# Conditioned Spectrogram Functions
# GLM Spectrogram Functions


def _flatten(X):
@@ -783,7 +999,7 @@ def _glm_fit_glmtools(pxx, covariates, confounds, config, fit_constant=True):
def psd_glm(X, covariates=None, confounds=None, fit_method='pinv',
            fit_constant=True,
            # General STFT kwargs - passed to set_options
            fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
            fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None,
            detrend='constant', return_onesided=True, scaling='density',
            axis=-1, mode='psd', fmin=None, fmax=None):
    """Compute a Power Spectrum with a General Linear Model.
@@ -868,14 +1084,17 @@ def psd_glm(X, covariates=None, confounds=None, fit_method='pinv',
        axis = X.ndim - 1

    # Set configuration
    config = set_options(X.shape[axis], input_complex=np.iscomplexobj(X),
                         fs=fs, fmin=fmin, fmax=fmax, window=window,
                         nperseg=nperseg, noverlap=noverlap, nfft=nfft,
                         detrend=detrend, return_onesided=return_onesided,
    config = GLMPeriodogramConfig(X.shape[axis], covariates=covariates,
                                  confounds=confounds, fit_method=fit_method,
                                  fit_constant=fit_constant,input_complex=np.iscomplexobj(X),
                                  fs=fs, fmin=fmin, fmax=fmax, window_type=window_type,
                                  nperseg=nperseg, noverlap=noverlap,
                                  nfft=nfft, detrend=detrend,
                                  return_onesided=return_onesided,
                                  scaling=scaling, axis=axis, mode=mode)

    # Compute STFT
    f, t, p = compute_stft(X, config=config, output_roll='glm')
    f, t, p = compute_stft(X, config, output_axis='auto')

    # Prepare data
    orig_shape = p.shape
@@ -886,8 +1105,7 @@ def psd_glm(X, covariates=None, confounds=None, fit_method='pinv',
    copes = None  # Can compute separately if fit doesn't handle this
    if fit_method == 'pinv':
        # Design matrix pseudo-inverse method
        copes, varcopes, extras = _glm_fit_simple(p, covariates, confounds, config,
                                                  fit_method='pinv', fit_constant=fit_constant)
        copes, varcopes, extras = _glm_fit_simple(p, config)
    elif fit_method == 'lstsq':
        # numpy.linalg.lstsq method
        copes, varcopes, extras = _glm_fit_simple(p, covariates, confounds, config,