Commit 3b941c1b authored by Andrew Quinn's avatar Andrew Quinn
Browse files

Add condition regressor options to simple and sklearn fits

parent 54b16045
Loading
Loading
Loading
Loading
Loading
+13 −7
Original line number Diff line number Diff line
@@ -1650,7 +1650,7 @@ def _process_input_covariate(cov, input_len):
    return ret


def _specify_design(covariates, confounds, config, fit_constant=True):
def _specify_design(conditions, covariates, confounds, config, fit_constant=True):
    """Create a design matrix.

    Parameters
@@ -1680,6 +1680,12 @@ def _specify_design(covariates, confounds, config, fit_constant=True):
        logging.info("Adding constant")
        X.append(np.ones((config.nwindows,)))
        Xlabels.append('Constant')

    # Add conditions
    for idx, var in enumerate(conditions.keys()):
        logging.info("Adding condition '{0}'".format(var))
        X.append(_process_regressor(conditions[var], config, mode='condition'))
        Xlabels.append(var)
    # Add covariates
    for idx, var in enumerate(covariates.keys()):
        logging.info("Adding covariate '{0}'".format(var))
@@ -1720,7 +1726,7 @@ def _run_prefit_checks(data, design_matrix, contrasts):
    assert(design_matrix.shape[1] == contrasts.shape[0])


def _glm_fit_simple(pxx, covariates, confounds, config, fit_method='pinv', fit_constant=True):
def _glm_fit_simple(pxx, conditions, covariates, confounds, config, fit_method='pinv', fit_constant=True):
    """Fit a GLM using a standard OLS fitting method.

    Parameters
@@ -1747,7 +1753,7 @@ def _glm_fit_simple(pxx, covariates, confounds, config, fit_method='pinv', fit_c

    """
    # Prepare GLM components
    design_matrix, contrasts, Xlabels = _specify_design(covariates, confounds,
    design_matrix, contrasts, Xlabels = _specify_design(conditions, covariates, confounds,
                                                        config, fit_constant=fit_constant)

    # Check we're probably good to go
@@ -1770,7 +1776,7 @@ def _glm_fit_simple(pxx, covariates, confounds, config, fit_method='pinv', fit_c
    return copes, varcopes


def _glm_fit_sklearn_estimator(pxx, covariates, confounds, config, fit_method, fit_constant=True):
def _glm_fit_sklearn_estimator(pxx, conditions, covariates, confounds, config, fit_method, fit_constant=True):
    """Fit a GLM using a sklearn-like estimator object.

    Parameters
@@ -1798,7 +1804,7 @@ def _glm_fit_sklearn_estimator(pxx, covariates, confounds, config, fit_method, f
    """
    logging.info('Running sklearn GLM fit')
    # Prepare GLM components
    design_matrix, contrasts, Xlabels = _specify_design(covariates, confounds,
    design_matrix, contrasts, Xlabels = _specify_design(conditions, covariates, confounds,
                                                        config, fit_constant=fit_constant)

    # Check we're probably good to go
@@ -1998,7 +2004,7 @@ def glm_periodogram(X, conditions=None, covariates=None, confounds=None,
    # Compute model - each method MUST assign copes, varcopes and extras
    if fit_method in ['pinv', 'lstsq']:
        logging.info('Running numpy GLM fit')
        copes, varcopes = _glm_fit_simple(p, covariates, confounds, config,
        copes, varcopes = _glm_fit_simple(p, conditions, covariates, confounds, config,
                                          fit_method=fit_method,
                                          fit_constant=fit_constant)
        extras = None
@@ -2010,7 +2016,7 @@ def glm_periodogram(X, conditions=None, covariates=None, confounds=None,
                                                    fit_constant=fit_constant)
    elif _is_sklearn_estimator(fit_method):
        logging.info('Running sklearn GLM fit with {0}'.format(_glm_fit_sklearn_estimator))
        copes, varcopes, extras = _glm_fit_sklearn_estimator(p, covariates, confounds, config,
        copes, varcopes, extras = _glm_fit_sklearn_estimator(p, conditions, covariates, confounds, config,
                                                             fit_method=fit_method,
                                                             fit_constant=fit_constant)
    else: