Commit ab0ab7df authored by Andrew Quinn's avatar Andrew Quinn
Browse files

big refactor and glm_irasa_v3

parent 90acc216
Loading
Loading
Loading
Loading
+0 −2
Original line number Diff line number Diff line
@@ -277,9 +277,7 @@ stft_funcs = ['apply_sliding_window',
              '_set_detrend',
              '_set_mode',
              '_set_frange',
              'sw_periodogram',
              'periodogram',
              'sw_multitaper',
              'multitaper',
              'glm_periodogram',
              'glm_multitaper',
+228 −119
Original line number Diff line number Diff line
@@ -533,8 +533,13 @@ def _proc_get_freq_inds(freqvals, fmin, fmax):
           (freqvals <= fmax)
    return fidx

def _proc_get_time_range(nperseg, nstep, fs, xlen):
    # Create time window vector
    noverlap = nperseg - nstep
    return np.arange(nperseg/2, xlen - nperseg/2 + 1, nperseg - noverlap)/float(fs)

def _proc_trim_freq_range(result, freqvals, fmin, fmax):

def _proc_trim_freq_range(result, freqvals, fmin=None, fmax=None, axis=-1):
    """Trim an FFT output to desired frequency range.

    This helper function assumes that we want to trim the final axis.
@@ -546,6 +551,7 @@ def _proc_trim_freq_range(result, freqvals, fmin, fmax):
    freqvals : vector
        Vector of frequency values with length matching the final axis of result
    %(freq_range)s'
    %(axis)s'

    Returns
    -------
@@ -555,10 +561,16 @@ def _proc_trim_freq_range(result, freqvals, fmin, fmax):
        New frequency array matching the final axis of output

    """
    if fmin is None and fmax is None:
        # Just passing through
        return freqs, result
    
    logging.info('Trimming freq axis to range {0} - {1}'.format(fmin, fmax))
    fmin = freqvals[0] if fmin is None else fmin
    fmax = freqvals[-1] if fmax is None else fmax
    fidx = _proc_get_freq_inds(freqvals, fmin, fmax)
    result = result[..., fidx]
    freqs = freqvals[fidx]
    result = np.compress(fidx, result, axis=axis)
    freqs = np.compress(fidx, freqvals)
    logging.debug('fft trimmed output shape {0}'.format(result.shape))

    return freqs, result
@@ -1060,57 +1072,15 @@ class GLMIRASAConfig(GLMPeriodogramConfig, IRASAConfig):
# These functions take input data, run the option handling and execute whatever
# computations are needed

@dataclass
class SpectrumResult:

@set_verbose
def sw_periodogram(x,
                   # STFT window args
                   nperseg=None, noverlap=None, window_type='hann', detrend='constant',
                   # FFT core args
                   nfft=None, axis=-1, return_onesided=True, mode='psd',
                   scaling='density', fs=1.0, fmin=None, fmax=None,
                   # misc
                   return_config=False, verbose=None):
    """Compute Periodogram by averaging across windows in a STFT.

    Parameters
    ----------
    x : array_like
        Time series of measurement values
    %(stft_window_user)s'
    %(fft_user)s'
    return_config : bool
        Indicate whether parameter configuration object should be returned
        alongside result (Default value = False)
    %(verbose)s'

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    config : PeriodogramConfig, optional
        Configuration object containing all parameters used to compute
        spectrum, optionally returned based on value of `return_config`.

    """
    # Config object stores options in one place and sets sensible defaults for
    # unspecified options given the data in-hand
    config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)),
                               average=None, fs=fs, window_type=window_type, nperseg=nperseg,
                               noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode,
                               return_onesided=return_onesided, scaling=scaling, axis=axis,
                               fmin=fmin, fmax=fmax, output_axis='auto')

    f, t, p = compute_stft(x, **config.stft_args)
    logging.debug(p.shape)
    def __init__(self, f, t, pxx, config=None):

    if return_config:
        return f, t, p, config
    else:
        return f, t, p
        self.f = f
        self.t = t
        self.spectrum = pxx
        self.config = config


