Loading sails/stft.py +82 −45 Original line number Diff line number Diff line Loading @@ -306,6 +306,7 @@ def compute_stft(x, fs=1.0, fmin=0, fmax=0.5, axis=-1, def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwidth=5, apply_tapers='broadcast', # Data kwargs fs=1.0, fmin=0, fmax=0.5, axis=-1, # window kwargs Loading Loading @@ -415,6 +416,7 @@ def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwi detrend_func=detrend_func, window=None, padded=padded) if apply_tapers == 'broadcast': # Apply tapers - via broadcasting to avoid loops to_shape = np.r_[np.ones((len(y.shape)-1),), num_tapers, nperseg].astype(int) z = y[..., np.newaxis, :] * np.broadcast_to(tapers, to_shape) Loading @@ -428,6 +430,27 @@ def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwi # Average over tapers - could be high level option? mean or median? result = np.average(result, weights=taper_weights, axis=-2) elif apply_tapers == 'loop': # Apply tapers in a loop - slower but uses much less RAM to_shape = np.r_[np.ones((len(y.shape)-1),), nperseg].astype(int) for ii in range(num_tapers): logging.debug('running taper: {0}'.format(ii)) z = y * np.broadcast_to(tapers[ii, :], to_shape) # Run actual FFT freqs, taper_result = compute_fft(z, nfft=nfft, axis=-1, side=side, mode=mode, scale=scale, fs=fs, fmin=fmin, fmax=fmax) logging.debug('tapered and fftd data shape {0}'.format(taper_result.shape)) # Run an incremental average so we don't have to store anything # https://math.stackexchange.com/questions/106700/incremental-averaging/1836447 if ii == 0: result = taper_result else: result = result + (taper_result - result) / (ii + 1) else: raise ValueError("'apply_tapers' option '{0}' not recognised. Use one of 'broadcast' or 'loop''".format(apply_tapers)) # Periodogram Scaling result = result / fs Loading @@ -442,6 +465,16 @@ def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwi return freqs, time, result def compute_spectral_matrix_fft(psd): # psd should be [channels x freq] complex for now S = np.zeros((psd.shape[0], psd.shape[0], psd.shape[1]), dtype=complex) for ii in range(psd.shape[1]): S[:, :, ii] = np.dot(psd[:, ii, np.newaxis], psd[np.newaxis, :, ii].conj()) return S # Helpers - private functions assisting low-level processors def _proc_roll_input(x, axis=-1): Loading Loading @@ -834,7 +867,7 @@ def _set_mode(mode): """ modelist = ['psd', 'complex', 'magnitude', 'angle', 'phase'] if mode not in modelist: raise ValueError('unknown value for mode {}, must be one of {}' raise ValueError("Invalid value ('{}') for mode, must be one of {}" .format(mode, modelist)) Loading Loading @@ -958,6 +991,7 @@ class MultiTaperConfig(STFTConfig): time_bandwidth: int = 3 num_tapers: typing.Union[str, int] = 'auto' freq_resolution: int = 1 apply_tapers: str = 'broadcast' def __post_init__(self): super().__post_init__() Loading @@ -967,9 +1001,9 @@ class MultiTaperConfig(STFTConfig): """ """ args = {} for key in ['time_bandwidth', 'num_tapers', 'freq_resolution', 'fs', 'nperseg', 'nstep', 'nfft', 'detrend_func', 'side', 'scale', 'axis', 'mode', 'padded', 'fmin', 'fmax', 'output_axis']: 'apply_tapers', 'fs', 'nperseg', 'nstep', 'nfft', 'detrend_func', 'side', 'scale', 'axis', 'mode', 'padded', 'fmin', 'fmax', 'output_axis']: args[key] = getattr(self, key) return args Loading Loading @@ -1123,7 +1157,7 @@ def set_options(input_len, def sw_periodogram(x, # General STFT kwargs fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', detrend='constant', return_onesided=True, mode='psd', scaling='density', axis=-1, fmin=None, fmax=None, return_config=False): """Compute Periodogram by averaging across windows in a STFT. Loading Loading @@ -1194,7 +1228,7 @@ def sw_periodogram(x, # unspecified options given the data in-hand config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), average=None, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='auto') Loading @@ -1210,7 +1244,7 @@ def sw_periodogram(x, def periodogram(x, average='mean', # General STFT kwargs fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', detrend='constant', return_onesided=True, mode='psd', scaling='density', axis=-1, fmin=None, fmax=None, return_config=False): """Compute Periodogram by averaging across windows in a STFT. Loading Loading @@ -1283,7 +1317,7 @@ def periodogram(x, average='mean', # unspecified options given the data in-hand config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), average=average, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='time_first') Loading @@ -1291,9 +1325,9 @@ def periodogram(x, average='mean', logging.debug(p.shape) if config.average == 'mean': p = np.nanmean(p, axis=0).real p = np.nanmean(p, axis=0) elif config.average == 'median': p = np.nanmedian(p, axis=0).real p = np.nanmedian(p, axis=0) elif config.average is None: pass else: Loading @@ -1306,11 +1340,11 @@ def periodogram(x, average='mean', return f, p def sw_multitaper(x, num_tapers='auto', time_bandwidth=5, freq_resolution=1, def sw_multitaper(x, num_tapers='auto', time_bandwidth=5, freq_resolution=1, apply_tapers='broadcast', # General STFT kwargs fs=1.0, nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, fmin=None, fmax=None, return_config=True): detrend='constant', return_onesided=True, mode='psd', scaling='density', axis=-1, fmin=None, fmax=None, return_config=False): """Compute a multi-tapered power spectrum across windows in a STFT. Parameters Loading Loading @@ -1383,26 +1417,28 @@ def sw_multitaper(x, num_tapers='auto', time_bandwidth=5, freq_resolution=1, # unspecified options given the data in-hand config = MultiTaperConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), time_bandwidth=time_bandwidth, num_tapers=num_tapers, freq_resolution=freq_resolution, average=None, fs=fs, window_type=None, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, time_bandwidth=time_bandwidth, num_tapers=num_tapers, apply_tapers=apply_tapers, freq_resolution=freq_resolution, average=None, fs=fs, window_type=None, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, mode=mode, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='auto') f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args) if return_config: return f, t, p.real, config return f, t, p, config else: return f, t, p.real return f, t, p def multitaper(x, average='mean', time_bandwidth=5, num_tapers='auto', freq_resolution=1, def multitaper(x, average='mean', time_bandwidth=5, num_tapers='auto', freq_resolution=1, apply_tapers='broadcast', # General STFT kwargs fs=1.0, nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, fmin=None, fmax=None, return_config=True): detrend='constant', return_onesided=True, scaling='density', mode='psd', axis=-1, fmin=None, fmax=None, return_config=False): """Compute a multi-tapered power spectrum averaged across windows in a STFT. Parameters Loading Loading @@ -1477,27 +1513,29 @@ def multitaper(x, average='mean', time_bandwidth=5, num_tapers='auto', freq_reso # unspecified options given the data in-hand config = MultiTaperConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), time_bandwidth=time_bandwidth, num_tapers=num_tapers, freq_resolution=freq_resolution, average=average, fs=fs, window_type=None, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, time_bandwidth=time_bandwidth, num_tapers=num_tapers, apply_tapers=apply_tapers, freq_resolution=freq_resolution, average=average, fs=fs, window_type=None, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, mode=mode, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='time_first') f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args) if config.average == 'mean': p = np.nanmean(p, axis=0).real p = np.nanmean(p, axis=0) elif config.average == 'median': p = np.nanmedian(p, axis=0).real p = np.nanmedian(p, axis=0) else: msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'" raise ValueError(msg.format(config.average)) if return_config: return f, p.real, config return f, p, config else: return f, p.real return f, p # ----------------------------------------------------------------------- Loading Loading @@ -1998,11 +2036,10 @@ def glm_periodogram(X, covariates=None, confounds=None, fit_method='pinv', f, t, p = compute_stft(X, **config.stft_args) # Compute model - each method MUST assign copes, varcopes and extras if fit_method == 'pinv': copes, varcopes, extras = _glm_fit_simple(p, config) elif fit_method == 'lstsq': if fit_method in ['pinv', 'lstsq']: copes, varcopes, extras = _glm_fit_simple(p, covariates, confounds, config, fit_method='lstsq', fit_constant=fit_constant) fit_method=fit_method, fit_constant=fit_constant) elif fit_method == 'glmtools': copes, varcopes, extras = _glm_fit_glmtools(p, covariates, confounds, config, fit_constant=fit_constant) Loading Loading
sails/stft.py +82 −45 Original line number Diff line number Diff line Loading @@ -306,6 +306,7 @@ def compute_stft(x, fs=1.0, fmin=0, fmax=0.5, axis=-1, def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwidth=5, apply_tapers='broadcast', # Data kwargs fs=1.0, fmin=0, fmax=0.5, axis=-1, # window kwargs Loading Loading @@ -415,6 +416,7 @@ def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwi detrend_func=detrend_func, window=None, padded=padded) if apply_tapers == 'broadcast': # Apply tapers - via broadcasting to avoid loops to_shape = np.r_[np.ones((len(y.shape)-1),), num_tapers, nperseg].astype(int) z = y[..., np.newaxis, :] * np.broadcast_to(tapers, to_shape) Loading @@ -428,6 +430,27 @@ def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwi # Average over tapers - could be high level option? mean or median? result = np.average(result, weights=taper_weights, axis=-2) elif apply_tapers == 'loop': # Apply tapers in a loop - slower but uses much less RAM to_shape = np.r_[np.ones((len(y.shape)-1),), nperseg].astype(int) for ii in range(num_tapers): logging.debug('running taper: {0}'.format(ii)) z = y * np.broadcast_to(tapers[ii, :], to_shape) # Run actual FFT freqs, taper_result = compute_fft(z, nfft=nfft, axis=-1, side=side, mode=mode, scale=scale, fs=fs, fmin=fmin, fmax=fmax) logging.debug('tapered and fftd data shape {0}'.format(taper_result.shape)) # Run an incremental average so we don't have to store anything # https://math.stackexchange.com/questions/106700/incremental-averaging/1836447 if ii == 0: result = taper_result else: result = result + (taper_result - result) / (ii + 1) else: raise ValueError("'apply_tapers' option '{0}' not recognised. Use one of 'broadcast' or 'loop''".format(apply_tapers)) # Periodogram Scaling result = result / fs Loading @@ -442,6 +465,16 @@ def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwi return freqs, time, result def compute_spectral_matrix_fft(psd): # psd should be [channels x freq] complex for now S = np.zeros((psd.shape[0], psd.shape[0], psd.shape[1]), dtype=complex) for ii in range(psd.shape[1]): S[:, :, ii] = np.dot(psd[:, ii, np.newaxis], psd[np.newaxis, :, ii].conj()) return S # Helpers - private functions assisting low-level processors def _proc_roll_input(x, axis=-1): Loading Loading @@ -834,7 +867,7 @@ def _set_mode(mode): """ modelist = ['psd', 'complex', 'magnitude', 'angle', 'phase'] if mode not in modelist: raise ValueError('unknown value for mode {}, must be one of {}' raise ValueError("Invalid value ('{}') for mode, must be one of {}" .format(mode, modelist)) Loading Loading @@ -958,6 +991,7 @@ class MultiTaperConfig(STFTConfig): time_bandwidth: int = 3 num_tapers: typing.Union[str, int] = 'auto' freq_resolution: int = 1 apply_tapers: str = 'broadcast' def __post_init__(self): super().__post_init__() Loading @@ -967,9 +1001,9 @@ class MultiTaperConfig(STFTConfig): """ """ args = {} for key in ['time_bandwidth', 'num_tapers', 'freq_resolution', 'fs', 'nperseg', 'nstep', 'nfft', 'detrend_func', 'side', 'scale', 'axis', 'mode', 'padded', 'fmin', 'fmax', 'output_axis']: 'apply_tapers', 'fs', 'nperseg', 'nstep', 'nfft', 'detrend_func', 'side', 'scale', 'axis', 'mode', 'padded', 'fmin', 'fmax', 'output_axis']: args[key] = getattr(self, key) return args Loading Loading @@ -1123,7 +1157,7 @@ def set_options(input_len, def sw_periodogram(x, # General STFT kwargs fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', detrend='constant', return_onesided=True, mode='psd', scaling='density', axis=-1, fmin=None, fmax=None, return_config=False): """Compute Periodogram by averaging across windows in a STFT. Loading Loading @@ -1194,7 +1228,7 @@ def sw_periodogram(x, # unspecified options given the data in-hand config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), average=None, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='auto') Loading @@ -1210,7 +1244,7 @@ def sw_periodogram(x, def periodogram(x, average='mean', # General STFT kwargs fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', detrend='constant', return_onesided=True, mode='psd', scaling='density', axis=-1, fmin=None, fmax=None, return_config=False): """Compute Periodogram by averaging across windows in a STFT. Loading Loading @@ -1283,7 +1317,7 @@ def periodogram(x, average='mean', # unspecified options given the data in-hand config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), average=average, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='time_first') Loading @@ -1291,9 +1325,9 @@ def periodogram(x, average='mean', logging.debug(p.shape) if config.average == 'mean': p = np.nanmean(p, axis=0).real p = np.nanmean(p, axis=0) elif config.average == 'median': p = np.nanmedian(p, axis=0).real p = np.nanmedian(p, axis=0) elif config.average is None: pass else: Loading @@ -1306,11 +1340,11 @@ def periodogram(x, average='mean', return f, p def sw_multitaper(x, num_tapers='auto', time_bandwidth=5, freq_resolution=1, def sw_multitaper(x, num_tapers='auto', time_bandwidth=5, freq_resolution=1, apply_tapers='broadcast', # General STFT kwargs fs=1.0, nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, fmin=None, fmax=None, return_config=True): detrend='constant', return_onesided=True, mode='psd', scaling='density', axis=-1, fmin=None, fmax=None, return_config=False): """Compute a multi-tapered power spectrum across windows in a STFT. Parameters Loading Loading @@ -1383,26 +1417,28 @@ def sw_multitaper(x, num_tapers='auto', time_bandwidth=5, freq_resolution=1, # unspecified options given the data in-hand config = MultiTaperConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), time_bandwidth=time_bandwidth, num_tapers=num_tapers, freq_resolution=freq_resolution, average=None, fs=fs, window_type=None, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, time_bandwidth=time_bandwidth, num_tapers=num_tapers, apply_tapers=apply_tapers, freq_resolution=freq_resolution, average=None, fs=fs, window_type=None, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, mode=mode, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='auto') f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args) if return_config: return f, t, p.real, config return f, t, p, config else: return f, t, p.real return f, t, p def multitaper(x, average='mean', time_bandwidth=5, num_tapers='auto', freq_resolution=1, def multitaper(x, average='mean', time_bandwidth=5, num_tapers='auto', freq_resolution=1, apply_tapers='broadcast', # General STFT kwargs fs=1.0, nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, fmin=None, fmax=None, return_config=True): detrend='constant', return_onesided=True, scaling='density', mode='psd', axis=-1, fmin=None, fmax=None, return_config=False): """Compute a multi-tapered power spectrum averaged across windows in a STFT. Parameters Loading Loading @@ -1477,27 +1513,29 @@ def multitaper(x, average='mean', time_bandwidth=5, num_tapers='auto', freq_reso # unspecified options given the data in-hand config = MultiTaperConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), time_bandwidth=time_bandwidth, num_tapers=num_tapers, freq_resolution=freq_resolution, average=average, fs=fs, window_type=None, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, time_bandwidth=time_bandwidth, num_tapers=num_tapers, apply_tapers=apply_tapers, freq_resolution=freq_resolution, average=average, fs=fs, window_type=None, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, mode=mode, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='time_first') f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args) if config.average == 'mean': p = np.nanmean(p, axis=0).real p = np.nanmean(p, axis=0) elif config.average == 'median': p = np.nanmedian(p, axis=0).real p = np.nanmedian(p, axis=0) else: msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'" raise ValueError(msg.format(config.average)) if return_config: return f, p.real, config return f, p, config else: return f, p.real return f, p # ----------------------------------------------------------------------- Loading Loading @@ -1998,11 +2036,10 @@ def glm_periodogram(X, covariates=None, confounds=None, fit_method='pinv', f, t, p = compute_stft(X, **config.stft_args) # Compute model - each method MUST assign copes, varcopes and extras if fit_method == 'pinv': copes, varcopes, extras = _glm_fit_simple(p, config) elif fit_method == 'lstsq': if fit_method in ['pinv', 'lstsq']: copes, varcopes, extras = _glm_fit_simple(p, covariates, confounds, config, fit_method='lstsq', fit_constant=fit_constant) fit_method=fit_method, fit_constant=fit_constant) elif fit_method == 'glmtools': copes, varcopes, extras = _glm_fit_glmtools(p, covariates, confounds, config, fit_constant=fit_constant) Loading