Commit 1319a780 authored by Andrew Quinn's avatar Andrew Quinn
Browse files

Major update - relax glmtools dependency and start to finalise API

parent 803447fc
Loading
Loading
Loading
Loading
+157 −177
Original line number Diff line number Diff line
import importlib
import numpy as np
import glmtools as glm
from scipy import signal, stats
from scipy import fft as sp_fft
from scipy import signal, stats
from scipy.signal import signaltools

# ------------------------------------------------------------------
@@ -10,7 +10,7 @@ from scipy.signal import signaltools
# These functions are stand-alone data processors


def window_vector(x, nperseg, nstep, window=None, detrend_func=None):
def window_vector(x, nperseg, nstep, window=None, detrend_func=None, padded=False):
    """Strongly inspired by scipy.signal.spectral._fft_helper.
    Will window the last axis of input.

@@ -23,11 +23,12 @@ def window_vector(x, nperseg, nstep, window=None, detrend_func=None):
    # y.shape == nperseg + (nseg-1)*nstep
    # nadd = (-(y.shape[-1]-nperseg) % nstep) % nperseg
    # y = np.r_[y, np.zeros(nadd,)]
    padded = True
    if padded:
        nadd = (-(x.shape[-1]-nperseg) % nstep) % nperseg
        zeros_shape = list(x.shape[:-1]) + [nadd]
        y = np.concatenate((x, np.zeros(zeros_shape)), axis=-1)
    else:
        y = x

    # Strided array
    # https://github.com/scipy/scip/y/blob/v1.5.1/scipy/signal/spectral.py#L1896
@@ -47,8 +48,9 @@ def window_vector(x, nperseg, nstep, window=None, detrend_func=None):
    return y_window


def compute_windowed_fft(x, fs=1.0, window='hann', nperseg=None, noverlap=None,
                     nfft=None, detrend='constant', return_onesided=True,
def compute_windowed_fft(x, fs=1.0, window='hann', fmin=None, fmax=None,
                     nperseg=None, noverlap=None,
                     tfft=None, detrend='constant', return_onesided=True,
                     scaling='density', axis=-1, mode='psd', boundary=None,
                     padded=False, return_config=False, config=None, output_roll='auto'):
    """Assuming you want a two-sided PSD."""
@@ -61,12 +63,12 @@ def compute_windowed_fft(x, fs=1.0, window='hann', nperseg=None, noverlap=None,
                             padded=padded)

    # ---- Work start here

    x = _proc_roll_input(x, axis=config['axis'])

    # window inputs
    y = window_vector(x, config['nperseg'], config['nstep'],
                      window=config['win'], detrend_func=config['detrend_func'])
                      window=config['win'], detrend_func=config['detrend_func'],
                      padded=config['padded'])

    # Compute FFT
    if config['side'] == 'twosided':
@@ -82,10 +84,15 @@ def compute_windowed_fft(x, fs=1.0, window='hann', nperseg=None, noverlap=None,
    # Apply scaling
    result = _proc_spectrum_scaling(result, config['scale'], config['side'], config['mode'], config['nfft'])

    # Create time window vector
    time = np.arange(config['nperseg']/2, x.shape[-1] - config['nperseg']/2 + 1,
                     config['nperseg'] - config['noverlap'])/float(config['fs'])

    freqs = config['freqvals']
    # Trim frequency range to specified limits
    fidx = (config['freqvals'] >= config['fmin']) & \
           (config['freqvals'] <= config['fmax'])
    result = result[..., fidx]
    freqs = config['freqvals'][fidx]

    # Final two axes are now [..., time x freq]
    if output_roll == 'auto':
@@ -96,12 +103,6 @@ def compute_windowed_fft(x, fs=1.0, window='hann', nperseg=None, noverlap=None,
        # Put time at front and freq in original position
        result = np.rollaxis(result, -2, 0)
        result = np.rollaxis(result, -1, config['axis']+1)
    #if output_roll == 'auto':
    #    # Final two axis are now [.... time, freq] - roll time to front
    #    result = _proc_unroll_output(result, -2, config['axis'])
    #elif output_roll == 'glm':
    #    # Final two axis are now [.... time, freq] - roll time to front
    #    result = _proc_unroll_output(result, -2, config['axis'])

    if return_config:
        return freqs, time, result, config
@@ -212,6 +213,8 @@ def _triage_scaling(scaling, fs, win):
        scale = 1.0 / (fs * (win*win).sum())
    elif scaling == 'spectrum':
        scale = 1.0 / win.sum()**2
    elif scaling is None:
        scale = 1.0
    else:
        raise ValueError('Unknown scaling: %r' % scaling)
    return scale
@@ -245,7 +248,16 @@ def _triage_mode(mode):
                          .format(mode, modelist))


def set_options(input_len, input_complex=False,
def _triage_frange(fmin, fmax, fs):
    if fmin is None:
        fmin = 0
    if fmax is None:
        fmax = fs/2

    return fmin, fmax


def set_options(input_len, input_complex=False, fmin=None, fmax=None,
                fs=1.0, window='hann', nperseg=None, noverlap=None,
                nfft=None, detrend='constant', return_onesided=True,
                scaling='density', axis=-1, mode='psd', boundary=None,
@@ -267,6 +279,8 @@ def set_options(input_len, input_complex=False,

    freqvals = _triage_freqvalues(nfft, fs, sides)

    fmin, fmax = _triage_frange(fmin, fmax, fs)

    _triage_mode(mode)

    opts = {'input_len': input_len,
@@ -283,6 +297,8 @@ def set_options(input_len, input_complex=False,
            'side': sides,
            'mode': mode,
            'axis': axis,
            'fmin': fmin,
            'fmax': fmax,
            'detrend_func': detrend_func,
            'freqvals': freqvals}

@@ -298,7 +314,7 @@ def set_options(input_len, input_complex=False,

def psd(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
        detrend='constant', return_onesided=True, scaling='density',
        axis=-1, average='mean'):
        axis=-1, average='mean', fmin=None, fmax=None):
    """Compute Periodogram from successive FFTs."""

    f, t, p = compute_windowed_fft(x, fs=fs, nperseg=nperseg, noverlap=noverlap, nfft=nfft,
@@ -317,18 +333,65 @@ def psd(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
# Conditioned Spectrogram Functions


def process_regressor(Y, config, mode='nuisance'):
    """Y is [nregs x nsamples]. Nuisance is scaled 0->1 and covariate is z-transformed."""
def _flatten(X):
    """Flatten all dimensions after first in prep for regression"""
    if X.ndim == 1:
        return X[:, np.newaxis]
    else:
        return X.reshape(X.shape[0], np.prod(X.shape[1:]))


def _unflatten(X, orig_shape):
    """Restore dimensions after regression."""
    new_shape = tuple((X.shape[0], *orig_shape[1:]))
    return X.reshape(*new_shape)


def _is_sklearn_estimator(fit_method, strict=False):
    """Check (in the duck sense) if object is a skearn fitter."""
    test1 = hasattr(fit_method, 'fit') and callable(getattr(fit_method, 'fit'))
    test2 = hasattr(fit_method, 'get_params') and callable(getattr(fit_method, 'get_params'))
    test3 = hasattr(fit_method, 'set_params') and callable(getattr(fit_method, 'set_params'))

    return test1 and test2 and test3


def compute_ols_varcopes(design_matrix, data, contrasts, betas):
    """Compute variance of cope estimates."""

    # Compute varcopes
    varcopes = np.zeros((contrasts.shape[0], data.shape[1]))

    # Compute varcopes
    residue_forming_matrix = np.linalg.pinv(design_matrix.T.dot(design_matrix))
    var_forming_matrix = np.diag(np.linalg.multi_dot([contrasts,
                                                     residue_forming_matrix,
                                                     contrasts.T]))

    resid = data - design_matrix.dot(betas)

    # This is equivalent to >> np.diag( resid.T.dot(resid) )
    resid_dots = np.einsum('ij,ji->i', resid.T, resid)
    del resid
    dof_error = data.shape[0] - np.linalg.matrix_rank(design_matrix)
    V = resid_dots / dof_error
    varcopes = var_forming_matrix[:, None] * V[None, :]

    return varcopes


def process_regressor(Y, config, mode='confound'):
    """Y is [nregs x nsamples]. Confound is scaled 0->1 and covariate is z-transformed."""
    if Y.ndim == 1:
        Y = Y[np.newaxis, :]

    noverlap = _triage_noverlap(config['noverlap'], config['nperseg'])

    windowed = window_vector(Y, config['nperseg'], config['noverlap'])
    windowed = window_vector(Y, config['nperseg'], config['noverlap'],
                             window=config['win'], #detrend_func=config['detrend_func'],
                             padded=config['padded'])

    y = np.nansum(windowed, axis=-1)

    if mode == 'nuisance':
    if mode == 'confound':
        y = y - y.min(axis=-1)[:, np.newaxis]
        y = y / y.max(axis=-1)[:, np.newaxis]
    elif mode == 'covariate':
@@ -337,170 +400,87 @@ def process_regressor(Y, config, mode='nuisance'):
    return y


def psd_glm(X, covariates=None, cov_types=None, fitmethod='pinv',
            fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
def psd_glm(X, covariates=None, confounds=None, fitmethod='pinv',
            fit_constant=True, fmin=None, fmax=None, fs=1.0, window='hann',
            nperseg=None, fit_method='pinv', noverlap=None, nfft=None,
            detrend='constant', return_onesided=True, scaling='density',
            axis=-1, average='mean', mode='psd'):
    """Compute a Power Spectrum with a General Linear Model."""

    # Option housekeeping
    if covariates is None:
        covariates = {}

    if cov_types is None:
        cov_types = ['covariate'] * len(covariates)
    if confounds is None:
        confounds = {}

    if axis == -1:
        axis = X.ndim - 1

    # Set configuration
    config = set_options(X.shape[axis], input_complex=np.iscomplexobj(X),
                         fs=fs, window=window, nperseg=nperseg, noverlap=noverlap,
                         nfft=nfft, detrend=detrend, return_onesided=return_onesided,
                         fs=fs, fmin=fmin, fmax=fmax, window=window,
                         nperseg=nperseg, noverlap=noverlap, nfft=nfft,
                         detrend=detrend, return_onesided=return_onesided,
                         scaling=scaling, axis=axis, mode=mode)

    # Compute STFT
    f, t, p = compute_windowed_fft(X, config=config, output_roll='glm')

    info = {}
    covars = list(covariates.keys())

    DC = glm.design.DesignConfig()
    DC.add_regressor(name='Mean', rtype='Constant')
    info['nuisance'] = np.zeros((p.shape[0], ))
    for idx, var in enumerate(covars):
        info[var] = process_regressor(covariates[var], config, mode=cov_types[idx])[0, :]
        if cov_types[idx] == 'nuisance':
            #info['nuisance'] = np.max(np.c_[info['nuisance'], info[var]], axis=1)
            max_ind = np.where(info[var] > info['nuisance'])
            if len(max_ind) > 0:
                info['nuisance'][max_ind] = info[var][max_ind]
        DC.add_regressor(name=var, rtype='Parametric', datainfo=var, preproc=None)
    weights = None#1 - info['nuisance']

    info['num_observations'] = p.shape[0]  # Should always be first axis here
    DC.add_simple_contrasts()
    #import pdb; pdb.set_trace()
    des = DC.design_from_datainfo(info)

    from .orthogonalise import symmetric_orthonormal
    #des.design_matrix = symmetric_orthonormal(des.design_matrix.T)[0].T

    data = glm.data.TrialGLMData(data=p, **info)

    model = glm.fit.OLSModel(des, data, fit_args={'method': fitmethod, 'weights': weights})


    return model, f, des


if __name__ == '__main__':

    #import matplotlib.pyplot as plt

    #nperseg = None
    #noverlap = None
    #nfft = None
    #window = 'hann'
    #detrend = False
    #scaling = 'density'

    #fs = 512
    #seconds = 600
    #mm = 3

    #if mm == 1:
    #    a = emd.utils.ar_simulate(12, fs, seconds, r=.95)[:, 0] * 5e-4
    #    b = emd.utils.ar_simulate(41, fs, seconds, r=.97)[:, 0] * 5e-4

    #    covariate = np.abs(signal.hilbert(b))
    #    covariate = covariate / np.max(covariate)
    #elif mm == 2:

    #    a = np.sin(2*np.pi*19*np.linspace(0, seconds, seconds*fs))
    #    a = (np.linspace(1, 0, seconds*fs) > 0) * a

    #    b = np.sin(2*np.pi*42*np.linspace(0, seconds, seconds*fs))
    #    covariate = (np.linspace(-1, 1, seconds*fs))
    #    b = covariate * b
    #else:
    #    # a = np.sin(2*np.pi*12*np.linspace(0,seconds,seconds*fs))
    #    a = stats.zscore(emd.utils.ar_simulate(12, fs, seconds, r=.98)[:, 0])
    #    a += stats.zscore(emd.utils.ar_simulate(0.1, fs, seconds, r=.97)[:, 0])*2
    #    b = stats.zscore(emd.utils.ar_simulate(10, fs, seconds, r=.96)[:, 0])

    #    nuisance = np.sin(2*np.pi*0.125*np.linspace(0, seconds, seconds*fs)) > 0
    #    b = nuisance * b

    #    covariate = np.random.randn(*nuisance.shape)

    #x = a + b + np.random.randn(*a.shape)

    #nperseg = 512
    #f, p = psd(x, nperseg=nperseg, noverlap=nperseg//2, fs=fs, nfft=2048)
    #f, p2 = psd(a, nperseg=nperseg, noverlap=nperseg//2, fs=fs, nfft=2048)
    #f2, beta, pred = glm_psd(x, nuisance=nuisance, covariate=covariate,
    #                          nperseg=nperseg, noverlap=nperseg//2, fs=fs, nfft=2048)

    #plt.figure()
    #plt.plot(f, p, linewidth=2)
    #plt.plot(f, p2, linewidth=2)
    #plt.plot(f2, pred[0, :], ':', linewidth=2)
    #plt.plot(f2, pred[1, :], linewidth=2)
    #plt.plot(f2, pred[2, :], linewidth=2)
    #plt.xlim(0, 25)
    #plt.ylim(0, 0.75)
    #plt.legend(['Total', 'Target', 'Mean', 'Covariate', 'Nuisance'])

    #%% -----------

    from scipy import signal
    import matplotlib.pyplot as plt
    rng = np.random.default_rng()
    #Generate a test signal, a 2 Vrms sine wave at 1234 Hz, corrupted by
    #0.001 V**2/Hz of white noise sampled at 10 kHz.
    fs = 10e3
    N = 1e5
    amp = 2*np.sqrt(2)
    freq = 1234.0
    noise_power = 0.001 * fs / 2
    time = np.arange(N) / fs
    x = amp*np.sin(2*np.pi*freq*time)
    x += rng.normal(scale=np.sqrt(noise_power), size=time.shape)

    #Compute and plot the power spectral density.
    f, Pxx_den = psd(x, fs, nperseg=1024)
    plt.semilogy(f, Pxx_den)
    plt.ylim([0.5e-3, 1])
    plt.xlabel('frequency [Hz]')
    plt.ylabel('PSD [V**2/Hz]')
    plt.show()

    #If we average the last half of the spectral density, to exclude the
    #peak, we can recover the noise power on the signal.
    np.mean(Pxx_den[256:])
    #0.0009924865443739191

    #Now compute and plot the power spectrum.
    f, Pxx_spec = psd(x, fs, 'flattop', 1024, scaling='spectrum')
    plt.figure()
    plt.semilogy(f, np.sqrt(Pxx_spec))
    plt.xlabel('frequency [Hz]')
    plt.ylabel('Linear spectrum [V RMS]')
    plt.show()
    #The peak height in the power spectrum is an estimate of the RMS
    #amplitude.

    np.sqrt(Pxx_spec.max())
    #2.0077340678640727

    #If we now introduce a discontinuity in the signal, by increasing the
    #amplitude of a small portion of the signal by 50, we can see the
    #corruption of the mean average power spectral density, but using a
    #median average better estimates the normal behaviour.
    x[int(N//2):int(N//2)+10] *= 50.
    f, Pxx_den = psd(x, fs, nperseg=1024)
    f_med, Pxx_den_med = psd(x, fs, nperseg=1024, average='median')
    plt.semilogy(f, Pxx_den, label='mean')
    plt.semilogy(f_med, Pxx_den_med, label='median')
    plt.ylim([0.5e-3, 1])
    plt.xlabel('frequency [Hz]')
    plt.ylabel('PSD [V**2/Hz]')
    plt.legend()
    plt.show()
    # Specify design
    X = []
    if fit_constant:
        X.append(np.ones((p.shape[0],)))
    # Add covariates
    for idx, var in enumerate(covariates.keys()):
        X.append(process_regressor(covariates[var], config, mode='covariate')[0, :])
    # Add confounds
    for idx, var in enumerate(confounds.keys()):
        X.append(process_regressor(covariates[var], config, mode='confound')[0, :])

    design_matrix = np.vstack(X).T
    contrasts = np.eye(design_matrix.shape[1])

    # Prepare data
    orig_shape = p.shape
    p = _flatten(p)

    assert(p.shape[0] == design_matrix.shape[0])
    assert(design_matrix.shape[1] == contrasts.shape[0])

    # Compute model
    extras = None
    copes = None  # Can compute separately if fit doesn't handle this
    if fit_method == 'pinv':
        betas = np.linalg.pinv(design_matrix).dot(p)
    elif fit_method == 'lstsq':
        betas, resids, rank, s = np.linalg.lstsq(design_matrix, p)
    elif fit_method == 'glmtools':
        import glmtools as glm
        des = glm.design.GLMDesign.initialise_from_matrices(design_matrix, contrasts)
        data = glm.data.TrialGLMData(data=p)
        model = glm.fit.OLSModel(des, data)
        betas = model.betas
        copes = model.copes
        varcopes = model.varcopes
        extras = (model, des, data)
    elif _is_sklearn_estimator(fit_method):
        fit_method.fit(design_matrix, p)
        if hasattr(fit_method, 'coef_'):
            betas = fit_method.coef_.T
        else:
            # Somtimes this is stored in a sub model...
            betas = fit_method.estimator_.coef_.T
        extras = (fit_method)
    else:
        raise ValueError('fit_method not recognised')

    if copes is None:
        # Compute contrasts
        copes = contrasts.dot(betas)
        varcopes = compute_ols_varcopes(design_matrix, p, contrasts, betas)

    # Preserve original input shape
    copes = _unflatten(copes, orig_shape)
    varcopes = _unflatten(varcopes, orig_shape)

    return f, copes, varcopes, extras