@set_verbose
@@ -1162,71 +1132,11 @@ def periodogram(x, average='mean',

    logging.info('Averaging across first dim of result using method {0}'.format(config.average))
    p = _proc_apply_average(p, config.average, axis=0)
    t = None if config.average is None else t

    logging.info('Returning spectrum of shape {0}'.format(p.shape))
    if return_config:
        return f, p, config
    else:
        return f, p


@set_verbose
def sw_multitaper(x,
                  # Multitaper core
                  num_tapers='auto', freq_resolution=1, time_bandwidth=5, apply_tapers='broadcast',
                  # STFT window args
                  nperseg=None, noverlap=None, window_type='hann', detrend='constant',
                  # FFT core args
                  nfft=None, axis=-1, return_onesided=True, mode='psd',
                  scaling='density', fs=1.0, fmin=None, fmax=None,
                  # misc
                  return_config=False, verbose=None):
    """Compute a multi-tapered power spectrum across windows in a STFT.

    Parameters
    ----------
    x : array_like
        Time series of measurement values
    %(multitaper_core)s'
    %(stft_window_user)s'
    %(fft_user)s'
    return_config : bool
        Indicate whether parameter configuration object should be returned
        alongside result (Default value = False)
    %(verbose)s'

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    config : MultiTaperConfig, optional
        Configuration object containing all parameters used to compute
        spectrum, optionally returned based on value of `return_config`.

    """
    # Config object stores options in one place and sets sensible defaults for
    # unspecified options given the data in-hand
    config = MultiTaperConfig(x.shape[axis],
                              input_complex=np.any(np.iscomplex(x)),
                              time_bandwidth=time_bandwidth,
                              num_tapers=num_tapers, apply_tapers=apply_tapers,
                              freq_resolution=freq_resolution, average=None,
                              fs=fs, window_type=None, nperseg=nperseg,
                              noverlap=noverlap, nfft=nfft, detrend=detrend,
                              return_onesided=return_onesided, mode=mode,
                              scaling=scaling, axis=axis, fmin=fmin, fmax=fmax,
                              output_axis='auto')

    f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args)

    if return_config:
        return f, t, p, config
    else:
        return f, t, p
    return SpectrumResult(f, t, p, config=config)


@set_verbose
@@ -1284,11 +1194,9 @@ def multitaper(x, average='mean',
    f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args)

    p = _proc_apply_average(p, config.average, axis=axis)
    t = None if config.average is None else t

    if return_config:
        return f, p, config
    else:
        return f, p
    return SpectrumResult(f, t, p, config=config)


