Loading sails/_docstring_utils.py +0 −2 Original line number Diff line number Diff line Loading @@ -277,9 +277,7 @@ stft_funcs = ['apply_sliding_window', '_set_detrend', '_set_mode', '_set_frange', 'sw_periodogram', 'periodogram', 'sw_multitaper', 'multitaper', 'glm_periodogram', 'glm_multitaper', Loading sails/stft.py +228 −119 Original line number Diff line number Diff line Loading @@ -533,8 +533,13 @@ def _proc_get_freq_inds(freqvals, fmin, fmax): (freqvals <= fmax) return fidx def _proc_get_time_range(nperseg, nstep, fs, xlen): # Create time window vector noverlap = nperseg - nstep return np.arange(nperseg/2, xlen - nperseg/2 + 1, nperseg - noverlap)/float(fs) def _proc_trim_freq_range(result, freqvals, fmin, fmax): def _proc_trim_freq_range(result, freqvals, fmin=None, fmax=None, axis=-1): """Trim an FFT output to desired frequency range. This helper function assumes that we want to trim the final axis. Loading @@ -546,6 +551,7 @@ def _proc_trim_freq_range(result, freqvals, fmin, fmax): freqvals : vector Vector of frequency values with length matching the final axis of result %(freq_range)s' %(axis)s' Returns ------- Loading @@ -555,10 +561,16 @@ def _proc_trim_freq_range(result, freqvals, fmin, fmax): New frequency array matching the final axis of output """ if fmin is None and fmax is None: # Just passing through return freqs, result logging.info('Trimming freq axis to range {0} - {1}'.format(fmin, fmax)) fmin = freqvals[0] if fmin is None else fmin fmax = freqvals[-1] if fmax is None else fmax fidx = _proc_get_freq_inds(freqvals, fmin, fmax) result = result[..., fidx] freqs = freqvals[fidx] result = np.compress(fidx, result, axis=axis) freqs = np.compress(fidx, freqvals) logging.debug('fft trimmed output shape {0}'.format(result.shape)) return freqs, result Loading Loading @@ -1060,57 +1072,15 @@ class GLMIRASAConfig(GLMPeriodogramConfig, IRASAConfig): # These functions take input data, run the option handling and execute whatever # computations are needed @dataclass class SpectrumResult: @set_verbose def sw_periodogram(x, # STFT window args nperseg=None, noverlap=None, window_type='hann', detrend='constant', # FFT core args nfft=None, axis=-1, return_onesided=True, mode='psd', scaling='density', fs=1.0, fmin=None, fmax=None, # misc return_config=False, verbose=None): """Compute Periodogram by averaging across windows in a STFT. Parameters ---------- x : array_like Time series of measurement values %(stft_window_user)s' %(fft_user)s' return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) %(verbose)s' Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. config : PeriodogramConfig, optional Configuration object containing all parameters used to compute spectrum, optionally returned based on value of `return_config`. """ # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), average=None, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='auto') f, t, p = compute_stft(x, **config.stft_args) logging.debug(p.shape) def __init__(self, f, t, pxx, config=None): if return_config: return f, t, p, config else: return f, t, p self.f = f self.t = t self.spectrum = pxx self.config = config @set_verbose Loading Loading @@ -1162,71 +1132,11 @@ def periodogram(x, average='mean', logging.info('Averaging across first dim of result using method {0}'.format(config.average)) p = _proc_apply_average(p, config.average, axis=0) t = None if config.average is None else t logging.info('Returning spectrum of shape {0}'.format(p.shape)) if return_config: return f, p, config else: return f, p @set_verbose def sw_multitaper(x, # Multitaper core num_tapers='auto', freq_resolution=1, time_bandwidth=5, apply_tapers='broadcast', # STFT window args nperseg=None, noverlap=None, window_type='hann', detrend='constant', # FFT core args nfft=None, axis=-1, return_onesided=True, mode='psd', scaling='density', fs=1.0, fmin=None, fmax=None, # misc return_config=False, verbose=None): """Compute a multi-tapered power spectrum across windows in a STFT. Parameters ---------- x : array_like Time series of measurement values %(multitaper_core)s' %(stft_window_user)s' %(fft_user)s' return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) %(verbose)s' Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. config : MultiTaperConfig, optional Configuration object containing all parameters used to compute spectrum, optionally returned based on value of `return_config`. """ # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand config = MultiTaperConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), time_bandwidth=time_bandwidth, num_tapers=num_tapers, apply_tapers=apply_tapers, freq_resolution=freq_resolution, average=None, fs=fs, window_type=None, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, mode=mode, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='auto') f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args) if return_config: return f, t, p, config else: return f, t, p return SpectrumResult(f, t, p, config=config) @set_verbose Loading Loading @@ -1284,11 +1194,9 @@ def multitaper(x, average='mean', f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args) p = _proc_apply_average(p, config.average, axis=axis) t = None if config.average is None else t if return_config: return f, p, config else: return f, p return SpectrumResult(f, t, p, config=config) # ----------------------------------------------------------------------- Loading Loading @@ -1936,6 +1844,122 @@ def glm_multitaper(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None, # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand @set_verbose def irasa_v3(x, method='original', resample_factors=None, aperiodic_average='median', #bootstrap_osc=None, bs_average='mean', # Periodogram args average='median', # STFT window args nperseg=None, noverlap=None, window_type='hann', detrend='constant', # FFT core args nfft=None, axis=-1, return_onesided=True, mode='psd', scaling='density', fs=1.0, fmin=None, fmax=None, # misc return_config=False, verbose=None, avg2=False): """Compute Periodogram by averaging across windows in a STFT. Parameters ---------- x : array_like Time series of measurement values %(irasa)s` %(average)s` %(stft_window_user)s' %(fft_user)s' return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) %(verbose)s' Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. config : PeriodogramConfig, optional Configuration object containing all parameters used to compute spectrum, optionally returned based on value of `return_config`. """ # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand logging.info('Setting config options') config = IRASAConfig(x.shape[axis], method=method, resample_factors=resample_factors, aperiodic_average=aperiodic_average, input_complex=np.any(np.iscomplex(x)), average=average, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=None, fmax=None, output_axis='time_first') fmin = 0 if fmin is None else fmin fmax = fs / 2 if fmax is None else fmax if resample_factors is None and method == 'original': resample_factors = np.linspace(1.1, 1.9, 17) elif resample_factors is None and method == 'modified': resample_factors = np.exp(np.linspace(np.log(0.33), np.log(2), 7)) logging.info('Resample factors defined : {}'.format(resample_factors)) resample_factors = np.round(resample_factors, 4) if resample_factors.min() * config.fmax < fmax: msg = 'Requested fmax exceeds the range of valid resamplings. Specified fmax : {} Max valid fmax : {}' logging.warning(msg.format(fmax, resample_factors.min() * config.fmax)) if resample_factors.max() * config.fmin > fmin: # Not sure that this would ever trigger msg = 'Requested fmin exceeds the range of valid resamplings. Specified fmin : {} Min valid fmin : {}' logging.warning(msg.format(fmin, resample_factors.max() * config.fmin)) # Compute IRASA for each individual windowed sliding window data segment. x = _proc_roll_input(x, axis=config.axis) # Put time to final dimension yy = apply_sliding_window(x, **config.sliding_window_args) print('a : {}'.format(np.sum(yy**2))) freqs, pxx_full = compute_fft(yy, **config.fft_args) t = _proc_get_time_range(config.nperseg, config.nstep, config.fs, config.input_len) t = None if config.average is None else t for ii, rf in enumerate(resample_factors): rat = fractions.Fraction(str(rf)) up, down = rat.numerator, rat.denominator logging.info('Resampling by {}, factor {} of {}'.format(rf, ii+1, len(resample_factors))) zz = signal.resample_poly(yy, up, down, axis=-1) # This changes the variance of the signal... freqs, pxx_resampled = compute_fft(zz, **config.fft_args) if ii == 0: pxx_aperiodic = pxx_resampled[None, ...] else: pxx_aperiodic = np.concatenate((pxx_aperiodic, pxx_resampled[None, ...]), axis=0) # Average across resamplings logging.info('Averaging across {0} resamplings using method {1}'.format(pxx_aperiodic.shape[0], config.aperiodic_average)) pxx_aperiodic = _proc_apply_average(pxx_aperiodic, config.aperiodic_average, axis=0) pxx_oscillatory = pxx_full - pxx_aperiodic # Get trimmed frequency range out_freqs, pxx_aperiodic = _proc_trim_freq_range(pxx_aperiodic, freqs, fmin, fmax) out_freqs, pxx_oscillatory = _proc_trim_freq_range(pxx_oscillatory, freqs, fmin, fmax) logging.info('Averaging across {0} time segments using method {1}'.format(pxx_aperiodic.shape[0], config.average)) pxx_aperiodic = _proc_apply_average(pxx_aperiodic, config.average, axis=0) pxx_oscillatory = _proc_apply_average(pxx_oscillatory, config.average, axis=0) if pxx_aperiodic.ndim > 1: pxx_aperiodic = _proc_unroll_output(pxx_aperiodic, pxx_aperiodic.ndim-1, output_axis='auto') pxx_oscillatory = _proc_unroll_output(pxx_oscillatory, pxx_aperiodic.ndim-1, output_axis='auto') aperiodic = SpectrumResult(out_freqs, t, pxx_aperiodic, config=config) oscillatory = SpectrumResult(out_freqs, t, pxx_oscillatory, config=config) return aperiodic, oscillatory @set_verbose def irasa(x, method='original', resample_factors=None, aperiodic_average='median', #bootstrap_osc=None, bs_average='mean', Loading Loading @@ -2061,6 +2085,91 @@ def irasa(x, method='original', resample_factors=None, aperiodic_average='median return freqs, aperiodic_pxx, oscillatory_pxx @set_verbose def glm_irasa_v3(x, # GLM args reg_categorical=None, reg_ztrans=None, reg_unitmax=None, contrasts=None, fit_method='pinv', fit_intercept=True, # IRASA args method='original', resample_factors=None, aperiodic_average='median', # Periodogram args average='median', # STFT window args nperseg=None, noverlap=None, window_type='hann', detrend='constant', # FFT core args nfft=None, axis=-1, return_onesided=True, mode='psd', scaling='density', fs=1.0, fmin=None, fmax=None, # misc return_config=False, verbose=None): """Compute Periodogram by averaging across windows in a STFT. Parameters ---------- x : array_like Time series of measurement values %(glmperiodogram)s' %(irasa)s' %(average)s' %(stft_window_user)s' %(fft_user)s' return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) %(verbose)s' Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. config : PeriodogramConfig, optional Configuration object containing all parameters used to compute spectrum, optionally returned based on value of `return_config`. """ # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand logging.info('Setting config options') aperiodic, oscillatory = irasa_v3(x, reg_categorical=reg_categorical, reg_ztrans=reg_ztrans, reg_unitmax=reg_unitmax, contrasts=contrasts, fit_method=fit_method, fit_intercept=fit_intercept, # IRASA args method=method, resample_factors=resample_factors, aperiodic_average=aperiodic_average, # Periodogram args average=average, # STFT window args nperseg=nperseg, noverlap=noverlap, window_type=window_type, detrend=detrend, # FFT core args nfft=nfft, axis=axis, return_onesided=return_onesided, mode=mode, scaling=scaling, fs=fs, fmin=fmin, fmax=fmax, # misc return_config=return_config, verbose=verbose) # Transform inputs into predicable, sanity checked dictionaries logging.info('Processing Conditions, Covariates and Confounds') for ind, pxx in enumerate((aperiodic, oscillatory)): pxx.config.reg_categorical = _process_input_covariate(pxx.config.reg_categorical, pxx.config.input_len) pxx.config.reg_ztrans = _process_input_covariate(pxx.config.reg_ztrans, pxx.config.input_len) pxx.config.reg_unitmax = _process_input_covariate(pxx.config.reg_unitmax, pxx.config.input_len) # Fit model model, des, data = _glm_fit_glmtools(pxx.spectrum, pxx.config.reg_categorical, pxx.config.reg_ztrans, pxx.config.reg_unitmax, pxx.config, contrasts=pxx.config.contrasts, fit_intercept=pxx.config.fit_intercept) if ind == 0: glm_aperiodic = GLMSpectrumResult(pxx.f, model, des, data, config=pxx.config) else: glm_oscillatory = GLMSpectrumResult(pxx.f, model, des, data, config=pxx.config) return glm_aperiodic, glm_oscillatory @set_verbose def glm_irasa(x, # GLM args Loading sails/tests/test_stft.py +19 −19 Original line number Diff line number Diff line Loading @@ -21,9 +21,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): for ii in range(5): xx = np.random.randn(4096,) f, pxx = signal.welch(xx, nperseg=2**(4+ii)) f2, pxx2 = periodogram(xx, nperseg=2**(4+ii)) pxx2 = periodogram(xx, nperseg=2**(4+ii)) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_window_type(self): """Ensure window type results are consistent.""" Loading @@ -35,9 +35,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): xx = np.random.randn(4096,) win = window_tests[ii] if window_tests[ii] is not None else np.ones((128,)) / 128 f, pxx = signal.welch(xx, nperseg=128, window=win) f2, pxx2 = periodogram(xx, nperseg=128, window_type=window_tests[ii]) pxx2 = periodogram(xx, nperseg=128, window_type=window_tests[ii]) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_nfft(self): """Ensure nfft results are consistent.""" Loading @@ -46,9 +46,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): for ii in range(5): xx = np.random.randn(4096,) f, pxx = signal.welch(xx, nfft=2**(ii+4), nperseg=16) f2, pxx2 = periodogram(xx, nfft=2**(ii+4), nperseg=16) pxx2 = periodogram(xx, nfft=2**(ii+4), nperseg=16) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_scaling(self): """Ensure scaling results are consistent.""" Loading @@ -59,9 +59,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): for ii in range(len(scaling_tests)): xx = np.random.randn(4096,) f, pxx = signal.welch(xx, nperseg=128, scaling=scaling_tests[ii]) f2, pxx2 = periodogram(xx, nperseg=128, scaling=scaling_tests[ii]) pxx2 = periodogram(xx, nperseg=128, scaling=scaling_tests[ii]) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_sided(self): """Ensure scaling results are consistent.""" Loading @@ -73,9 +73,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): print(side_tests[ii]) xx = np.random.randn(4096,) f, pxx = signal.welch(xx, nperseg=128, return_onesided=side_tests[ii]) f2, pxx2 = periodogram(xx, nperseg=128, return_onesided=side_tests[ii]) pxx2 = periodogram(xx, nperseg=128, return_onesided=side_tests[ii]) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_detrend(self): """Ensure scaling results are consistent.""" Loading @@ -87,9 +87,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): print(detrend_tests[ii]) xx = np.random.randn(4096,) f, pxx = signal.welch(xx, nperseg=128, detrend=detrend_tests[ii]) f2, pxx2 = periodogram(xx, nperseg=128, detrend=detrend_tests[ii]) pxx2 = periodogram(xx, nperseg=128, detrend=detrend_tests[ii]) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_average(self): """Ensure scaling results are consistent.""" Loading @@ -106,9 +106,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): # they use a median with bias correction referencing a paper # https://github.com/scipy/scipy/blob/v1.11.3/scipy/signal/_spectral_py.py#L2037 avg = average_tests[ii] if ii == 0 else average_tests[ii] + '_bias' f2, pxx2 = periodogram(xx, nperseg=128, average=avg, verbose='DEBUG') pxx2 = periodogram(xx, nperseg=128, average=avg, verbose='DEBUG') assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) class TestBasicIRASA(unittest.TestCase): Loading @@ -116,12 +116,12 @@ class TestBasicIRASA(unittest.TestCase): def test_canary_irasa(self): """Ensure irasa runs.""" from ..stft import irasa, periodogram from ..stft import irasa_v3, periodogram # Run test 5 times for ii in range(5): xx = np.random.randn(4096,) f, pxx = periodogram(xx, nperseg=2**(4+ii), average='median') f2, aperiodic, oscillatory = irasa(xx, nperseg=2**(4+ii)) assert(np.all(f == f2)) assert(np.allclose(pxx, aperiodic+oscillatory)) pxx = periodogram(xx, nperseg=2**(4+ii), average='mean') aperiodic, oscillatory = irasa_v3(xx, nperseg=2**(4+ii), average='mean') assert(np.all(pxx.f == aperiodic.f)) assert(np.allclose(pxx.spectrum, aperiodic.spectrum + oscillatory.spectrum)) Loading
sails/_docstring_utils.py +0 −2 Original line number Diff line number Diff line Loading @@ -277,9 +277,7 @@ stft_funcs = ['apply_sliding_window', '_set_detrend', '_set_mode', '_set_frange', 'sw_periodogram', 'periodogram', 'sw_multitaper', 'multitaper', 'glm_periodogram', 'glm_multitaper', Loading
sails/stft.py +228 −119 Original line number Diff line number Diff line Loading @@ -533,8 +533,13 @@ def _proc_get_freq_inds(freqvals, fmin, fmax): (freqvals <= fmax) return fidx def _proc_get_time_range(nperseg, nstep, fs, xlen): # Create time window vector noverlap = nperseg - nstep return np.arange(nperseg/2, xlen - nperseg/2 + 1, nperseg - noverlap)/float(fs) def _proc_trim_freq_range(result, freqvals, fmin, fmax): def _proc_trim_freq_range(result, freqvals, fmin=None, fmax=None, axis=-1): """Trim an FFT output to desired frequency range. This helper function assumes that we want to trim the final axis. Loading @@ -546,6 +551,7 @@ def _proc_trim_freq_range(result, freqvals, fmin, fmax): freqvals : vector Vector of frequency values with length matching the final axis of result %(freq_range)s' %(axis)s' Returns ------- Loading @@ -555,10 +561,16 @@ def _proc_trim_freq_range(result, freqvals, fmin, fmax): New frequency array matching the final axis of output """ if fmin is None and fmax is None: # Just passing through return freqs, result logging.info('Trimming freq axis to range {0} - {1}'.format(fmin, fmax)) fmin = freqvals[0] if fmin is None else fmin fmax = freqvals[-1] if fmax is None else fmax fidx = _proc_get_freq_inds(freqvals, fmin, fmax) result = result[..., fidx] freqs = freqvals[fidx] result = np.compress(fidx, result, axis=axis) freqs = np.compress(fidx, freqvals) logging.debug('fft trimmed output shape {0}'.format(result.shape)) return freqs, result Loading Loading @@ -1060,57 +1072,15 @@ class GLMIRASAConfig(GLMPeriodogramConfig, IRASAConfig): # These functions take input data, run the option handling and execute whatever # computations are needed @dataclass class SpectrumResult: @set_verbose def sw_periodogram(x, # STFT window args nperseg=None, noverlap=None, window_type='hann', detrend='constant', # FFT core args nfft=None, axis=-1, return_onesided=True, mode='psd', scaling='density', fs=1.0, fmin=None, fmax=None, # misc return_config=False, verbose=None): """Compute Periodogram by averaging across windows in a STFT. Parameters ---------- x : array_like Time series of measurement values %(stft_window_user)s' %(fft_user)s' return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) %(verbose)s' Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. config : PeriodogramConfig, optional Configuration object containing all parameters used to compute spectrum, optionally returned based on value of `return_config`. """ # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), average=None, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='auto') f, t, p = compute_stft(x, **config.stft_args) logging.debug(p.shape) def __init__(self, f, t, pxx, config=None): if return_config: return f, t, p, config else: return f, t, p self.f = f self.t = t self.spectrum = pxx self.config = config @set_verbose Loading Loading @@ -1162,71 +1132,11 @@ def periodogram(x, average='mean', logging.info('Averaging across first dim of result using method {0}'.format(config.average)) p = _proc_apply_average(p, config.average, axis=0) t = None if config.average is None else t logging.info('Returning spectrum of shape {0}'.format(p.shape)) if return_config: return f, p, config else: return f, p @set_verbose def sw_multitaper(x, # Multitaper core num_tapers='auto', freq_resolution=1, time_bandwidth=5, apply_tapers='broadcast', # STFT window args nperseg=None, noverlap=None, window_type='hann', detrend='constant', # FFT core args nfft=None, axis=-1, return_onesided=True, mode='psd', scaling='density', fs=1.0, fmin=None, fmax=None, # misc return_config=False, verbose=None): """Compute a multi-tapered power spectrum across windows in a STFT. Parameters ---------- x : array_like Time series of measurement values %(multitaper_core)s' %(stft_window_user)s' %(fft_user)s' return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) %(verbose)s' Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. config : MultiTaperConfig, optional Configuration object containing all parameters used to compute spectrum, optionally returned based on value of `return_config`. """ # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand config = MultiTaperConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), time_bandwidth=time_bandwidth, num_tapers=num_tapers, apply_tapers=apply_tapers, freq_resolution=freq_resolution, average=None, fs=fs, window_type=None, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, mode=mode, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='auto') f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args) if return_config: return f, t, p, config else: return f, t, p return SpectrumResult(f, t, p, config=config) @set_verbose Loading Loading @@ -1284,11 +1194,9 @@ def multitaper(x, average='mean', f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args) p = _proc_apply_average(p, config.average, axis=axis) t = None if config.average is None else t if return_config: return f, p, config else: return f, p return SpectrumResult(f, t, p, config=config) # ----------------------------------------------------------------------- Loading Loading @@ -1936,6 +1844,122 @@ def glm_multitaper(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None, # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand @set_verbose def irasa_v3(x, method='original', resample_factors=None, aperiodic_average='median', #bootstrap_osc=None, bs_average='mean', # Periodogram args average='median', # STFT window args nperseg=None, noverlap=None, window_type='hann', detrend='constant', # FFT core args nfft=None, axis=-1, return_onesided=True, mode='psd', scaling='density', fs=1.0, fmin=None, fmax=None, # misc return_config=False, verbose=None, avg2=False): """Compute Periodogram by averaging across windows in a STFT. Parameters ---------- x : array_like Time series of measurement values %(irasa)s` %(average)s` %(stft_window_user)s' %(fft_user)s' return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) %(verbose)s' Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. config : PeriodogramConfig, optional Configuration object containing all parameters used to compute spectrum, optionally returned based on value of `return_config`. """ # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand logging.info('Setting config options') config = IRASAConfig(x.shape[axis], method=method, resample_factors=resample_factors, aperiodic_average=aperiodic_average, input_complex=np.any(np.iscomplex(x)), average=average, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=None, fmax=None, output_axis='time_first') fmin = 0 if fmin is None else fmin fmax = fs / 2 if fmax is None else fmax if resample_factors is None and method == 'original': resample_factors = np.linspace(1.1, 1.9, 17) elif resample_factors is None and method == 'modified': resample_factors = np.exp(np.linspace(np.log(0.33), np.log(2), 7)) logging.info('Resample factors defined : {}'.format(resample_factors)) resample_factors = np.round(resample_factors, 4) if resample_factors.min() * config.fmax < fmax: msg = 'Requested fmax exceeds the range of valid resamplings. Specified fmax : {} Max valid fmax : {}' logging.warning(msg.format(fmax, resample_factors.min() * config.fmax)) if resample_factors.max() * config.fmin > fmin: # Not sure that this would ever trigger msg = 'Requested fmin exceeds the range of valid resamplings. Specified fmin : {} Min valid fmin : {}' logging.warning(msg.format(fmin, resample_factors.max() * config.fmin)) # Compute IRASA for each individual windowed sliding window data segment. x = _proc_roll_input(x, axis=config.axis) # Put time to final dimension yy = apply_sliding_window(x, **config.sliding_window_args) print('a : {}'.format(np.sum(yy**2))) freqs, pxx_full = compute_fft(yy, **config.fft_args) t = _proc_get_time_range(config.nperseg, config.nstep, config.fs, config.input_len) t = None if config.average is None else t for ii, rf in enumerate(resample_factors): rat = fractions.Fraction(str(rf)) up, down = rat.numerator, rat.denominator logging.info('Resampling by {}, factor {} of {}'.format(rf, ii+1, len(resample_factors))) zz = signal.resample_poly(yy, up, down, axis=-1) # This changes the variance of the signal... freqs, pxx_resampled = compute_fft(zz, **config.fft_args) if ii == 0: pxx_aperiodic = pxx_resampled[None, ...] else: pxx_aperiodic = np.concatenate((pxx_aperiodic, pxx_resampled[None, ...]), axis=0) # Average across resamplings logging.info('Averaging across {0} resamplings using method {1}'.format(pxx_aperiodic.shape[0], config.aperiodic_average)) pxx_aperiodic = _proc_apply_average(pxx_aperiodic, config.aperiodic_average, axis=0) pxx_oscillatory = pxx_full - pxx_aperiodic # Get trimmed frequency range out_freqs, pxx_aperiodic = _proc_trim_freq_range(pxx_aperiodic, freqs, fmin, fmax) out_freqs, pxx_oscillatory = _proc_trim_freq_range(pxx_oscillatory, freqs, fmin, fmax) logging.info('Averaging across {0} time segments using method {1}'.format(pxx_aperiodic.shape[0], config.average)) pxx_aperiodic = _proc_apply_average(pxx_aperiodic, config.average, axis=0) pxx_oscillatory = _proc_apply_average(pxx_oscillatory, config.average, axis=0) if pxx_aperiodic.ndim > 1: pxx_aperiodic = _proc_unroll_output(pxx_aperiodic, pxx_aperiodic.ndim-1, output_axis='auto') pxx_oscillatory = _proc_unroll_output(pxx_oscillatory, pxx_aperiodic.ndim-1, output_axis='auto') aperiodic = SpectrumResult(out_freqs, t, pxx_aperiodic, config=config) oscillatory = SpectrumResult(out_freqs, t, pxx_oscillatory, config=config) return aperiodic, oscillatory @set_verbose def irasa(x, method='original', resample_factors=None, aperiodic_average='median', #bootstrap_osc=None, bs_average='mean', Loading Loading @@ -2061,6 +2085,91 @@ def irasa(x, method='original', resample_factors=None, aperiodic_average='median return freqs, aperiodic_pxx, oscillatory_pxx @set_verbose def glm_irasa_v3(x, # GLM args reg_categorical=None, reg_ztrans=None, reg_unitmax=None, contrasts=None, fit_method='pinv', fit_intercept=True, # IRASA args method='original', resample_factors=None, aperiodic_average='median', # Periodogram args average='median', # STFT window args nperseg=None, noverlap=None, window_type='hann', detrend='constant', # FFT core args nfft=None, axis=-1, return_onesided=True, mode='psd', scaling='density', fs=1.0, fmin=None, fmax=None, # misc return_config=False, verbose=None): """Compute Periodogram by averaging across windows in a STFT. Parameters ---------- x : array_like Time series of measurement values %(glmperiodogram)s' %(irasa)s' %(average)s' %(stft_window_user)s' %(fft_user)s' return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) %(verbose)s' Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. config : PeriodogramConfig, optional Configuration object containing all parameters used to compute spectrum, optionally returned based on value of `return_config`. """ # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand logging.info('Setting config options') aperiodic, oscillatory = irasa_v3(x, reg_categorical=reg_categorical, reg_ztrans=reg_ztrans, reg_unitmax=reg_unitmax, contrasts=contrasts, fit_method=fit_method, fit_intercept=fit_intercept, # IRASA args method=method, resample_factors=resample_factors, aperiodic_average=aperiodic_average, # Periodogram args average=average, # STFT window args nperseg=nperseg, noverlap=noverlap, window_type=window_type, detrend=detrend, # FFT core args nfft=nfft, axis=axis, return_onesided=return_onesided, mode=mode, scaling=scaling, fs=fs, fmin=fmin, fmax=fmax, # misc return_config=return_config, verbose=verbose) # Transform inputs into predicable, sanity checked dictionaries logging.info('Processing Conditions, Covariates and Confounds') for ind, pxx in enumerate((aperiodic, oscillatory)): pxx.config.reg_categorical = _process_input_covariate(pxx.config.reg_categorical, pxx.config.input_len) pxx.config.reg_ztrans = _process_input_covariate(pxx.config.reg_ztrans, pxx.config.input_len) pxx.config.reg_unitmax = _process_input_covariate(pxx.config.reg_unitmax, pxx.config.input_len) # Fit model model, des, data = _glm_fit_glmtools(pxx.spectrum, pxx.config.reg_categorical, pxx.config.reg_ztrans, pxx.config.reg_unitmax, pxx.config, contrasts=pxx.config.contrasts, fit_intercept=pxx.config.fit_intercept) if ind == 0: glm_aperiodic = GLMSpectrumResult(pxx.f, model, des, data, config=pxx.config) else: glm_oscillatory = GLMSpectrumResult(pxx.f, model, des, data, config=pxx.config) return glm_aperiodic, glm_oscillatory @set_verbose def glm_irasa(x, # GLM args Loading
sails/tests/test_stft.py +19 −19 Original line number Diff line number Diff line Loading @@ -21,9 +21,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): for ii in range(5): xx = np.random.randn(4096,) f, pxx = signal.welch(xx, nperseg=2**(4+ii)) f2, pxx2 = periodogram(xx, nperseg=2**(4+ii)) pxx2 = periodogram(xx, nperseg=2**(4+ii)) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_window_type(self): """Ensure window type results are consistent.""" Loading @@ -35,9 +35,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): xx = np.random.randn(4096,) win = window_tests[ii] if window_tests[ii] is not None else np.ones((128,)) / 128 f, pxx = signal.welch(xx, nperseg=128, window=win) f2, pxx2 = periodogram(xx, nperseg=128, window_type=window_tests[ii]) pxx2 = periodogram(xx, nperseg=128, window_type=window_tests[ii]) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_nfft(self): """Ensure nfft results are consistent.""" Loading @@ -46,9 +46,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): for ii in range(5): xx = np.random.randn(4096,) f, pxx = signal.welch(xx, nfft=2**(ii+4), nperseg=16) f2, pxx2 = periodogram(xx, nfft=2**(ii+4), nperseg=16) pxx2 = periodogram(xx, nfft=2**(ii+4), nperseg=16) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_scaling(self): """Ensure scaling results are consistent.""" Loading @@ -59,9 +59,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): for ii in range(len(scaling_tests)): xx = np.random.randn(4096,) f, pxx = signal.welch(xx, nperseg=128, scaling=scaling_tests[ii]) f2, pxx2 = periodogram(xx, nperseg=128, scaling=scaling_tests[ii]) pxx2 = periodogram(xx, nperseg=128, scaling=scaling_tests[ii]) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_sided(self): """Ensure scaling results are consistent.""" Loading @@ -73,9 +73,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): print(side_tests[ii]) xx = np.random.randn(4096,) f, pxx = signal.welch(xx, nperseg=128, return_onesided=side_tests[ii]) f2, pxx2 = periodogram(xx, nperseg=128, return_onesided=side_tests[ii]) pxx2 = periodogram(xx, nperseg=128, return_onesided=side_tests[ii]) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_detrend(self): """Ensure scaling results are consistent.""" Loading @@ -87,9 +87,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): print(detrend_tests[ii]) xx = np.random.randn(4096,) f, pxx = signal.welch(xx, nperseg=128, detrend=detrend_tests[ii]) f2, pxx2 = periodogram(xx, nperseg=128, detrend=detrend_tests[ii]) pxx2 = periodogram(xx, nperseg=128, detrend=detrend_tests[ii]) assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) def test_simple_periodogram_average(self): """Ensure scaling results are consistent.""" Loading @@ -106,9 +106,9 @@ class TestSTFTAgainstScipy(unittest.TestCase): # they use a median with bias correction referencing a paper # https://github.com/scipy/scipy/blob/v1.11.3/scipy/signal/_spectral_py.py#L2037 avg = average_tests[ii] if ii == 0 else average_tests[ii] + '_bias' f2, pxx2 = periodogram(xx, nperseg=128, average=avg, verbose='DEBUG') pxx2 = periodogram(xx, nperseg=128, average=avg, verbose='DEBUG') assert(np.allclose(pxx, pxx2)) assert(np.allclose(pxx, pxx2.spectrum)) class TestBasicIRASA(unittest.TestCase): Loading @@ -116,12 +116,12 @@ class TestBasicIRASA(unittest.TestCase): def test_canary_irasa(self): """Ensure irasa runs.""" from ..stft import irasa, periodogram from ..stft import irasa_v3, periodogram # Run test 5 times for ii in range(5): xx = np.random.randn(4096,) f, pxx = periodogram(xx, nperseg=2**(4+ii), average='median') f2, aperiodic, oscillatory = irasa(xx, nperseg=2**(4+ii)) assert(np.all(f == f2)) assert(np.allclose(pxx, aperiodic+oscillatory)) pxx = periodogram(xx, nperseg=2**(4+ii), average='mean') aperiodic, oscillatory = irasa_v3(xx, nperseg=2**(4+ii), average='mean') assert(np.all(pxx.f == aperiodic.f)) assert(np.allclose(pxx.spectrum, aperiodic.spectrum + oscillatory.spectrum))