Loading sails/periodogram.py +291 −73 Original line number Diff line number Diff line import warnings import logging from dataclasses import dataclass import numpy as np from scipy import fft as sp_fft from scipy import signal, stats from scipy.signal import signaltools from scipy.signal.windows import dpss logging.basicConfig(level=logging.DEBUG) # ------------------------------------------------------------------ # Low level computations # # These functions are stand-alone data processors # These functions are stand-alone data processors which are usable on their own # Inputs are not sanity checked and documentation may point elsewhere but these # are fast and flexible for expert users. # # Most users will interact with these via the high level functions and option # handlers below. def apply_delay_embedding(x, nperseg, nstep, window=None, detrend_func=None, padded=False): Loading Loading @@ -47,6 +57,7 @@ def apply_delay_embedding(x, nperseg, nstep, window=None, detrend_func=None, pad # y.shape == nperseg + (nseg-1)*nstep # nadd = (-(y.shape[-1]-nperseg) % nstep) % nperseg # y = np.r_[y, np.zeros(nadd,)] logging.info('delay embedding {0} {1} {2}'.format(x.shape, nperseg, nstep)) if padded: nadd = (-(x.shape[-1]-nperseg) % nstep) % nperseg zeros_shape = list(x.shape[:-1]) + [nadd] Loading @@ -62,25 +73,53 @@ def apply_delay_embedding(x, nperseg, nstep, window=None, detrend_func=None, pad strides = y.strides[:-1]+(step*y.strides[-1], y.strides[-1]) y_window = np.lib.stride_tricks.as_strided(y, shape=shape, strides=strides) logging.info('delay embedding {0} '.format(y_window.shape, (y_window**2).sum())) if detrend_func is not None: logging.info('delay embedding - detrending {0} '.format(detrend_func)) y_window = detrend_func(y_window) if window is not None: # Apply windowing logging.info('delay embedding - windowing {0} '.format(window.sum())) y_window = window * y_window logging.info('delay embedding - end {0} {1}'.format(y_window.shape, (y_window**2).sum())) return y_window def compute_stft(x, # kwargs from signal.spectral._spectral_helper fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', boundary=None, padded=False, # kwargs from sails fmin=None, fmax=None, return_config=False, config=None, output_roll='auto'): def compute_fft(x, nfft=256, axis=-1, side='onesided', mode='psd', scale=1.0, fs=1.0, fmin=-0.5, fmax=0.5): """Compute, trim and post-process an FFT on last dimension of input array.""" # Compute FFT if side == 'twosided': func = sp_fft.fft else: x = x.real func = sp_fft.rfft logging.info('fft - start {0} {1} {2} {3}'.format(side, func, x.shape, (x**2).sum())) result = func(x, nfft) logging.info('fft - {0} {1}'.format(result.shape, (result**2).sum())) # Apply spectrum mode selection result = _proc_spectrum_mode(result, mode, axis=axis) # Apply scaling result = _proc_spectrum_scaling(result, scale, side, mode, nfft) # Get frequency values freqvals = _set_freqvalues(nfft, fs, side) # Trim frequency range to specified limits fidx = (freqvals >= fmin) & \ (freqvals <= fmax) result = result[..., fidx] freqs = freqvals[fidx] return result, freqs def compute_stft(x, nperseg=256, nstep=256, window=None, detrend_func=None, padded=False, nfft=256, axis=-1, side='onesided', mode='psd', scale=1.0, fs=1.0, fmin=0, fmax=0.5, output_axis='auto'): """Compute a short-time Fourier transform to a dataset. Parameters Loading Loading @@ -150,7 +189,7 @@ def compute_stft(x, Dictionary of values specifying all parameters of a STFT set by set_options. Values in config override all other user specified options. output_roll : {'auto', 'glm'} output_axis : {'auto', 'glm'} Flag indicating where to roll the time and frequencies dimensions to in output array. 'auto' will return the transformed dimensions back the position of the transformed input, 'glm' will roll the time windows to Loading @@ -172,55 +211,84 @@ def compute_stft(x, """ if config is None: config = set_options(x.shape[axis], input_complex=np.iscomplexobj(x), fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=axis, mode=mode, boundary=boundary, padded=padded) # ---- Work start here x = _proc_roll_input(x, axis=config['axis']) if axis == -1: axis = x.ndim-1 x = _proc_roll_input(x, axis=axis) # window inputs y = apply_delay_embedding(x, config['nperseg'], config['nstep'], window=config['win'], detrend_func=config['detrend_func'], padded=config['padded']) y = apply_delay_embedding(x, nperseg, nstep, detrend_func=detrend_func, window=window, padded=padded) # Compute FFT if config['side'] == 'twosided': func = sp_fft.fft else: y = y.real func = sp_fft.rfft result = func(y, config['nfft']) # Run actual FFT print(scale) result, freqs = compute_fft(y, nfft=nfft, axis=axis, side=side, mode=mode, scale=scale, fs=fs, fmin=fmin, fmax=fmax) # Apply spectrum mode selection result = _proc_spectrum_mode(result, config['mode'], axis=config['axis']) # Create time window vector noverlap = nperseg - nstep time = np.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1, nperseg - noverlap)/float(fs) # Apply scaling result = _proc_spectrum_scaling(result, config['scale'], config['side'], config['mode'], config['nfft']) # Final two axes are now [..., time x freq] result = _proc_unroll_output(result, axis, output_axis=output_axis) # Create time window vector time = np.arange(config['nperseg']/2, x.shape[-1] - config['nperseg']/2 + 1, config['nperseg'] - config['noverlap'])/float(config['fs']) return freqs, time, result # Trim frequency range to specified limits fidx = (config['freqvals'] >= config['fmin']) & \ (config['freqvals'] <= config['fmax']) result = result[..., fidx] freqs = config['freqvals'][fidx] # Final two axes are now [..., time x freq] result = _proc_unroll_output(result, config['axis'], output_roll=output_roll) def compute_multitaper_stft(x, freq_resolution=1, num_tapers='auto', time_bandwidth=5, nperseg=256, nstep=256, window=None, detrend_func=None, padded=False, nfft=256, axis=-1, side='onesided', mode='psd', scale=1.0, fs=1.0, fmin=0, fmax=0.5, output_axis='auto'): seconds_perseg = nperseg / fs time_half_bandwidth = int(seconds_perseg * freq_resolution / 2) if num_tapers == 'auto': num_tapers = 2 * time_half_bandwidth - 1 logging.info('multitaper {0} {1} {2} {3}'.format(time_bandwidth, num_tapers, freq_resolution, time_half_bandwidth)) tapers, ratios = dpss(nperseg, time_bandwidth, num_tapers, return_ratios=True) taper_weights = np.ones((num_tapers,)) / num_tapers # ---- Work start here if axis == -1: axis = x.ndim-1 x = _proc_roll_input(x, axis=axis) # delay embedding - don't apply window function... y = apply_delay_embedding(x, nperseg, nstep, detrend_func=detrend_func, window=None, padded=padded) # Apply tapers via broadcasting to_shape = np.r_[np.ones((len(y.shape)-1),), num_tapers, nperseg].astype(int) z = y[..., np.newaxis, :] * np.broadcast_to(tapers, to_shape) logging.info('multitaper data {0}'.format(z.shape)) logging.info('multitaper seg ss {0}'.format((z[0, 0, -1, :]**2).sum())) # Run actual FFT result, freqs = compute_fft(z, nfft=nfft, axis=-1, side=side, mode=mode, scale=scale, fs=fs, fmin=fmin, fmax=fmax) logging.info('multitaper fft data {0}'.format(result.shape)) logging.info('multitaper fft data power {0}'.format((result[0, -1, :, :]**2).sum(axis=1))) # Average over tapers - could be high level option? mean or median? result = np.average(result, weights=taper_weights, axis=-2) # PERIODOGRAM SCALING DOESNT WORK!:!??! result = result/ fs # Create time window vector noverlap = nperseg - nstep time = np.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1, nperseg - noverlap)/float(fs) # Final two axes are now [..., time x freq] - return them to requested position result = _proc_unroll_output(result, axis, output_axis=output_axis) if return_config: return freqs, time, result, config else: return freqs, time, result # Helpers # Helpers - private functions assisting low-level processors def _proc_roll_input(x, axis=-1): """Move axis to be transformed to final position.""" Loading @@ -229,13 +297,14 @@ def _proc_roll_input(x, axis=-1): return x def _proc_unroll_output(result, axis, output_roll='auto'): def _proc_unroll_output(result, axis, output_axis='auto'): """Move STFT'd dimensions to user specified position.""" if output_roll == 'auto': print('unroll {0} {1} {2}'.format(result.shape, axis, output_axis)) if output_axis == 'auto': # Return time and freq back to original position result = np.rollaxis(result, -2, axis) result = np.rollaxis(result, -1, axis+1) elif output_roll == 'glm': elif output_axis == 'glm': # Put time at front and freq in original position result = np.rollaxis(result, -2, 0) result = np.rollaxis(result, -1, axis+1) Loading @@ -245,6 +314,7 @@ def _proc_unroll_output(result, axis, output_roll='auto'): def _proc_spectrum_mode(pxx, mode, axis=-1): """Apply specified transformation to STFT result.""" logging.info('fft spectrum mode - {0} {1}'.format(mode, (pxx**2).sum())) if mode == 'magnitude': pxx = np.abs(pxx) elif mode == 'psd': Loading @@ -258,7 +328,7 @@ def _proc_spectrum_mode(pxx, mode, axis=-1): pxx = np.unwrap(pxx, axis=axis) elif mode == 'complex': pass logging.info('fft spectrum mode - {0} {1}'.format(mode, (pxx**2).sum())) return pxx Loading @@ -269,6 +339,7 @@ def _proc_spectrum_scaling(pxx, scale, side, mode, nfft): consistent with time-dimension. """ logging.info('fft scaling - {0} {1} {2} {3} {4}'.format(mode, side, nfft, scale, (pxx**2).sum())) pxx *= scale if side == 'onesided' and mode == 'psd': if nfft % 2: Loading @@ -276,6 +347,7 @@ def _proc_spectrum_scaling(pxx, scale, side, mode, nfft): else: # Last point is unpaired Nyquist freq point, don't double pxx[..., 1:-1] *= 2 logging.info('fft scaling - {0} {1} {2}'.format(mode, scale, (pxx**2).sum())) return pxx Loading Loading @@ -334,6 +406,7 @@ def _set_noverlap(noverlap, nperseg): def _set_scaling(scaling, fs, win): """Set scaling to be applied to FFT output.""" print('setting scaling {0} {1} {2}'.format(scaling, fs, win)) if scaling == 'density': scale = 1.0 / (fs * (win*win).sum()) elif scaling == 'spectrum': Loading Loading @@ -383,6 +456,105 @@ def _set_frange(fmin, fmax, fs): return fmin, fmax import typing @dataclass class STFTConfig: # Data specific args input_len : int axis : int = -1 input_complex : bool = False # General FFT args fs : float = 1.0 window_type : str = 'hann' nperseg : int = None noverlap : int = None nfft : int = None detrend : typing.Union[typing.Callable, str] = 'constant' return_onesided : bool = True scaling : str = 'density' mode : str = 'psd' boundary : str = None # Not currently used... padded = bool = False fmin : float = None fmax : float = None output_axis : typing.Union[int, str] = 'auto' def __post_init__(self): self.window, self.nperseg = signal.spectral._triage_segments(self.window_type, self.nperseg, input_length=self.input_len) self.nfft = _set_nfft(self.nfft, self.nperseg) self.noverlap = _set_noverlap(self.noverlap, self.nperseg) self.nstep = self.nperseg - self.noverlap self.scale = _set_scaling(self.scaling, self.fs, self.window) self.detrend_func = _set_detrend(self.detrend, axis=self.axis) self.side = _set_onesided(self.return_onesided, self.input_complex) self.freqvals = _set_freqvalues(self.nfft, self.fs, self.side) self.fmin, self.fmax = _set_frange(self.fmin, self.fmax, self.fs) _set_mode(self.mode) print(self) @property def stft_args(self): args = {} for key in ['fs', 'nperseg', 'nstep', 'nfft', 'detrend_func', 'side', 'scale', 'axis', 'mode', 'window', 'padded', 'fmin', 'fmax', 'output_axis']: args[key] = getattr(self, key) return args @property def embedding_args(self): args = {} for key in ['nperseg', 'nstep', 'detrend_func', 'window', 'padded']: args[key] = getattr(self, key) return args @property def fft_args(self): args = {} for key in ['nfft', 'axis', 'side', 'mode', 'scale', 'fs', 'fmin', 'fmax']: args[key] = getattr(self, key) return args @dataclass class PeriodogramConfig(STFTConfig): average : str = 'mean' def __post_init__(self): super().__post_init__() @dataclass class GLMPeriodogramConfig(STFTConfig): covariates : dict = None confounds : dict = None fit_method : str = 'pinv' fit_constant : bool = True def __post_init__(self): super().__post_init__() @dataclass class MultiTaperConfig(STFTConfig): average : str = 'mean' time_bandwidth : int = 3 num_tapers : typing.Union[str, int] = 'auto' freq_resolution : int = 1 def __post_init__(self): super().__post_init__() @property def multitaper_stft_args(self): args = {} for key in ['time_bandwidth', 'num_tapers', 'freq_resolution', 'fs', 'nperseg', 'nstep', 'nfft', 'detrend_func', 'side', 'scale', 'axis', 'mode', 'padded', 'fmin', 'fmax', 'output_axis']: args[key] = getattr(self, key) return args def set_options(input_len, # scipy.signal.spectral._spectral_helper kwargs Loading Loading @@ -526,11 +698,11 @@ def set_options(input_len, # computations are needed def psd(x, average='mean', def periodograms(x, average='mean', # General STFT kwargs fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, fmin=None, fmax=None): axis=-1, fmin=None, fmax=None, return_config=False): """Compute Periodogram by averaging across windows in a STFT. Parameters Loading Loading @@ -586,22 +758,66 @@ def psd(x, average='mean', Power spectral density or power spectrum of x. """ f, t, p = compute_stft(x, fs=fs, window=window, nperseg=nperseg, # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), average=average, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=axis) return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='glm') f, t, p = compute_stft(x, **config.stft_args) print(p.shape) if average == 'mean': if config.average == 'mean': p = np.nanmean(p, axis=0).real elif average == 'median': elif config.average == 'median': p = np.nanmedian(p, axis=0).real else: msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'" raise ValueError(msg.format(average)) raise ValueError(msg.format(config.average)) if return_config: return f, p.real, config else: return f, p.real def multitaper(x, time_bandwidth=5, num_tapers='auto', freq_resolution=1, average='mean', # General STFT kwargs fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, fmin=None, fmax=None, return_config=True): # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand config = MultiTaperConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), time_bandwidth=time_bandwidth, num_tapers=num_tapers, freq_resolution=freq_resolution, average=average, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='glm') f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args) print(p.shape) if config.average == 'mean': p = np.nanmean(p, axis=0).real elif config.average == 'median': p = np.nanmedian(p, axis=0).real else: msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'" raise ValueError(msg.format(config.average)) if return_config: return f, p.real, config else: return f, p.real # ----------------------------------------------------------------------- # Conditioned Spectrogram Functions # GLM Spectrogram Functions def _flatten(X): Loading Loading @@ -783,7 +999,7 @@ def _glm_fit_glmtools(pxx, covariates, confounds, config, fit_constant=True): def psd_glm(X, covariates=None, confounds=None, fit_method='pinv', fit_constant=True, # General STFT kwargs - passed to set_options fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', fmin=None, fmax=None): """Compute a Power Spectrum with a General Linear Model. Loading Loading @@ -868,14 +1084,17 @@ def psd_glm(X, covariates=None, confounds=None, fit_method='pinv', axis = X.ndim - 1 # Set configuration config = set_options(X.shape[axis], input_complex=np.iscomplexobj(X), fs=fs, fmin=fmin, fmax=fmax, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, config = GLMPeriodogramConfig(X.shape[axis], covariates=covariates, confounds=confounds, fit_method=fit_method, fit_constant=fit_constant,input_complex=np.iscomplexobj(X), fs=fs, fmin=fmin, fmax=fmax, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=axis, mode=mode) # Compute STFT f, t, p = compute_stft(X, config=config, output_roll='glm') f, t, p = compute_stft(X, config, output_axis='auto') # Prepare data orig_shape = p.shape Loading @@ -886,8 +1105,7 @@ def psd_glm(X, covariates=None, confounds=None, fit_method='pinv', copes = None # Can compute separately if fit doesn't handle this if fit_method == 'pinv': # Design matrix pseudo-inverse method copes, varcopes, extras = _glm_fit_simple(p, covariates, confounds, config, fit_method='pinv', fit_constant=fit_constant) copes, varcopes, extras = _glm_fit_simple(p, config) elif fit_method == 'lstsq': # numpy.linalg.lstsq method copes, varcopes, extras = _glm_fit_simple(p, covariates, confounds, config, Loading Loading
sails/periodogram.py +291 −73 Original line number Diff line number Diff line import warnings import logging from dataclasses import dataclass import numpy as np from scipy import fft as sp_fft from scipy import signal, stats from scipy.signal import signaltools from scipy.signal.windows import dpss logging.basicConfig(level=logging.DEBUG) # ------------------------------------------------------------------ # Low level computations # # These functions are stand-alone data processors # These functions are stand-alone data processors which are usable on their own # Inputs are not sanity checked and documentation may point elsewhere but these # are fast and flexible for expert users. # # Most users will interact with these via the high level functions and option # handlers below. def apply_delay_embedding(x, nperseg, nstep, window=None, detrend_func=None, padded=False): Loading Loading @@ -47,6 +57,7 @@ def apply_delay_embedding(x, nperseg, nstep, window=None, detrend_func=None, pad # y.shape == nperseg + (nseg-1)*nstep # nadd = (-(y.shape[-1]-nperseg) % nstep) % nperseg # y = np.r_[y, np.zeros(nadd,)] logging.info('delay embedding {0} {1} {2}'.format(x.shape, nperseg, nstep)) if padded: nadd = (-(x.shape[-1]-nperseg) % nstep) % nperseg zeros_shape = list(x.shape[:-1]) + [nadd] Loading @@ -62,25 +73,53 @@ def apply_delay_embedding(x, nperseg, nstep, window=None, detrend_func=None, pad strides = y.strides[:-1]+(step*y.strides[-1], y.strides[-1]) y_window = np.lib.stride_tricks.as_strided(y, shape=shape, strides=strides) logging.info('delay embedding {0} '.format(y_window.shape, (y_window**2).sum())) if detrend_func is not None: logging.info('delay embedding - detrending {0} '.format(detrend_func)) y_window = detrend_func(y_window) if window is not None: # Apply windowing logging.info('delay embedding - windowing {0} '.format(window.sum())) y_window = window * y_window logging.info('delay embedding - end {0} {1}'.format(y_window.shape, (y_window**2).sum())) return y_window def compute_stft(x, # kwargs from signal.spectral._spectral_helper fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', boundary=None, padded=False, # kwargs from sails fmin=None, fmax=None, return_config=False, config=None, output_roll='auto'): def compute_fft(x, nfft=256, axis=-1, side='onesided', mode='psd', scale=1.0, fs=1.0, fmin=-0.5, fmax=0.5): """Compute, trim and post-process an FFT on last dimension of input array.""" # Compute FFT if side == 'twosided': func = sp_fft.fft else: x = x.real func = sp_fft.rfft logging.info('fft - start {0} {1} {2} {3}'.format(side, func, x.shape, (x**2).sum())) result = func(x, nfft) logging.info('fft - {0} {1}'.format(result.shape, (result**2).sum())) # Apply spectrum mode selection result = _proc_spectrum_mode(result, mode, axis=axis) # Apply scaling result = _proc_spectrum_scaling(result, scale, side, mode, nfft) # Get frequency values freqvals = _set_freqvalues(nfft, fs, side) # Trim frequency range to specified limits fidx = (freqvals >= fmin) & \ (freqvals <= fmax) result = result[..., fidx] freqs = freqvals[fidx] return result, freqs def compute_stft(x, nperseg=256, nstep=256, window=None, detrend_func=None, padded=False, nfft=256, axis=-1, side='onesided', mode='psd', scale=1.0, fs=1.0, fmin=0, fmax=0.5, output_axis='auto'): """Compute a short-time Fourier transform to a dataset. Parameters Loading Loading @@ -150,7 +189,7 @@ def compute_stft(x, Dictionary of values specifying all parameters of a STFT set by set_options. Values in config override all other user specified options. output_roll : {'auto', 'glm'} output_axis : {'auto', 'glm'} Flag indicating where to roll the time and frequencies dimensions to in output array. 'auto' will return the transformed dimensions back the position of the transformed input, 'glm' will roll the time windows to Loading @@ -172,55 +211,84 @@ def compute_stft(x, """ if config is None: config = set_options(x.shape[axis], input_complex=np.iscomplexobj(x), fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=axis, mode=mode, boundary=boundary, padded=padded) # ---- Work start here x = _proc_roll_input(x, axis=config['axis']) if axis == -1: axis = x.ndim-1 x = _proc_roll_input(x, axis=axis) # window inputs y = apply_delay_embedding(x, config['nperseg'], config['nstep'], window=config['win'], detrend_func=config['detrend_func'], padded=config['padded']) y = apply_delay_embedding(x, nperseg, nstep, detrend_func=detrend_func, window=window, padded=padded) # Compute FFT if config['side'] == 'twosided': func = sp_fft.fft else: y = y.real func = sp_fft.rfft result = func(y, config['nfft']) # Run actual FFT print(scale) result, freqs = compute_fft(y, nfft=nfft, axis=axis, side=side, mode=mode, scale=scale, fs=fs, fmin=fmin, fmax=fmax) # Apply spectrum mode selection result = _proc_spectrum_mode(result, config['mode'], axis=config['axis']) # Create time window vector noverlap = nperseg - nstep time = np.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1, nperseg - noverlap)/float(fs) # Apply scaling result = _proc_spectrum_scaling(result, config['scale'], config['side'], config['mode'], config['nfft']) # Final two axes are now [..., time x freq] result = _proc_unroll_output(result, axis, output_axis=output_axis) # Create time window vector time = np.arange(config['nperseg']/2, x.shape[-1] - config['nperseg']/2 + 1, config['nperseg'] - config['noverlap'])/float(config['fs']) return freqs, time, result # Trim frequency range to specified limits fidx = (config['freqvals'] >= config['fmin']) & \ (config['freqvals'] <= config['fmax']) result = result[..., fidx] freqs = config['freqvals'][fidx] # Final two axes are now [..., time x freq] result = _proc_unroll_output(result, config['axis'], output_roll=output_roll) def compute_multitaper_stft(x, freq_resolution=1, num_tapers='auto', time_bandwidth=5, nperseg=256, nstep=256, window=None, detrend_func=None, padded=False, nfft=256, axis=-1, side='onesided', mode='psd', scale=1.0, fs=1.0, fmin=0, fmax=0.5, output_axis='auto'): seconds_perseg = nperseg / fs time_half_bandwidth = int(seconds_perseg * freq_resolution / 2) if num_tapers == 'auto': num_tapers = 2 * time_half_bandwidth - 1 logging.info('multitaper {0} {1} {2} {3}'.format(time_bandwidth, num_tapers, freq_resolution, time_half_bandwidth)) tapers, ratios = dpss(nperseg, time_bandwidth, num_tapers, return_ratios=True) taper_weights = np.ones((num_tapers,)) / num_tapers # ---- Work start here if axis == -1: axis = x.ndim-1 x = _proc_roll_input(x, axis=axis) # delay embedding - don't apply window function... y = apply_delay_embedding(x, nperseg, nstep, detrend_func=detrend_func, window=None, padded=padded) # Apply tapers via broadcasting to_shape = np.r_[np.ones((len(y.shape)-1),), num_tapers, nperseg].astype(int) z = y[..., np.newaxis, :] * np.broadcast_to(tapers, to_shape) logging.info('multitaper data {0}'.format(z.shape)) logging.info('multitaper seg ss {0}'.format((z[0, 0, -1, :]**2).sum())) # Run actual FFT result, freqs = compute_fft(z, nfft=nfft, axis=-1, side=side, mode=mode, scale=scale, fs=fs, fmin=fmin, fmax=fmax) logging.info('multitaper fft data {0}'.format(result.shape)) logging.info('multitaper fft data power {0}'.format((result[0, -1, :, :]**2).sum(axis=1))) # Average over tapers - could be high level option? mean or median? result = np.average(result, weights=taper_weights, axis=-2) # PERIODOGRAM SCALING DOESNT WORK!:!??! result = result/ fs # Create time window vector noverlap = nperseg - nstep time = np.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1, nperseg - noverlap)/float(fs) # Final two axes are now [..., time x freq] - return them to requested position result = _proc_unroll_output(result, axis, output_axis=output_axis) if return_config: return freqs, time, result, config else: return freqs, time, result # Helpers # Helpers - private functions assisting low-level processors def _proc_roll_input(x, axis=-1): """Move axis to be transformed to final position.""" Loading @@ -229,13 +297,14 @@ def _proc_roll_input(x, axis=-1): return x def _proc_unroll_output(result, axis, output_roll='auto'): def _proc_unroll_output(result, axis, output_axis='auto'): """Move STFT'd dimensions to user specified position.""" if output_roll == 'auto': print('unroll {0} {1} {2}'.format(result.shape, axis, output_axis)) if output_axis == 'auto': # Return time and freq back to original position result = np.rollaxis(result, -2, axis) result = np.rollaxis(result, -1, axis+1) elif output_roll == 'glm': elif output_axis == 'glm': # Put time at front and freq in original position result = np.rollaxis(result, -2, 0) result = np.rollaxis(result, -1, axis+1) Loading @@ -245,6 +314,7 @@ def _proc_unroll_output(result, axis, output_roll='auto'): def _proc_spectrum_mode(pxx, mode, axis=-1): """Apply specified transformation to STFT result.""" logging.info('fft spectrum mode - {0} {1}'.format(mode, (pxx**2).sum())) if mode == 'magnitude': pxx = np.abs(pxx) elif mode == 'psd': Loading @@ -258,7 +328,7 @@ def _proc_spectrum_mode(pxx, mode, axis=-1): pxx = np.unwrap(pxx, axis=axis) elif mode == 'complex': pass logging.info('fft spectrum mode - {0} {1}'.format(mode, (pxx**2).sum())) return pxx Loading @@ -269,6 +339,7 @@ def _proc_spectrum_scaling(pxx, scale, side, mode, nfft): consistent with time-dimension. """ logging.info('fft scaling - {0} {1} {2} {3} {4}'.format(mode, side, nfft, scale, (pxx**2).sum())) pxx *= scale if side == 'onesided' and mode == 'psd': if nfft % 2: Loading @@ -276,6 +347,7 @@ def _proc_spectrum_scaling(pxx, scale, side, mode, nfft): else: # Last point is unpaired Nyquist freq point, don't double pxx[..., 1:-1] *= 2 logging.info('fft scaling - {0} {1} {2}'.format(mode, scale, (pxx**2).sum())) return pxx Loading Loading @@ -334,6 +406,7 @@ def _set_noverlap(noverlap, nperseg): def _set_scaling(scaling, fs, win): """Set scaling to be applied to FFT output.""" print('setting scaling {0} {1} {2}'.format(scaling, fs, win)) if scaling == 'density': scale = 1.0 / (fs * (win*win).sum()) elif scaling == 'spectrum': Loading Loading @@ -383,6 +456,105 @@ def _set_frange(fmin, fmax, fs): return fmin, fmax import typing @dataclass class STFTConfig: # Data specific args input_len : int axis : int = -1 input_complex : bool = False # General FFT args fs : float = 1.0 window_type : str = 'hann' nperseg : int = None noverlap : int = None nfft : int = None detrend : typing.Union[typing.Callable, str] = 'constant' return_onesided : bool = True scaling : str = 'density' mode : str = 'psd' boundary : str = None # Not currently used... padded = bool = False fmin : float = None fmax : float = None output_axis : typing.Union[int, str] = 'auto' def __post_init__(self): self.window, self.nperseg = signal.spectral._triage_segments(self.window_type, self.nperseg, input_length=self.input_len) self.nfft = _set_nfft(self.nfft, self.nperseg) self.noverlap = _set_noverlap(self.noverlap, self.nperseg) self.nstep = self.nperseg - self.noverlap self.scale = _set_scaling(self.scaling, self.fs, self.window) self.detrend_func = _set_detrend(self.detrend, axis=self.axis) self.side = _set_onesided(self.return_onesided, self.input_complex) self.freqvals = _set_freqvalues(self.nfft, self.fs, self.side) self.fmin, self.fmax = _set_frange(self.fmin, self.fmax, self.fs) _set_mode(self.mode) print(self) @property def stft_args(self): args = {} for key in ['fs', 'nperseg', 'nstep', 'nfft', 'detrend_func', 'side', 'scale', 'axis', 'mode', 'window', 'padded', 'fmin', 'fmax', 'output_axis']: args[key] = getattr(self, key) return args @property def embedding_args(self): args = {} for key in ['nperseg', 'nstep', 'detrend_func', 'window', 'padded']: args[key] = getattr(self, key) return args @property def fft_args(self): args = {} for key in ['nfft', 'axis', 'side', 'mode', 'scale', 'fs', 'fmin', 'fmax']: args[key] = getattr(self, key) return args @dataclass class PeriodogramConfig(STFTConfig): average : str = 'mean' def __post_init__(self): super().__post_init__() @dataclass class GLMPeriodogramConfig(STFTConfig): covariates : dict = None confounds : dict = None fit_method : str = 'pinv' fit_constant : bool = True def __post_init__(self): super().__post_init__() @dataclass class MultiTaperConfig(STFTConfig): average : str = 'mean' time_bandwidth : int = 3 num_tapers : typing.Union[str, int] = 'auto' freq_resolution : int = 1 def __post_init__(self): super().__post_init__() @property def multitaper_stft_args(self): args = {} for key in ['time_bandwidth', 'num_tapers', 'freq_resolution', 'fs', 'nperseg', 'nstep', 'nfft', 'detrend_func', 'side', 'scale', 'axis', 'mode', 'padded', 'fmin', 'fmax', 'output_axis']: args[key] = getattr(self, key) return args def set_options(input_len, # scipy.signal.spectral._spectral_helper kwargs Loading Loading @@ -526,11 +698,11 @@ def set_options(input_len, # computations are needed def psd(x, average='mean', def periodograms(x, average='mean', # General STFT kwargs fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, fmin=None, fmax=None): axis=-1, fmin=None, fmax=None, return_config=False): """Compute Periodogram by averaging across windows in a STFT. Parameters Loading Loading @@ -586,22 +758,66 @@ def psd(x, average='mean', Power spectral density or power spectrum of x. """ f, t, p = compute_stft(x, fs=fs, window=window, nperseg=nperseg, # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), average=average, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=axis) return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='glm') f, t, p = compute_stft(x, **config.stft_args) print(p.shape) if average == 'mean': if config.average == 'mean': p = np.nanmean(p, axis=0).real elif average == 'median': elif config.average == 'median': p = np.nanmedian(p, axis=0).real else: msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'" raise ValueError(msg.format(average)) raise ValueError(msg.format(config.average)) if return_config: return f, p.real, config else: return f, p.real def multitaper(x, time_bandwidth=5, num_tapers='auto', freq_resolution=1, average='mean', # General STFT kwargs fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, fmin=None, fmax=None, return_config=True): # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand config = MultiTaperConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), time_bandwidth=time_bandwidth, num_tapers=num_tapers, freq_resolution=freq_resolution, average=average, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='glm') f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args) print(p.shape) if config.average == 'mean': p = np.nanmean(p, axis=0).real elif config.average == 'median': p = np.nanmedian(p, axis=0).real else: msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'" raise ValueError(msg.format(config.average)) if return_config: return f, p.real, config else: return f, p.real # ----------------------------------------------------------------------- # Conditioned Spectrogram Functions # GLM Spectrogram Functions def _flatten(X): Loading Loading @@ -783,7 +999,7 @@ def _glm_fit_glmtools(pxx, covariates, confounds, config, fit_constant=True): def psd_glm(X, covariates=None, confounds=None, fit_method='pinv', fit_constant=True, # General STFT kwargs - passed to set_options fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', fmin=None, fmax=None): """Compute a Power Spectrum with a General Linear Model. Loading Loading @@ -868,14 +1084,17 @@ def psd_glm(X, covariates=None, confounds=None, fit_method='pinv', axis = X.ndim - 1 # Set configuration config = set_options(X.shape[axis], input_complex=np.iscomplexobj(X), fs=fs, fmin=fmin, fmax=fmax, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, config = GLMPeriodogramConfig(X.shape[axis], covariates=covariates, confounds=confounds, fit_method=fit_method, fit_constant=fit_constant,input_complex=np.iscomplexobj(X), fs=fs, fmin=fmin, fmax=fmax, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=axis, mode=mode) # Compute STFT f, t, p = compute_stft(X, config=config, output_roll='glm') f, t, p = compute_stft(X, config, output_axis='auto') # Prepare data orig_shape = p.shape Loading @@ -886,8 +1105,7 @@ def psd_glm(X, covariates=None, confounds=None, fit_method='pinv', copes = None # Can compute separately if fit doesn't handle this if fit_method == 'pinv': # Design matrix pseudo-inverse method copes, varcopes, extras = _glm_fit_simple(p, covariates, confounds, config, fit_method='pinv', fit_constant=fit_constant) copes, varcopes, extras = _glm_fit_simple(p, config) elif fit_method == 'lstsq': # numpy.linalg.lstsq method copes, varcopes, extras = _glm_fit_simple(p, covariates, confounds, config, Loading