Commit 54b16045 authored by Andrew Quinn's avatar Andrew Quinn
Browse files

add option for condition regressor in glm_periodogram (glmtools only)

parent c26d7bc7
Loading
Loading
Loading
Loading
Loading
+31 −11
Original line number Diff line number Diff line
@@ -573,7 +573,7 @@ def _proc_unroll_output(result, axis, output_axis='auto'):
    which require the temporal dimension in the first position.

    """
    logging.debug('Rolling outut axis {0} to position {1}'.format(axis, output_axis))
    logging.debug('Rolling output axis {0} to position {1}'.format(axis, output_axis))
    logging.debug('Pre-rolled shape {0}'.format(result.shape))
    if output_axis == 'auto':
        # Return time and freq back to original position
@@ -758,7 +758,7 @@ def _set_noverlap(noverlap, nperseg):
    Parameters
    ----------
    noverlap : int
        Desired number of overlapping samples between sucessive windows.
        Desired number of overlapping samples between successive windows.
    nperseg : int
        Window length

@@ -1579,16 +1579,20 @@ def _process_regressor(Y, config, mode='confound'):
        Processed regressor

    """
    window = None if mode == 'condition' else config.window
    windowed = apply_sliding_window(Y, config.nperseg, config.noverlap,
                                    window=config.window, padded=config.padded)
                                    window=window, padded=config.padded)

    y = np.nansum(windowed, axis=-1)
    print(config.nperseg)

    if mode == 'confound':
        y = y - y.min(axis=-1)[:, np.newaxis]
        y = y / y.max(axis=-1)[:, np.newaxis]
    if mode == 'condition':
        y = y / config.nperseg
    elif mode == 'covariate':
        y = stats.zscore(y, axis=-1)
    elif mode == 'confound':
        y = y - y.min(axis=-1)[:, np.newaxis]
        y = y / y.max(axis=-1)[:, np.newaxis]
    elif mode is None:
        pass

@@ -1665,7 +1669,7 @@ def _specify_design(covariates, confounds, config, fit_constant=True):
    design_matrix : ndarray
        [num_observations x num_regressors] matrix of regressors
    contrasts : ndarray
        [num_regressors x num_regressors] matrix of contrasts (indentity)
        [num_regressors x num_regressors] matrix of contrasts (identity)
    Xlabels : list of str
        List of regressor names

@@ -1815,7 +1819,8 @@ def _glm_fit_sklearn_estimator(pxx, covariates, confounds, config, fit_method, f
    return copes, varcopes, (fit_method)


def _glm_fit_glmtools(pxx, covariates, confounds, config, fit_constant=True):
def _glm_fit_glmtools(pxx, conditions, covariates, confounds,
                      config, contrasts=None, fit_constant=True):
    """Fit a GLM using the glmtools package.

    Parameters
@@ -1848,6 +1853,9 @@ def _glm_fit_glmtools(pxx, covariates, confounds, config, fit_constant=True):
    data = glm.data.TrialGLMData(data=pxx)

    # Add windowed confounds and covariates - no preproc yet
    for key, value in conditions.items():
        logging.debug('Processing Condition Regressor : {0}'.format(key))
        data.info[key] = _process_regressor(value, config, mode='condition')
    for key, value in covariates.items():
        data.info[key] = _process_regressor(value, config, mode=None)
    for key, value in confounds.items():
@@ -1857,12 +1865,19 @@ def _glm_fit_glmtools(pxx, covariates, confounds, config, fit_constant=True):
    if fit_constant:
        logging.debug('Adding Constant Regressor')
        DC.add_regressor(name='Mean', rtype='Constant')
    for key in conditions.keys():
        logging.debug('Adding Condition : {0}'.format(key))
        DC.add_regressor(name=key, rtype='Categorical', datainfo=key, codes=[1])
    for key in covariates.keys():
        logging.debug('Adding Covariate : {0}'.format(key))
        DC.add_regressor(name=key, rtype='Parametric', datainfo=key, preproc='z')
    for key in confounds.keys():
        logging.debug('Adding Confound : {0}'.format(key))
        DC.add_regressor(name=key, rtype='Parametric', datainfo=key, preproc='unitmax')

    if contrasts is not None:
        for con in contrasts:
            DC.add_contrast(**con)
    DC.add_simple_contrasts()

    des = DC.design_from_datainfo(data.info)
@@ -1873,7 +1888,8 @@ def _glm_fit_glmtools(pxx, covariates, confounds, config, fit_constant=True):


@set_verbose
def glm_periodogram(X, covariates=None, confounds=None, fit_method='pinv', fit_constant=True,
def glm_periodogram(X, conditions=None, covariates=None, confounds=None,
                    contrasts=None, fit_method='pinv', fit_constant=True,
                    # General STFT kwargs
                    fs=1.0, window_type='hann', nperseg=None, noverlap=None, nfft=None,
                    detrend='constant', return_onesided=True, scaling='density',
@@ -1969,7 +1985,9 @@ def glm_periodogram(X, covariates=None, confounds=None, fit_method='pinv', fit_c
                                  scaling=scaling, axis=axis, mode=mode,
                                  output_axis='time_first')

    logging.info('Processing Covariates and Confounds')
    # Transform inputs into predicable, sanity checked dictionaries
    logging.info('Processing Conditions, Covariates and Confounds')
    conditions = _process_input_covariate(conditions, config.input_len)
    covariates = _process_input_covariate(covariates, config.input_len)
    confounds = _process_input_covariate(confounds, config.input_len)

@@ -1986,7 +2004,9 @@ def glm_periodogram(X, covariates=None, confounds=None, fit_method='pinv', fit_c
        extras = None
    elif fit_method == 'glmtools':
        logging.info('Running glmtools GLM fit')
        copes, varcopes, extras = _glm_fit_glmtools(p, covariates, confounds, config,
        copes, varcopes, extras = _glm_fit_glmtools(p, conditions, covariates,
                                                    confounds, config,
                                                    contrasts=contrasts,
                                                    fit_constant=fit_constant)
    elif _is_sklearn_estimator(fit_method):
        logging.info('Running sklearn GLM fit with {0}'.format(_glm_fit_sklearn_estimator))