Commit 803447fc authored by Andrew Quinn's avatar Andrew Quinn
Browse files

Next iteration of GLMPSD

parent 509de06b
Loading
Loading
Loading
Loading
+340 −133
Original line number Diff line number Diff line
@@ -5,7 +5,9 @@ from scipy import fft as sp_fft
from scipy.signal import signaltools

# ------------------------------------------------------------------
# General Helper Functions
# Low level computations
#
# These functions are stand-alone data processors


def window_vector(x, nperseg, nstep, window=None, detrend_func=None):
@@ -45,6 +47,144 @@ def window_vector(x, nperseg, nstep, window=None, detrend_func=None):
    return y_window


def compute_windowed_fft(x, fs=1.0, window='hann', nperseg=None, noverlap=None,
                     nfft=None, detrend='constant', return_onesided=True,
                     scaling='density', axis=-1, mode='psd', boundary=None,
                     padded=False, return_config=False, config=None, output_roll='auto'):
    """Assuming you want a two-sided PSD."""

    if config is None:
        config = set_options(x.shape[axis], input_complex=np.iscomplexobj(x),
                             fs=fs, window=window, nperseg=nperseg, noverlap=noverlap,
                             nfft=nfft, detrend=detrend, return_onesided=return_onesided,
                             scaling=scaling, axis=axis, mode=mode, boundary=boundary,
                             padded=padded)

    # ---- Work start here

    x = _proc_roll_input(x, axis=config['axis'])

    # window inputs
    y = window_vector(x, config['nperseg'], config['nstep'],
                      window=config['win'], detrend_func=config['detrend_func'])

    # Compute FFT
    if config['side'] == 'twosided':
        func = sp_fft.fft
    else:
        y = y.real
        func = sp_fft.rfft
    result = func(y, config['nfft'])

    # Apply spectrum mode selection
    result = _proc_spectrum_mode(result, config['mode'], axis=config['axis'])

    # Apply scaling
    result = _proc_spectrum_scaling(result, config['scale'], config['side'], config['mode'], config['nfft'])

    time = np.arange(config['nperseg']/2, x.shape[-1] - config['nperseg']/2 + 1,
                     config['nperseg'] - config['noverlap'])/float(config['fs'])

    freqs = config['freqvals']

    # Final two axes are now [..., time x freq]
    if output_roll == 'auto':
        # Return time and freq back to original position
        result = np.rollaxis(result, -2, config['axis'])
        result = np.rollaxis(result, -1, config['axis']+1)
    elif output_roll == 'glm':
        # Put time at front and freq in original position
        result = np.rollaxis(result, -2, 0)
        result = np.rollaxis(result, -1, config['axis']+1)
    #if output_roll == 'auto':
    #    # Final two axis are now [.... time, freq] - roll time to front
    #    result = _proc_unroll_output(result, -2, config['axis'])
    #elif output_roll == 'glm':
    #    # Final two axis are now [.... time, freq] - roll time to front
    #    result = _proc_unroll_output(result, -2, config['axis'])

    if return_config:
        return freqs, time, result, config
    else:
        return freqs, time, result


def _proc_roll_input(x, axis=-1):
    if axis != -1:
        x = np.rollaxis(x, axis, len(x.shape))
    return x


def _proc_unroll_output(z, ax, axis):
    # Output is going to have new last axis for time/window index, so a
    # negative axis index shifts down one
    if axis < 0:
        offset = 1
    else:
        offset = 0
    return np.rollaxis(z, ax, axis-offset)


def _proc_spectrum_mode(pxx, mode, axis=-1):
    if mode == 'magnitude':
        pxx = np.abs(pxx)
    elif mode == 'psd':
        pxx = (np.conjugate(pxx) * pxx).real
    elif mode in ['angle', 'phase']:
        pxx = np.angle(pxx)
        if mode == 'phase':
            # pxx has one additional dimension for time strides
            if axis < 0:
                axis -= 1
            pxx = np.unwrap(pxx, axis=axis)
    elif mode == 'complex':
        pass

    return pxx


def _proc_spectrum_scaling(pxx, scale, side, mode, nfft):
    pxx *= scale
    if side == 'onesided' and mode == 'psd':
        if nfft % 2:
            pxx[..., 1:] *= 2
        else:
            # Last point is unpaired Nyquist freq point, don't double
            pxx[..., 1:-1] *= 2
    return pxx



