Commit 6a756cbc authored by Andrew Quinn's avatar Andrew Quinn
Browse files

Merge branch 'cluster_perm' into 'master'

Cluster perm

See merge request ajquinn/glmtools!11
parents 0a983bb0 2beb1cd9
Loading
Loading
Loading
Loading
+27 −1
Original line number Diff line number Diff line
@@ -61,6 +61,7 @@ class AbstractModelFit(AbstractAnam):

        # Store a copy of the design matrix
        self.design_matrix = design.design_matrix
        self.regressor_list = design.regressor_list

        # Compute number of valid observations (observations with NaNs are ignored)
        self.good_observations = np.isnan(data.sum(axis=1)) == False  # noqa: E712
@@ -177,10 +178,35 @@ class AbstractModelFit(AbstractAnam):
            Array containing t-statistic estimates

        """
        return get_tstats(self.copes, self.varcopes,
        return get_tstats(self.copes, self.varcopes.copy(),
                          varcope_smoothing=varcope_smoothing, smoothing_window=np.hanning,
                          smooth_dims=None, sigma_hat=sigma_hat)

    def project_range(self, contrast, nsteps=2, values=None, mean_ind=0):
        """Get model prediction for a range of values across one regressor."""

        steps = np.linspace(self.design_matrix[:, contrast].min(),
                            self.design_matrix[:, contrast].max(),
                            nsteps)
        pred = np.zeros((nsteps, *self.betas.shape[1:]))

        # Run projection
        for ii in range(nsteps):
            if nsteps == 1:
                coeff = 0
            else:
                coeff = steps[ii]
            pred[ii, ...] = self.betas[mean_ind, ...] + coeff*self.betas[contrast, ...]

        # Compute label values
        if nsteps > 1:
            scale = self.regressor_list[contrast].values_orig
            llabels = np.linspace(scale.min(), scale.max(), nsteps)
        else:
            llabels = ['Mean']

        return pred, llabels

    @property
    def num_observations(self):

+68 −1
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ from . import fit
import sys
from copy import deepcopy

from scipy import ndimage
from scipy import ndimage, stats


class Permutation:
@@ -512,6 +512,73 @@ def get_sensor_time_clusters(X, A, cft=5, pwr=1):
    return C3, cinds, cexts


class MNEClusterPermutation3(Permutation):

    def _get_null_dims(self, dims, **kwargs):
        return (self.nperms,)

    def _extract_perm_stat(self, null, **kwargs):
        from mne.stats.cluster_level import _find_clusters as mne_find_clusters

        threshold = self.perm_args.get('cluster_forming_threshold')
        adjacency = self.perm_args.get('adjacency', None)

        flatt = null.flatten()
        clus, cstats = mne_find_clusters(flatt,
                                         threshold,
                                         adjacency=adjacency)

        if len(clus) == 0:
            print('No clusters')
            return 0

        print('Found {0} clusters - {1} is largest'.format(len(cstats), np.abs(cstats).max()))

        if len(cstats) == 0:
            return 0
        else:
            return np.abs(cstats).max()

    def get_obs_clusters(self, data, fit=fit.OLSModel):
        from mne.stats.cluster_level import _find_clusters as mne_find_clusters
        from mne.stats.cluster_level import _reshape_clusters as mne_reshape_clusters

        f = fit(self._design, data)
        obs = getattr(f, self.perm_metric)[self.contrast_idx, ...]

        threshold = self.perm_args.get('cluster_forming_threshold')
        adjacency = self.perm_args.get('adjacency')

        flatt = obs.flatten()
        clus, cstats = mne_find_clusters(flatt,
                                         threshold,
                                         adjacency=adjacency)

        clus = mne_reshape_clusters(clus, obs.shape)

        return obs, clus, cstats

    def get_sig_clusters(self, thresh, data):
        # Find sig clusters
        obs, clusters, cstats = self.get_obs_clusters(data)
        thresh = self.get_thresh([thresh])
        sigs = np.abs(cstats) > thresh
        sig_inds = np.where(sigs)[0]

        # Collate info from sig clusters
        out = []
        for idx in sig_inds:
            clu = clusters[idx]
            pval = stats.percentileofscore(self.nulls, cstats[idx])
            out.append((cstats[idx], pval, clu))

        # Sort from largest to smallest
        I = np.argsort([c[0] for c in out])[::-1]
        out = [out[ii] for ii in I]

        return out, obs


class MNEClusterPermutation2(Permutation):

    def _get_null_dims(self, dims, **kwargs):
+2 −0
Original line number Diff line number Diff line
@@ -35,6 +35,8 @@ class AbstractRegressor(AbstractAnam):

    def normalise_values(self):

        self.values_orig = self.values.copy()

        if self.preproc == 'z':
            self.values = (self.values - self.values.mean()) / self.values.std()
        elif self.preproc == 'demean':