Commit d59a55e4 authored by Andrew Quinn's avatar Andrew Quinn
Browse files

working glm_irasa_v3

parent ab0ab7df
Loading
Loading
Loading
Loading
+39 −25
Original line number Diff line number Diff line
@@ -1065,6 +1065,16 @@ class GLMIRASAConfig(GLMPeriodogramConfig, IRASAConfig):
        """Set user picks and fill rest with sensible defaults."""
        super().__post_init__()

    @property
    def irasa_args(self):
        """Get keyword arguments for a call to compute_multitaper_stft."""
        args = {}
        for key in ['method', 'resample_factors', 'aperiodic_average',
                    'average', 'nperseg', 'noverlap', 'window_type',
                    'detrend', 'nfft', 'axis', 'return_onesided', 'mode',
                    'scaling', 'fs', 'fmin', 'fmax']:
            args[key] = getattr(self, key)
        return args

# ------------------------------------------------------------------------
# Top-level computation functions
@@ -2132,10 +2142,10 @@ def glm_irasa_v3(x,
    """
    # Config object stores options in one place and sets sensible defaults for
    # unspecified options given the data in-hand
    logging.info('Setting config options')
    aperiodic, oscillatory = irasa_v3(x, reg_categorical=reg_categorical, reg_ztrans=reg_ztrans, 
                                      reg_unitmax=reg_unitmax, contrasts=contrasts, 
                                      fit_method=fit_method, fit_intercept=fit_intercept,
    config = GLMIRASAConfig(x.shape[axis],reg_ztrans=reg_ztrans,
                            reg_unitmax=reg_unitmax,
                            contrasts=contrasts,
                            fit_intercept=fit_intercept,
                            # IRASA args
                            method=method, resample_factors=resample_factors, 
                            aperiodic_average=aperiodic_average,
@@ -2146,26 +2156,30 @@ def glm_irasa_v3(x,
                            window_type=window_type, detrend=detrend,
                            # FFT core args
                            nfft=nfft, axis=axis, return_onesided=return_onesided, mode=mode,
                                      scaling=scaling, fs=fs, fmin=fmin, fmax=fmax,
                                      # misc
                                      return_config=return_config, verbose=verbose) 
                            scaling=scaling, fs=fs, fmin=fmin, fmax=fmax)
    
    logging.info('Setting config options')
    aperiodic, oscillatory = irasa_v3(x, **config.irasa_args)



    # Transform inputs into predicable, sanity checked dictionaries
    logging.info('Processing Conditions, Covariates and Confounds')
    config.reg_categorical = _process_input_covariate(config.reg_categorical, config.input_len)
    config.reg_ztrans = _process_input_covariate(config.reg_ztrans, config.input_len)
    config.reg_unitmax = _process_input_covariate(config.reg_unitmax, config.input_len)
    
    for ind, pxx in enumerate((aperiodic, oscillatory)):
        pxx.config.reg_categorical = _process_input_covariate(pxx.config.reg_categorical, pxx.config.input_len)
        pxx.config.reg_ztrans = _process_input_covariate(pxx.config.reg_ztrans, pxx.config.input_len)
        pxx.config.reg_unitmax = _process_input_covariate(pxx.config.reg_unitmax, pxx.config.input_len)

        # Fit model
        model, des, data = _glm_fit_glmtools(pxx.spectrum, pxx.config.reg_categorical, pxx.config.reg_ztrans,
                                             pxx.config.reg_unitmax, pxx.config,
                                             contrasts=pxx.config.contrasts,
                                             fit_intercept=pxx.config.fit_intercept)
        model, des, data = _glm_fit_glmtools(pxx.spectrum, config.reg_categorical, config.reg_ztrans,
                                             config.reg_unitmax, config,
                                             contrasts=config.contrasts,
                                             fit_intercept=config.fit_intercept)
        if ind == 0:
            glm_aperiodic = GLMSpectrumResult(pxx.f, model, des, data, config=pxx.config)
            glm_aperiodic = GLMSpectrumResult(pxx.f, model, des, data, config=config)
        else:
            glm_oscillatory = GLMSpectrumResult(pxx.f, model, des, data, config=pxx.config)
            glm_oscillatory = GLMSpectrumResult(pxx.f, model, des, data, config=config)

    return glm_aperiodic, glm_oscillatory