# ------------------------------------------------------------------
# Option handling
#
# These functions parse inputs, set defaults and return sets of configured options.

def _triage_freqvalues(nfft, fs, sides):

    if sides == 'twosided':
        freqs = sp_fft.fftfreq(nfft, 1/fs)
    elif sides == 'onesided':
        freqs = sp_fft.rfftfreq(nfft, 1/fs)

    return freqs


def _triage_onesided(return_onesided, input_complex):

    if return_onesided:
        if input_complex:
            sides = 'twosided'
            warnings.warn('Input data is complex, switching to '
                          'return_onesided=False')
        else:
            sides = 'onesided'
    else:
        sides = 'twosided'

    return sides


def _triage_nfft(nfft, nperseg):

    if nfft is None:
@@ -98,9 +238,19 @@ def _triage_detrend(detrend, axis):
    return detrend_func


def set_options(input_len, nperseg=None, noverlap=None,
                nfft=None, window='hann', mode='psd', fs=1,
                detrend=None, scaling='density'):
def _triage_mode(mode):
    modelist = ['psd', 'complex', 'magnitude', 'angle', 'phase']
    if mode not in modelist:
        raise ValueError('unknown value for mode {}, must be one of {}'
                          .format(mode, modelist))


def set_options(input_len, input_complex=False,
                fs=1.0, window='hann', nperseg=None, noverlap=None,
                nfft=None, detrend='constant', return_onesided=True,
                scaling='density', axis=-1, mode='psd', boundary=None,
                padded=False):

    # parse window; if array like, then set nperseg = win.shape
    win, nperseg = signal.spectral._triage_segments(window, nperseg, input_length=input_len)

@@ -113,59 +263,53 @@ def set_options(input_len, nperseg=None, noverlap=None,

    detrend_func = _triage_detrend(detrend, axis=-1)

    return nperseg, noverlap, nstep, nfft, win, scale, detrend_func
    sides = _triage_onesided(return_onesided, input_complex)

    freqvals = _triage_freqvalues(nfft, fs, sides)

def compute_windowed_fft(x, nperseg=None, noverlap=None,
                         nfft=None, fs=1, window='hann', mode='psd',
                         detrend=None, scaling='density'):
    """Assuming you want a two-sided PSD."""

    opts = set_options(x.shape[-1], nperseg=nperseg, noverlap=noverlap,
                       nfft=nfft, window=window, fs=fs,
                       detrend=detrend, scaling=scaling)
    nperseg, noverlap, nstep, nfft, win, scale, detrend_func = opts
    _triage_mode(mode)

    # ---- Work start here
    opts = {'input_len': input_len,
            'input_complex': input_complex,
            'fs': fs,
            'win': win,
            'nperseg': nperseg,
            'noverlap': noverlap,
            'nstep': nstep,
            'nfft': nfft,
            'scale': scale,
            'boundary': boundary,
            'padded': padded,
            'side': sides,
            'mode': mode,
            'axis': axis,
            'detrend_func': detrend_func,
            'freqvals': freqvals}

    # window inputs
    y = window_vector(x, nperseg, nstep, window=win, detrend_func=detrend_func)

    # Compute FFT
    result = sp_fft.fft(y, nfft)

    if mode == 'psd':
        # Square the result for PSD - ignore for STFT
        result = np.conjugate(result) * result
        result = result.real  # All real anyway at this point

    # Apply scaling
    result *= scale

    time = np.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1,
                     nperseg - noverlap)/float(fs)

    freqs = sp_fft.fftfreq(nfft, 1/fs)
    return opts

    return freqs, time, result

# ------------------------------------------------------------------------
# Top-level computation functions
#
# These functions take input data, run the option handling and execute whatever
# computations are needed


def psd(x, nperseg=None, noverlap=None, average='mean',
        nfft=None, fs=1, window='hann',
        detrend=None, scaling='density'):
    """Compute what is basically a Welch's Periodogram."""
def  psd(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
        detrend='constant', return_onesided=True, scaling='density',
        axis=-1, average='mean'):
    """Compute Periodogram from successive FFTs."""

    print(noverlap)
    f, t, p = compute_windowed_fft(x, fs=fs, nperseg=nperseg, noverlap=noverlap, nfft=nfft,
                                   window=window, detrend=detrend, scaling=scaling)
                                   axis=axis, window=window, detrend=detrend, scaling=scaling)

    if average == 'mean':
        p = np.nanmean(p, axis=0)
        p = np.nanmean(p, axis=0).real
    elif average == 'median':
        p = np.nanmedian(p, axis=0)
        p = np.nanmedian(p, axis=0).real
    elif average is None:
        pass

    return f, p.real

