Loading sails/stft.py +327 −1 Original line number Diff line number Diff line Loading @@ -42,13 +42,15 @@ Worker functions: import logging import typing import warnings import fractions from dataclasses import dataclass from functools import wraps from copy import deepcopy import numpy as np from scipy import fft as sp_fft from scipy import stats from scipy.signal import signaltools from scipy.signal import signaltools, resample_poly from scipy.signal.windows import dpss try: Loading Loading @@ -1497,6 +1499,138 @@ def multitaper(x, average='mean', num_tapers='auto', return f, p @set_verbose def irasa(x, resample_factors=None, # Periodogram args average='mean', # General STFT kwargs fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, mode='psd', scaling='density', axis=-1, fmin=None, fmax=None, return_config=False, verbose=None): """Compute Periodogram by averaging across windows in a STFT. Parameters ---------- x : array_like Time series of measurement values average : { 'mean', 'median' }, optional Method to use when averaging periodograms. Defaults to 'mean'. fs : float, optional Sampling frequency of the `x` time series. Defaults to 1.0. window_type : str or tuple or array_like, optional Desired window to use. If `window` is a string or tuple, it is passed to `get_window` to generate the window values, which are DFT-even by default. See `get_window` for a list of windows and required parameters. If `window` is array_like it will be used directly as the window and its length must be nperseg. Defaults to a Hann window. nperseg : int, optional Length of each segment. Defaults to None, but if window is str or tuple, is set to 256, and if window is array_like, is set to the length of the window. noverlap : int, optional Number of points to overlap between segments. If `None`, ``noverlap = nperseg // 2``. Defaults to `None`. nfft : int, optional Length of the FFT used, if a zero padded FFT is desired. If `None`, the FFT length is `nperseg`. Defaults to `None`. detrend : str or function or `False`, optional Specifies how to detrend each segment. If `detrend` is a string, it is passed as the `type` argument to the `detrend` function. If it is a function, it takes a segment and returns a detrended segment. If `detrend` is `False`, no detrending is done. Defaults to 'constant'. return_onesided : bool, optional If `True`, return a one-sided spectrum for real data. If `False` return a two-sided spectrum. Defaults to `True`, but for complex data, a two-sided spectrum is always returned. scaling : { 'density', 'spectrum' }, optional Selects between computing the power spectral density ('density') where `Pxx` has units of V**2/Hz and computing the power spectrum ('spectrum') where `Pxx` has units of V**2, if `x` is measured in V and `fs` is measured in Hz. Defaults to 'density' axis : int, optional Axis along which the periodogram is computed; the default is over the last axis (i.e. ``axis=-1``). fmin : float or None, optional Minimum frequency value to return (Default value = 0) fmax : float or None, optional Maximum frequency value to return (Default value = 0.5) return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. config : PeriodogramConfig, optional Configuration object containing all parameters used to compute spectrum, optionally returned based on value of `return_config`. """ # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand logging.info('Setting config options') config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), average=average, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='time_first') if resample_factors is None: resample_factors = np.linspace(1.1, 1.9, 17) resample_factors = np.round(resample_factors, 4) logging.info('Starting computation') for ii, rf in enumerate(resample_factors): rat = fractions.Fraction(str(rf)) print(rat) up, down = rat.numerator, rat.denominator y = resample_poly(x, up, down, axis=config.axis) z = resample_poly(x, down, up, axis=config.axis) f, t, Y = compute_stft(y, **config.stft_args) Y = apply_average(Y, config.average, axis=0, keepdims=True) f, t, Z = compute_stft(z, **config.stft_args) Z = apply_average(Z, config.average, axis=0, keepdims=True) if ii == 0: pxx = np.sqrt(Y * Z) else: pxx = np.concatenate((pxx, np.sqrt(Y * Z)), axis=0) aperiodic_pxx = np.median(pxx, axis=0, keepdims=False) f, t, full_pxx = compute_stft(x, **config.stft_args) full_pxx = apply_average(full_pxx, config.average, axis=0) return full_pxx - aperiodic_pxx, aperiodic_pxx def apply_average(X, method, axis=0, keepdims=False): # Average over sliding windows. if method == 'mean': X = np.nanmean(X, axis=axis, keepdims=keepdims) elif method == 'median': X = np.nanmedian(X, axis=axis, keepdims=keepdims) elif method is None: pass else: msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'" raise ValueError(msg.format(method)) return X # ----------------------------------------------------------------------- # GLM Spectrogram Functions Loading Loading @@ -2186,3 +2320,195 @@ def glm_multitaper(X, reg_ztrans=None, reg_unitmax=None, fit_method='pinv', fit_ raise ValueError('fit_method not recognised') return f, copes, varcopes, extras @set_verbose def glm_irasa(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None, contrasts=None, fit_method='pinv', fit_intercept=True, ret_class=True, # IRASA kwargs resample_factors=None, average='median', # General STFT kwargs fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', fmin=None, fmax=None, verbose=None): """Compute a Power Spectrum with a General Linear Model. Parameters ---------- x : array_like Time series of measurement values reg_ztrans : dict or None Dictionary of covariate time series to be added as z-standardised regessors. (Default value = None) reg_unitmax : dict or None Dictionary of confound time series to be added as positive-valued unitmax regessors. (Default value = None) fit_method : {'pinv', 'lstsq', 'glmtools', sklearn estimator instance} Specifies how the GLM parameters will be estimated. * `pinv` uses the design matrix psuedo-inverse method * `lstsq` uses np.linalg.lstsq. * `glmtools` uses the OLSModel from the glmtools package. * A parametrised instance of a sklearn estimator is used if specified here. (Default value = 'pinv') fit_intercept : bool Specifies whether a constant valued 'intercept' regressor is included in the model. (Default value = True) fs : float, optional Sampling frequency of the `x` time series. Defaults to 1.0. nperseg : int, optional Length of each segment. Defaults to None, but if window is str or tuple, is set to 256, and if window is array_like, is set to the length of the window. noverlap : int, optional Number of points to overlap between segments. If `None`, ``noverlap = nperseg // 2``. Defaults to `None`. nfft : int, optional Length of the FFT used, if a zero padded FFT is desired. If `None`, the FFT length is `nperseg`. Defaults to `None`. detrend : str or function or `False`, optional Specifies how to detrend each segment. If `detrend` is a string, it is passed as the `type` argument to the `detrend` function. If it is a function, it takes a segment and returns a detrended segment. If `detrend` is `False`, no detrending is done. Defaults to 'constant'. return_onesided : bool, optional If `True`, return a one-sided spectrum for real data. If `False` return a two-sided spectrum. Defaults to `True`, but for complex data, a two-sided spectrum is always returned. scaling : { 'density', 'spectrum' }, optional Selects between computing the power spectral density ('density') where `Pxx` has units of V**2/Hz and computing the power spectrum ('spectrum') where `Pxx` has units of V**2, if `x` is measured in V and `fs` is measured in Hz. Defaults to 'density' axis : int, optional Axis along which the periodogram is computed; the default is over the last axis (i.e. ``axis=-1``). fmin : float or None, optional Minimum frequency value to return (Default value = 0) fmax : float or None, optional Maximum frequency value to return (Default value = 0.5) return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. extras : tuple Additional model information depending on the fit method used. """ # Option housekeeping if axis == -1: axis = X.ndim - 1 if X.ndim != 1 and fit_method in ['pinv', 'lstsq']: msg = "Data input should be vector for 'pinv' and 'lstsq' fits - data shape {0} was passed in" logging.error(msg.format(X.shape)) logging.error("Use fit_method='glmtools' for multdimensional data") raise ValueError("Fit methods 'pinv' and 'lstsq' not implemented for multidimensional data") # Set configuration logging.info('Setting config options') config = GLMPeriodogramConfig(X.shape[axis], reg_ztrans=reg_ztrans, reg_unitmax=reg_unitmax, fit_method=fit_method, contrasts=contrasts, fit_intercept=fit_intercept, input_complex=np.iscomplexobj(X), fs=fs, fmin=fmin, fmax=fmax, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=axis, mode=mode, output_axis='time_first') print(config) # Transform inputs into predicable, sanity checked dictionaries logging.info('Processing Conditions, Covariates and Confounds') reg_categorical = _process_input_covariate(reg_categorical, config.input_len) reg_ztrans = _process_input_covariate(reg_ztrans, config.input_len) reg_unitmax = _process_input_covariate(reg_unitmax, config.input_len) # Compute STFT logging.info('Computing sliding window periodogram') f, t, p = compute_stft(X, **config.stft_args) # Compute model - each method MUST assign copes, varcopes and extras model, des, data = _glm_fit_glmtools(p, reg_categorical, reg_ztrans, reg_unitmax, config, contrasts=contrasts, fit_intercept=fit_intercept) if resample_factors is None: resample_factors = np.linspace(1.1, 1.9, 17) resample_factors = np.round(resample_factors, 4) logging.info('Starting computation') for ii, rf in enumerate(resample_factors): rat = fractions.Fraction(str(rf)) print(rat) up, down = rat.numerator, rat.denominator y = _resample_helper(X, reg_categorical, reg_ztrans, reg_unitmax, up, down, axis=config.axis) y, y_categorical, y_ztrans, y_unitmax = y f, t, Y = compute_stft(y, **config.stft_args) modelY, desY, dataY = _glm_fit_glmtools(Y, y_categorical, y_ztrans, y_unitmax, config, contrasts=contrasts, fit_intercept=fit_intercept) z = _resample_helper(X, reg_categorical, reg_ztrans, reg_unitmax, up, down, axis=config.axis) z, z_categorical, z_ztrans, z_unitmax = z f, t, Z = compute_stft(z, **config.stft_args) modelZ, desZ, dataZ = _glm_fit_glmtools(Z, z_categorical, z_ztrans, z_unitmax, config, contrasts=contrasts, fit_intercept=fit_intercept) if ii == 0: betas = np.sqrt(modelY.betas * modelZ.betas)[np.newaxis, ...] copes = np.sqrt(modelY.copes * modelZ.copes)[np.newaxis, ...] varcopes = np.sqrt(modelY.varcopes * modelZ.varcopes)[np.newaxis, ...] else: new_betas = np.sqrt(modelY.betas * modelZ.betas)[np.newaxis, ...] betas = np.concatenate((betas, new_betas), axis=0) new_copes = np.sqrt(modelY.copes * modelZ.copes)[np.newaxis, ...] copes = np.concatenate((copes, new_copes), axis=0) new_varcopes = np.sqrt(modelY.varcopes * modelZ.varcopes)[np.newaxis, ...] varcopes = np.concatenate((varcopes, new_varcopes), axis=0) model_aperiodic = deepcopy(model) model_aperiodic.betas = apply_average(betas, average, axis=0) model_aperiodic.copes = apply_average(copes, average, axis=0) model_aperiodic.varcopes = apply_average(varcopes, average, axis=0) model.betas = model.betas - model_aperiodic.betas model.copes = model.copes - model_aperiodic.copes model.varcopes = model.varcopes - model_aperiodic.varcopes return model_aperiodic, model def _resample_helper(X, reg_categorical, reg_ztrans, reg_unitmax, up, down, axis=0): y = resample_poly(X, up, down, axis=axis) out_categorical = reg_categorical.copy() for key, val in reg_categorical.items(): out_categorical[key] = resample_poly(val, up, down, axis=0) out_ztrans = reg_ztrans.copy() for key, val in reg_ztrans.items(): out_ztrans[key] = resample_poly(val, up, down, axis=0) out_unitmax = reg_unitmax.copy() for key, val in reg_unitmax.items(): out_unitmax[key] = resample_poly(val, up, down, axis=0) return y, out_categorical, out_ztrans, out_unitmax Loading
sails/stft.py +327 −1 Original line number Diff line number Diff line Loading @@ -42,13 +42,15 @@ Worker functions: import logging import typing import warnings import fractions from dataclasses import dataclass from functools import wraps from copy import deepcopy import numpy as np from scipy import fft as sp_fft from scipy import stats from scipy.signal import signaltools from scipy.signal import signaltools, resample_poly from scipy.signal.windows import dpss try: Loading Loading @@ -1497,6 +1499,138 @@ def multitaper(x, average='mean', num_tapers='auto', return f, p @set_verbose def irasa(x, resample_factors=None, # Periodogram args average='mean', # General STFT kwargs fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, mode='psd', scaling='density', axis=-1, fmin=None, fmax=None, return_config=False, verbose=None): """Compute Periodogram by averaging across windows in a STFT. Parameters ---------- x : array_like Time series of measurement values average : { 'mean', 'median' }, optional Method to use when averaging periodograms. Defaults to 'mean'. fs : float, optional Sampling frequency of the `x` time series. Defaults to 1.0. window_type : str or tuple or array_like, optional Desired window to use. If `window` is a string or tuple, it is passed to `get_window` to generate the window values, which are DFT-even by default. See `get_window` for a list of windows and required parameters. If `window` is array_like it will be used directly as the window and its length must be nperseg. Defaults to a Hann window. nperseg : int, optional Length of each segment. Defaults to None, but if window is str or tuple, is set to 256, and if window is array_like, is set to the length of the window. noverlap : int, optional Number of points to overlap between segments. If `None`, ``noverlap = nperseg // 2``. Defaults to `None`. nfft : int, optional Length of the FFT used, if a zero padded FFT is desired. If `None`, the FFT length is `nperseg`. Defaults to `None`. detrend : str or function or `False`, optional Specifies how to detrend each segment. If `detrend` is a string, it is passed as the `type` argument to the `detrend` function. If it is a function, it takes a segment and returns a detrended segment. If `detrend` is `False`, no detrending is done. Defaults to 'constant'. return_onesided : bool, optional If `True`, return a one-sided spectrum for real data. If `False` return a two-sided spectrum. Defaults to `True`, but for complex data, a two-sided spectrum is always returned. scaling : { 'density', 'spectrum' }, optional Selects between computing the power spectral density ('density') where `Pxx` has units of V**2/Hz and computing the power spectrum ('spectrum') where `Pxx` has units of V**2, if `x` is measured in V and `fs` is measured in Hz. Defaults to 'density' axis : int, optional Axis along which the periodogram is computed; the default is over the last axis (i.e. ``axis=-1``). fmin : float or None, optional Minimum frequency value to return (Default value = 0) fmax : float or None, optional Maximum frequency value to return (Default value = 0.5) return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. config : PeriodogramConfig, optional Configuration object containing all parameters used to compute spectrum, optionally returned based on value of `return_config`. """ # Config object stores options in one place and sets sensible defaults for # unspecified options given the data in-hand logging.info('Setting config options') config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)), average=average, fs=fs, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode, return_onesided=return_onesided, scaling=scaling, axis=axis, fmin=fmin, fmax=fmax, output_axis='time_first') if resample_factors is None: resample_factors = np.linspace(1.1, 1.9, 17) resample_factors = np.round(resample_factors, 4) logging.info('Starting computation') for ii, rf in enumerate(resample_factors): rat = fractions.Fraction(str(rf)) print(rat) up, down = rat.numerator, rat.denominator y = resample_poly(x, up, down, axis=config.axis) z = resample_poly(x, down, up, axis=config.axis) f, t, Y = compute_stft(y, **config.stft_args) Y = apply_average(Y, config.average, axis=0, keepdims=True) f, t, Z = compute_stft(z, **config.stft_args) Z = apply_average(Z, config.average, axis=0, keepdims=True) if ii == 0: pxx = np.sqrt(Y * Z) else: pxx = np.concatenate((pxx, np.sqrt(Y * Z)), axis=0) aperiodic_pxx = np.median(pxx, axis=0, keepdims=False) f, t, full_pxx = compute_stft(x, **config.stft_args) full_pxx = apply_average(full_pxx, config.average, axis=0) return full_pxx - aperiodic_pxx, aperiodic_pxx def apply_average(X, method, axis=0, keepdims=False): # Average over sliding windows. if method == 'mean': X = np.nanmean(X, axis=axis, keepdims=keepdims) elif method == 'median': X = np.nanmedian(X, axis=axis, keepdims=keepdims) elif method is None: pass else: msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'" raise ValueError(msg.format(method)) return X # ----------------------------------------------------------------------- # GLM Spectrogram Functions Loading Loading @@ -2186,3 +2320,195 @@ def glm_multitaper(X, reg_ztrans=None, reg_unitmax=None, fit_method='pinv', fit_ raise ValueError('fit_method not recognised') return f, copes, varcopes, extras @set_verbose def glm_irasa(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None, contrasts=None, fit_method='pinv', fit_intercept=True, ret_class=True, # IRASA kwargs resample_factors=None, average='median', # General STFT kwargs fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', fmin=None, fmax=None, verbose=None): """Compute a Power Spectrum with a General Linear Model. Parameters ---------- x : array_like Time series of measurement values reg_ztrans : dict or None Dictionary of covariate time series to be added as z-standardised regessors. (Default value = None) reg_unitmax : dict or None Dictionary of confound time series to be added as positive-valued unitmax regessors. (Default value = None) fit_method : {'pinv', 'lstsq', 'glmtools', sklearn estimator instance} Specifies how the GLM parameters will be estimated. * `pinv` uses the design matrix psuedo-inverse method * `lstsq` uses np.linalg.lstsq. * `glmtools` uses the OLSModel from the glmtools package. * A parametrised instance of a sklearn estimator is used if specified here. (Default value = 'pinv') fit_intercept : bool Specifies whether a constant valued 'intercept' regressor is included in the model. (Default value = True) fs : float, optional Sampling frequency of the `x` time series. Defaults to 1.0. nperseg : int, optional Length of each segment. Defaults to None, but if window is str or tuple, is set to 256, and if window is array_like, is set to the length of the window. noverlap : int, optional Number of points to overlap between segments. If `None`, ``noverlap = nperseg // 2``. Defaults to `None`. nfft : int, optional Length of the FFT used, if a zero padded FFT is desired. If `None`, the FFT length is `nperseg`. Defaults to `None`. detrend : str or function or `False`, optional Specifies how to detrend each segment. If `detrend` is a string, it is passed as the `type` argument to the `detrend` function. If it is a function, it takes a segment and returns a detrended segment. If `detrend` is `False`, no detrending is done. Defaults to 'constant'. return_onesided : bool, optional If `True`, return a one-sided spectrum for real data. If `False` return a two-sided spectrum. Defaults to `True`, but for complex data, a two-sided spectrum is always returned. scaling : { 'density', 'spectrum' }, optional Selects between computing the power spectral density ('density') where `Pxx` has units of V**2/Hz and computing the power spectrum ('spectrum') where `Pxx` has units of V**2, if `x` is measured in V and `fs` is measured in Hz. Defaults to 'density' axis : int, optional Axis along which the periodogram is computed; the default is over the last axis (i.e. ``axis=-1``). fmin : float or None, optional Minimum frequency value to return (Default value = 0) fmax : float or None, optional Maximum frequency value to return (Default value = 0.5) return_config : bool Indicate whether parameter configuration object should be returned alongside result (Default value = False) Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. extras : tuple Additional model information depending on the fit method used. """ # Option housekeeping if axis == -1: axis = X.ndim - 1 if X.ndim != 1 and fit_method in ['pinv', 'lstsq']: msg = "Data input should be vector for 'pinv' and 'lstsq' fits - data shape {0} was passed in" logging.error(msg.format(X.shape)) logging.error("Use fit_method='glmtools' for multdimensional data") raise ValueError("Fit methods 'pinv' and 'lstsq' not implemented for multidimensional data") # Set configuration logging.info('Setting config options') config = GLMPeriodogramConfig(X.shape[axis], reg_ztrans=reg_ztrans, reg_unitmax=reg_unitmax, fit_method=fit_method, contrasts=contrasts, fit_intercept=fit_intercept, input_complex=np.iscomplexobj(X), fs=fs, fmin=fmin, fmax=fmax, window_type=window_type, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=axis, mode=mode, output_axis='time_first') print(config) # Transform inputs into predicable, sanity checked dictionaries logging.info('Processing Conditions, Covariates and Confounds') reg_categorical = _process_input_covariate(reg_categorical, config.input_len) reg_ztrans = _process_input_covariate(reg_ztrans, config.input_len) reg_unitmax = _process_input_covariate(reg_unitmax, config.input_len) # Compute STFT logging.info('Computing sliding window periodogram') f, t, p = compute_stft(X, **config.stft_args) # Compute model - each method MUST assign copes, varcopes and extras model, des, data = _glm_fit_glmtools(p, reg_categorical, reg_ztrans, reg_unitmax, config, contrasts=contrasts, fit_intercept=fit_intercept) if resample_factors is None: resample_factors = np.linspace(1.1, 1.9, 17) resample_factors = np.round(resample_factors, 4) logging.info('Starting computation') for ii, rf in enumerate(resample_factors): rat = fractions.Fraction(str(rf)) print(rat) up, down = rat.numerator, rat.denominator y = _resample_helper(X, reg_categorical, reg_ztrans, reg_unitmax, up, down, axis=config.axis) y, y_categorical, y_ztrans, y_unitmax = y f, t, Y = compute_stft(y, **config.stft_args) modelY, desY, dataY = _glm_fit_glmtools(Y, y_categorical, y_ztrans, y_unitmax, config, contrasts=contrasts, fit_intercept=fit_intercept) z = _resample_helper(X, reg_categorical, reg_ztrans, reg_unitmax, up, down, axis=config.axis) z, z_categorical, z_ztrans, z_unitmax = z f, t, Z = compute_stft(z, **config.stft_args) modelZ, desZ, dataZ = _glm_fit_glmtools(Z, z_categorical, z_ztrans, z_unitmax, config, contrasts=contrasts, fit_intercept=fit_intercept) if ii == 0: betas = np.sqrt(modelY.betas * modelZ.betas)[np.newaxis, ...] copes = np.sqrt(modelY.copes * modelZ.copes)[np.newaxis, ...] varcopes = np.sqrt(modelY.varcopes * modelZ.varcopes)[np.newaxis, ...] else: new_betas = np.sqrt(modelY.betas * modelZ.betas)[np.newaxis, ...] betas = np.concatenate((betas, new_betas), axis=0) new_copes = np.sqrt(modelY.copes * modelZ.copes)[np.newaxis, ...] copes = np.concatenate((copes, new_copes), axis=0) new_varcopes = np.sqrt(modelY.varcopes * modelZ.varcopes)[np.newaxis, ...] varcopes = np.concatenate((varcopes, new_varcopes), axis=0) model_aperiodic = deepcopy(model) model_aperiodic.betas = apply_average(betas, average, axis=0) model_aperiodic.copes = apply_average(copes, average, axis=0) model_aperiodic.varcopes = apply_average(varcopes, average, axis=0) model.betas = model.betas - model_aperiodic.betas model.copes = model.copes - model_aperiodic.copes model.varcopes = model.varcopes - model_aperiodic.varcopes return model_aperiodic, model def _resample_helper(X, reg_categorical, reg_ztrans, reg_unitmax, up, down, axis=0): y = resample_poly(X, up, down, axis=axis) out_categorical = reg_categorical.copy() for key, val in reg_categorical.items(): out_categorical[key] = resample_poly(val, up, down, axis=0) out_ztrans = reg_ztrans.copy() for key, val in reg_ztrans.items(): out_ztrans[key] = resample_poly(val, up, down, axis=0) out_unitmax = reg_unitmax.copy() for key, val in reg_unitmax.items(): out_unitmax[key] = resample_poly(val, up, down, axis=0) return y, out_categorical, out_ztrans, out_unitmax