# -----------------------------------------------------------------------
@@ -1936,6 +1844,122 @@ def glm_multitaper(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,
# Config object stores options in one place and sets sensible defaults for
# unspecified options given the data in-hand


@set_verbose
def irasa_v3(x, method='original', resample_factors=None, aperiodic_average='median',
          #bootstrap_osc=None, bs_average='mean',
          # Periodogram args
          average='median',
          # STFT window args
          nperseg=None, noverlap=None, window_type='hann', detrend='constant',
          # FFT core args
          nfft=None, axis=-1, return_onesided=True, mode='psd',
          scaling='density', fs=1.0, fmin=None, fmax=None,
          # misc
          return_config=False, verbose=None, avg2=False):
    """Compute Periodogram by averaging across windows in a STFT.

    Parameters
    ----------
    x : array_like
        Time series of measurement values
    %(irasa)s`
    %(average)s`
    %(stft_window_user)s'
    %(fft_user)s'
    return_config : bool
        Indicate whether parameter configuration object should be returned
        alongside result (Default value = False)
    %(verbose)s'

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    config : PeriodogramConfig, optional
        Configuration object containing all parameters used to compute
        spectrum, optionally returned based on value of `return_config`.

    """
    # Config object stores options in one place and sets sensible defaults for
    # unspecified options given the data in-hand
    logging.info('Setting config options')
    config = IRASAConfig(x.shape[axis], method=method,
                         resample_factors=resample_factors, aperiodic_average=aperiodic_average,
                         input_complex=np.any(np.iscomplex(x)),
                         average=average, fs=fs, window_type=window_type, nperseg=nperseg,
                         noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode,
                         return_onesided=return_onesided, scaling=scaling, axis=axis,
                         fmin=None, fmax=None, output_axis='time_first')

    fmin = 0 if fmin is None else fmin
    fmax = fs / 2 if fmax is None else fmax

    if resample_factors is None and method == 'original':
        resample_factors = np.linspace(1.1, 1.9, 17)
    elif resample_factors is None and method == 'modified':
        resample_factors = np.exp(np.linspace(np.log(0.33), np.log(2), 7))
    logging.info('Resample factors defined : {}'.format(resample_factors))
    resample_factors = np.round(resample_factors, 4)

    if resample_factors.min() * config.fmax < fmax:
        msg = 'Requested fmax exceeds the range of valid resamplings. Specified fmax : {} Max valid fmax : {}'
        logging.warning(msg.format(fmax, resample_factors.min() * config.fmax))

    if resample_factors.max() * config.fmin > fmin:
        # Not sure that this would ever trigger
        msg = 'Requested fmin exceeds the range of valid resamplings. Specified fmin : {} Min valid fmin : {}'
        logging.warning(msg.format(fmin, resample_factors.max() * config.fmin))

    # Compute IRASA for each individual windowed sliding window data segment.
    x = _proc_roll_input(x, axis=config.axis)  # Put time to final dimension
    yy = apply_sliding_window(x, **config.sliding_window_args)
    print('a : {}'.format(np.sum(yy**2)))
    freqs, pxx_full = compute_fft(yy, **config.fft_args)
    t = _proc_get_time_range(config.nperseg, config.nstep, config.fs, config.input_len)
    t = None if config.average is None else t

    for ii, rf in enumerate(resample_factors):
        rat = fractions.Fraction(str(rf))
        up, down = rat.numerator, rat.denominator
        logging.info('Resampling by {}, factor {} of {}'.format(rf, ii+1, len(resample_factors)))

        zz = signal.resample_poly(yy, up, down, axis=-1)  # This changes the variance of the signal...
        freqs, pxx_resampled = compute_fft(zz, **config.fft_args)

        if ii == 0:
            pxx_aperiodic = pxx_resampled[None, ...]
        else:
            pxx_aperiodic = np.concatenate((pxx_aperiodic, pxx_resampled[None, ...]), axis=0)

    # Average across resamplings
    logging.info('Averaging across {0} resamplings using method {1}'.format(pxx_aperiodic.shape[0], 
                                                                            config.aperiodic_average))
    pxx_aperiodic = _proc_apply_average(pxx_aperiodic, config.aperiodic_average, axis=0)
    pxx_oscillatory = pxx_full - pxx_aperiodic

    # Get trimmed frequency range
    out_freqs, pxx_aperiodic = _proc_trim_freq_range(pxx_aperiodic, freqs, fmin, fmax)
    out_freqs, pxx_oscillatory = _proc_trim_freq_range(pxx_oscillatory, freqs, fmin, fmax)

    logging.info('Averaging across {0} time segments using method {1}'.format(pxx_aperiodic.shape[0], config.average))
    pxx_aperiodic = _proc_apply_average(pxx_aperiodic, config.average, axis=0)
    pxx_oscillatory = _proc_apply_average(pxx_oscillatory, config.average, axis=0)

    if pxx_aperiodic.ndim > 1:
        pxx_aperiodic = _proc_unroll_output(pxx_aperiodic, pxx_aperiodic.ndim-1, output_axis='auto')
        pxx_oscillatory = _proc_unroll_output(pxx_oscillatory, pxx_aperiodic.ndim-1, output_axis='auto')

    aperiodic = SpectrumResult(out_freqs, t, pxx_aperiodic, config=config)
    oscillatory = SpectrumResult(out_freqs, t, pxx_oscillatory, config=config)

    return aperiodic, oscillatory


@set_verbose
def irasa(x, method='original', resample_factors=None, aperiodic_average='median',
          #bootstrap_osc=None, bs_average='mean',
@@ -2061,6 +2085,91 @@ def irasa(x, method='original', resample_factors=None, aperiodic_average='median
    return freqs, aperiodic_pxx, oscillatory_pxx


@set_verbose
def glm_irasa_v3(x,
              # GLM args
              reg_categorical=None, reg_ztrans=None, reg_unitmax=None,
              contrasts=None, fit_method='pinv', fit_intercept=True,
              # IRASA args
              method='original', resample_factors=None, aperiodic_average='median',
              # Periodogram args
              average='median',
              # STFT window args
              nperseg=None, noverlap=None, window_type='hann', detrend='constant',
              # FFT core args
              nfft=None, axis=-1, return_onesided=True, mode='psd',
              scaling='density', fs=1.0, fmin=None, fmax=None,
              # misc
              return_config=False, verbose=None):
    """Compute Periodogram by averaging across windows in a STFT.

    Parameters
    ----------
    x : array_like
        Time series of measurement values
    %(glmperiodogram)s'
    %(irasa)s'
    %(average)s'
    %(stft_window_user)s'
    %(fft_user)s'
    return_config : bool
        Indicate whether parameter configuration object should be returned
        alongside result (Default value = False)
    %(verbose)s'

    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.
    config : PeriodogramConfig, optional
        Configuration object containing all parameters used to compute
        spectrum, optionally returned based on value of `return_config`.

    """
    # Config object stores options in one place and sets sensible defaults for
    # unspecified options given the data in-hand
    logging.info('Setting config options')
    aperiodic, oscillatory = irasa_v3(x, reg_categorical=reg_categorical, reg_ztrans=reg_ztrans, 
                                      reg_unitmax=reg_unitmax, contrasts=contrasts, 
                                      fit_method=fit_method, fit_intercept=fit_intercept,
                                      # IRASA args
                                      method=method, resample_factors=resample_factors, 
                                      aperiodic_average=aperiodic_average,
                                      # Periodogram args
                                      average=average,
                                      # STFT window args
                                      nperseg=nperseg, noverlap=noverlap, 
                                      window_type=window_type, detrend=detrend,
                                      # FFT core args
                                      nfft=nfft, axis=axis, return_onesided=return_onesided, mode=mode,
                                      scaling=scaling, fs=fs, fmin=fmin, fmax=fmax,
                                      # misc
                                      return_config=return_config, verbose=verbose) 

    # Transform inputs into predicable, sanity checked dictionaries
    logging.info('Processing Conditions, Covariates and Confounds')
    for ind, pxx in enumerate((aperiodic, oscillatory)):
        pxx.config.reg_categorical = _process_input_covariate(pxx.config.reg_categorical, pxx.config.input_len)
        pxx.config.reg_ztrans = _process_input_covariate(pxx.config.reg_ztrans, pxx.config.input_len)
        pxx.config.reg_unitmax = _process_input_covariate(pxx.config.reg_unitmax, pxx.config.input_len)

        # Fit model
        model, des, data = _glm_fit_glmtools(pxx.spectrum, pxx.config.reg_categorical, pxx.config.reg_ztrans,
                                             pxx.config.reg_unitmax, pxx.config,
                                             contrasts=pxx.config.contrasts,
                                             fit_intercept=pxx.config.fit_intercept)
        if ind == 0:
            glm_aperiodic = GLMSpectrumResult(pxx.f, model, des, data, config=pxx.config)
        else:
            glm_oscillatory = GLMSpectrumResult(pxx.f, model, des, data, config=pxx.config)

    return glm_aperiodic, glm_oscillatory


@set_verbose
def glm_irasa(x,
              # GLM args
+19 −19
Original line number Diff line number Diff line
@@ -21,9 +21,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
        for ii in range(5):
            xx = np.random.randn(4096,)
            f, pxx = signal.welch(xx, nperseg=2**(4+ii))
            f2, pxx2 = periodogram(xx, nperseg=2**(4+ii))
            pxx2 = periodogram(xx, nperseg=2**(4+ii))

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_window_type(self):
        """Ensure window type results are consistent."""
@@ -35,9 +35,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
            xx = np.random.randn(4096,)
            win = window_tests[ii] if window_tests[ii] is not None else np.ones((128,)) / 128
            f, pxx = signal.welch(xx, nperseg=128, window=win)
            f2, pxx2 = periodogram(xx, nperseg=128, window_type=window_tests[ii])
            pxx2 = periodogram(xx, nperseg=128, window_type=window_tests[ii])

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_nfft(self):
        """Ensure nfft results are consistent."""
@@ -46,9 +46,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
        for ii in range(5):
            xx = np.random.randn(4096,)
            f, pxx = signal.welch(xx, nfft=2**(ii+4), nperseg=16)
            f2, pxx2 = periodogram(xx, nfft=2**(ii+4), nperseg=16)
            pxx2 = periodogram(xx, nfft=2**(ii+4), nperseg=16)

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_scaling(self):
        """Ensure scaling results are consistent."""
@@ -59,9 +59,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
        for ii in range(len(scaling_tests)):
            xx = np.random.randn(4096,)
            f, pxx = signal.welch(xx, nperseg=128, scaling=scaling_tests[ii])
            f2, pxx2 = periodogram(xx, nperseg=128, scaling=scaling_tests[ii])
            pxx2 = periodogram(xx, nperseg=128, scaling=scaling_tests[ii])

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_sided(self):
        """Ensure scaling results are consistent."""
@@ -73,9 +73,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
            print(side_tests[ii])
            xx = np.random.randn(4096,)
            f, pxx = signal.welch(xx, nperseg=128, return_onesided=side_tests[ii])
            f2, pxx2 = periodogram(xx, nperseg=128, return_onesided=side_tests[ii])
            pxx2 = periodogram(xx, nperseg=128, return_onesided=side_tests[ii])

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_detrend(self):
        """Ensure scaling results are consistent."""
@@ -87,9 +87,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
            print(detrend_tests[ii])
            xx = np.random.randn(4096,)
            f, pxx = signal.welch(xx, nperseg=128, detrend=detrend_tests[ii])
            f2, pxx2 = periodogram(xx, nperseg=128, detrend=detrend_tests[ii])
            pxx2 = periodogram(xx, nperseg=128, detrend=detrend_tests[ii])

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_average(self):
        """Ensure scaling results are consistent."""
@@ -106,9 +106,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
            # they use a median with bias correction referencing a paper
            # https://github.com/scipy/scipy/blob/v1.11.3/scipy/signal/_spectral_py.py#L2037
            avg = average_tests[ii] if ii == 0 else average_tests[ii] + '_bias'
            f2, pxx2 = periodogram(xx, nperseg=128, average=avg, verbose='DEBUG')
            pxx2 = periodogram(xx, nperseg=128, average=avg, verbose='DEBUG')

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))


class TestBasicIRASA(unittest.TestCase):
@@ -116,12 +116,12 @@ class TestBasicIRASA(unittest.TestCase):

    def test_canary_irasa(self):
        """Ensure irasa runs."""
        from ..stft import irasa, periodogram
        from ..stft import irasa_v3, periodogram

        # Run test 5 times
        for ii in range(5):
            xx = np.random.randn(4096,)
            f, pxx = periodogram(xx, nperseg=2**(4+ii), average='median')
            f2, aperiodic, oscillatory = irasa(xx, nperseg=2**(4+ii))
            assert(np.all(f == f2))
            assert(np.allclose(pxx, aperiodic+oscillatory))
            pxx = periodogram(xx, nperseg=2**(4+ii), average='mean')
            aperiodic, oscillatory = irasa_v3(xx, nperseg=2**(4+ii), average='mean')
            assert(np.all(pxx.f == aperiodic.f))
            assert(np.allclose(pxx.spectrum, aperiodic.spectrum + oscillatory.spectrum))