@@ -173,16 +317,16 @@ def psd(x, nperseg=None, noverlap=None, average='mean',
# Conditioned Spectrogram Functions


def process_regressor(Y, nperseg, noverlap, mode='nuisance'):
def process_regressor(Y, config, mode='nuisance'):
    """Y is [nregs x nsamples]. Nuisance is scaled 0->1 and covariate is z-transformed."""
    if Y.ndim == 1:
        Y = Y[np.newaxis, :]

    noverlap = _triage_noverlap(noverlap, nperseg)
    noverlap = _triage_noverlap(config['noverlap'], config['nperseg'])

    windowed = window_vector(Y, nperseg, noverlap)
    windowed = window_vector(Y, config['nperseg'], config['noverlap'])

    y = np.nansum(windowed**2, axis=-1)
    y = np.nansum(windowed, axis=-1)

    if mode == 'nuisance':
        y = y - y.min(axis=-1)[:, np.newaxis]
@@ -193,107 +337,170 @@ def process_regressor(Y, nperseg, noverlap, mode='nuisance'):
    return y


def get_design_matrix(xlen, nperseg, noverlap, nuisance=None, covariate=None):

    design_matrix = np.ones((1, xlen))
def psd_glm(X, covariates=None, cov_types=None, fitmethod='pinv',
            fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
            detrend='constant', return_onesided=True, scaling='density',
            axis=-1, average='mean', mode='psd'):

    if covariate is not None:
        cov_reg = process_regressor(covariate, nperseg, noverlap, mode='covariate')
        design_matrix = np.vstack((design_matrix, cov_reg))
    if covariates is None:
        covariates = {}

    if nuisance is not None:
        nui_reg = process_regressor(nuisance, nperseg, noverlap, mode='nuisance')
        design_matrix = np.vstack((design_matrix, nui_reg))
    if cov_types is None:
        cov_types = ['covariate'] * len(covariates)

    return design_matrix.T
    if axis == -1:
        axis = X.ndim - 1

    config = set_options(X.shape[axis], input_complex=np.iscomplexobj(X),
                         fs=fs, window=window, nperseg=nperseg, noverlap=noverlap,
                         nfft=nfft, detrend=detrend, return_onesided=return_onesided,
                         scaling=scaling, axis=axis, mode=mode)

def cond_psd(x, covariate=None, nuisance=None,
             nperseg=None, noverlap=None, average='mean',
             nfft=None, fs=1, window='hann',
             detrend=False, scaling='density'):

    f, t, p = compute_windowed_fft(x, fs=fs, nperseg=nperseg, noverlap=noverlap, nfft=nfft,
                                   window=window, detrend=detrend, scaling=scaling)
    f, t, p = compute_windowed_fft(X, config=config, output_roll='glm')

    p = p.real
    goods = np.isnan(p) == False  # noqa: E712
    goods = np.any(goods, axis=1)
    info = {}
    covars = list(covariates.keys())

    # GLM
    design_matrix = get_design_matrix(p.shape[-2], nperseg, noverlap,
                                      covariate=covariate, nuisance=nuisance)
    design_matrix = design_matrix[:, 1:]
    # b, residuals, rank, s = np.linalg.lstsq(design_matrix[goods,:], p[goods,:])
    contrasts = np.eye(design_matrix.shape[1])
    betas, copes, varcopes = glm.fit.ols_fit(design_matrix[goods, :],
                                             p[goods, :],
                                             contrasts)
    DC = glm.design.DesignConfig()
    DC.add_regressor(name='Mean', rtype='Constant')
    info['nuisance'] = np.zeros((p.shape[0], ))
    for idx, var in enumerate(covars):
        info[var] = process_regressor(covariates[var], config, mode=cov_types[idx])[0, :]
        if cov_types[idx] == 'nuisance':
            #info['nuisance'] = np.max(np.c_[info['nuisance'], info[var]], axis=1)
            max_ind = np.where(info[var] > info['nuisance'])
            if len(max_ind) > 0:
                info['nuisance'][max_ind] = info[var][max_ind]
        DC.add_regressor(name=var, rtype='Parametric', datainfo=var, preproc=None)
    weights = None#1 - info['nuisance']

    # Get predicted spectrum from each regressor separately
    pred = np.zeros_like(betas)
    for ii in range(pred.shape[0]):
        b2 = np.zeros_like(betas)
        b2[ii, :] = betas[ii, :]
        pred[ii, :] = glm.fit._get_prediction(design_matrix[goods, :], b2)
    info['num_observations'] = p.shape[0]  # Should always be first axis here
    DC.add_simple_contrasts()
    #import pdb; pdb.set_trace()
    des = DC.design_from_datainfo(info)

    return f, betas, copes, varcopes, pred, design_matrix
    from .orthogonalise import symmetric_orthonormal
    #des.design_matrix = symmetric_orthonormal(des.design_matrix.T)[0].T

    data = glm.data.TrialGLMData(data=p, **info)

if __name__ == '__main__':

    import matplotlib.pyplot as plt

    nperseg = None
    noverlap = None
    nfft = None
    window = 'hann'
    detrend = False
    scaling = 'density'
    model = glm.fit.OLSModel(des, data, fit_args={'method': fitmethod, 'weights': weights})

    fs = 512
    seconds = 600
    mm = 3

    if mm == 1:
        a = emd.utils.ar_simulate(12, fs, seconds, r=.95)[:, 0] * 5e-4
        b = emd.utils.ar_simulate(41, fs, seconds, r=.97)[:, 0] * 5e-4
    return model, f, des

        covariate = np.abs(signal.hilbert(b))
        covariate = covariate / np.max(covariate)
    elif mm == 2:

        a = np.sin(2*np.pi*19*np.linspace(0, seconds, seconds*fs))
        a = (np.linspace(1, 0, seconds*fs) > 0) * a

        b = np.sin(2*np.pi*42*np.linspace(0, seconds, seconds*fs))
        covariate = (np.linspace(-1, 1, seconds*fs))
        b = covariate * b
    else:
        # a = np.sin(2*np.pi*12*np.linspace(0,seconds,seconds*fs))
        a = stats.zscore(emd.utils.ar_simulate(12, fs, seconds, r=.98)[:, 0])
        a += stats.zscore(emd.utils.ar_simulate(0.1, fs, seconds, r=.97)[:, 0])*2
        b = stats.zscore(emd.utils.ar_simulate(10, fs, seconds, r=.96)[:, 0])

        nuisance = np.sin(2*np.pi*0.125*np.linspace(0, seconds, seconds*fs)) > 0
        b = nuisance * b

        covariate = np.random.randn(*nuisance.shape)

    x = a + b + np.random.randn(*a.shape)

    nperseg = 512
    f, p = psd(x, nperseg=nperseg, noverlap=nperseg//2, fs=fs, nfft=2048)
    f, p2 = psd(a, nperseg=nperseg, noverlap=nperseg//2, fs=fs, nfft=2048)
    f2, beta, pred = cond_psd(x, nuisance=nuisance, covariate=covariate,
                              nperseg=nperseg, noverlap=nperseg//2, fs=fs, nfft=2048)
if __name__ == '__main__':

    #import matplotlib.pyplot as plt

    #nperseg = None
    #noverlap = None
    #nfft = None
    #window = 'hann'
    #detrend = False
    #scaling = 'density'

    #fs = 512
    #seconds = 600
    #mm = 3

    #if mm == 1:
    #    a = emd.utils.ar_simulate(12, fs, seconds, r=.95)[:, 0] * 5e-4
    #    b = emd.utils.ar_simulate(41, fs, seconds, r=.97)[:, 0] * 5e-4

    #    covariate = np.abs(signal.hilbert(b))
    #    covariate = covariate / np.max(covariate)
    #elif mm == 2:

    #    a = np.sin(2*np.pi*19*np.linspace(0, seconds, seconds*fs))
    #    a = (np.linspace(1, 0, seconds*fs) > 0) * a

    #    b = np.sin(2*np.pi*42*np.linspace(0, seconds, seconds*fs))
    #    covariate = (np.linspace(-1, 1, seconds*fs))
    #    b = covariate * b
    #else:
    #    # a = np.sin(2*np.pi*12*np.linspace(0,seconds,seconds*fs))
    #    a = stats.zscore(emd.utils.ar_simulate(12, fs, seconds, r=.98)[:, 0])
    #    a += stats.zscore(emd.utils.ar_simulate(0.1, fs, seconds, r=.97)[:, 0])*2
    #    b = stats.zscore(emd.utils.ar_simulate(10, fs, seconds, r=.96)[:, 0])

    #    nuisance = np.sin(2*np.pi*0.125*np.linspace(0, seconds, seconds*fs)) > 0
    #    b = nuisance * b

    #    covariate = np.random.randn(*nuisance.shape)

    #x = a + b + np.random.randn(*a.shape)

    #nperseg = 512
    #f, p = psd(x, nperseg=nperseg, noverlap=nperseg//2, fs=fs, nfft=2048)
    #f, p2 = psd(a, nperseg=nperseg, noverlap=nperseg//2, fs=fs, nfft=2048)
    #f2, beta, pred = glm_psd(x, nuisance=nuisance, covariate=covariate,
    #                          nperseg=nperseg, noverlap=nperseg//2, fs=fs, nfft=2048)

    #plt.figure()
    #plt.plot(f, p, linewidth=2)
    #plt.plot(f, p2, linewidth=2)
    #plt.plot(f2, pred[0, :], ':', linewidth=2)
    #plt.plot(f2, pred[1, :], linewidth=2)
    #plt.plot(f2, pred[2, :], linewidth=2)
    #plt.xlim(0, 25)
    #plt.ylim(0, 0.75)
    #plt.legend(['Total', 'Target', 'Mean', 'Covariate', 'Nuisance'])

    #%% -----------

    from scipy import signal
    import matplotlib.pyplot as plt
    rng = np.random.default_rng()
    #Generate a test signal, a 2 Vrms sine wave at 1234 Hz, corrupted by
    #0.001 V**2/Hz of white noise sampled at 10 kHz.
    fs = 10e3
    N = 1e5
    amp = 2*np.sqrt(2)
    freq = 1234.0
    noise_power = 0.001 * fs / 2
    time = np.arange(N) / fs
    x = amp*np.sin(2*np.pi*freq*time)
    x += rng.normal(scale=np.sqrt(noise_power), size=time.shape)

    #Compute and plot the power spectral density.
    f, Pxx_den = psd(x, fs, nperseg=1024)
    plt.semilogy(f, Pxx_den)
    plt.ylim([0.5e-3, 1])
    plt.xlabel('frequency [Hz]')
    plt.ylabel('PSD [V**2/Hz]')
    plt.show()

    #If we average the last half of the spectral density, to exclude the
    #peak, we can recover the noise power on the signal.
    np.mean(Pxx_den[256:])
    #0.0009924865443739191

    #Now compute and plot the power spectrum.
    f, Pxx_spec = psd(x, fs, 'flattop', 1024, scaling='spectrum')
    plt.figure()
    plt.plot(f, p, linewidth=2)
    plt.plot(f, p2, linewidth=2)
    plt.plot(f2, pred[0, :], ':', linewidth=2)
    plt.plot(f2, pred[1, :], linewidth=2)
    plt.plot(f2, pred[2, :], linewidth=2)
    plt.xlim(0, 25)
    plt.ylim(0, 0.75)
    plt.legend(['Total', 'Target', 'Mean', 'Covariate', 'Nuisance'])
    plt.semilogy(f, np.sqrt(Pxx_spec))
    plt.xlabel('frequency [Hz]')
    plt.ylabel('Linear spectrum [V RMS]')
    plt.show()
    #The peak height in the power spectrum is an estimate of the RMS
    #amplitude.

    np.sqrt(Pxx_spec.max())
    #2.0077340678640727

    #If we now introduce a discontinuity in the signal, by increasing the
    #amplitude of a small portion of the signal by 50, we can see the
    #corruption of the mean average power spectral density, but using a
    #median average better estimates the normal behaviour.
    x[int(N//2):int(N//2)+10] *= 50.
    f, Pxx_den = psd(x, fs, nperseg=1024)
    f_med, Pxx_den_med = psd(x, fs, nperseg=1024, average='median')
    plt.semilogy(f, Pxx_den, label='mean')
    plt.semilogy(f_med, Pxx_den_med, label='median')
    plt.ylim([0.5e-3, 1])
    plt.xlabel('frequency [Hz]')
    plt.ylabel('PSD [V**2/Hz]')
    plt.legend()
    plt.show()