Commit e669380f authored by Andrew Quinn's avatar Andrew Quinn
Browse files

Merge branch 'irasa' into 'master'

Irasa

See merge request !87
parents 741dfa26 355c6ddb
Loading
Loading
Loading
Loading
Loading
+33 −3
Original line number Diff line number Diff line
@@ -136,6 +136,22 @@ docdict['fft_core'] = docdict['nfft'] + docdict['axis'] + docdict['fft_side'] +
docdict['fft_user'] = docdict['nfft'] + docdict['axis'] + docdict['return_onesided'] + \
                      docdict['spec_mode'] + docdict['fft_scaling'] + docdict['fs'] + docdict['freq_range']

docdict['average'] = """
    average : { 'mean', 'median', 'median_bias' }, optional
        Method to use when averaging across sliding window segments in a periodograms.
        Defaults to 'mean'."""

docdict['irasa'] = """
    method : {'original', 'modified'}
        whether to compute the original implementation of IRASA or the modified update
        (default is 'original')
    resample_factors : {None, array_like}
        array of resampling factors to average across or None, in which a set
        of factors are automatically computed (default is None).
    aperiodic_average : {'mean', 'median', 'median_bias', 'min'}
        method for averaging across irregularly resampled spectra to estimate
        the aperiodic component (default is 'median')."""

docdict['nperseg'] = """
    nperseg : int
        Length of each segment. Defaults to None, but if window is str or
@@ -222,6 +238,20 @@ docdict['multitaper_core'] = """
        or to iterate through each in a loop. Broadcasting is probably faster but
        more memory intensive. (Default value = 'broadcast')"""

docdict['glmperiodogram'] = """
    reg_categorical : dict or None
        Dictionary of covariate time series to be added as binary regessors. (Default value = None)
    reg_ztrans : dict or None
        Dictionary of covariate time series to be added as z-standardised regessors. (Default value = None)
    reg_unitmax : dict or None
        Dictionary of confound time series to be added as positive-valued unitmax regessors. (Default value = None)
    contrasts : dict or None
        Dictionary of contrasts to be computed in the model.
        (Default value = None, will add a simple contrast for each regressor)
    fit_intercept : bool
        Specifies whether a constant valued 'intercept' regressor is included in the model. (Default value = True)"""


stft_funcs = ['apply_sliding_window',
              'compute_fft',
              'compute_stft',
@@ -241,9 +271,9 @@ stft_funcs = ['apply_sliding_window',
              '_set_detrend',
              '_set_mode',
              '_set_frange',
              'sw_periodogram',
              'periodogram',
              'sw_multitaper',
              'multitaper',
              'glm_periodogram',
              'glm_multitaper']
              'glm_multitaper',
              'irasa',
              'glm_irasa']
+21 −0
Original line number Diff line number Diff line
@@ -192,6 +192,27 @@ class AbstractLinearModel(AbstractAnam):

        return resid

    def simulate_data(self, num_samples=1000, num_realisations=1, use_cov=True):
        num_sources = self.nsignals
        # Preallocate output
        Y = np.zeros((num_sources, num_samples, num_realisations))

        for ep in range(num_realisations):

            # Create driving noise signal
            Y[:, :, ep] = np.random.randn(num_sources, num_samples)

            if use_cov:
                C = np.linalg.cholesky(self.resid_cov)
                Y[:, :, ep] = Y[:, :, ep].T.dot(C).T

            # Main Loop
            for t in range(self.order, num_samples):
                for p in range(1, self.order):
                    Y[:, t, ep] -= -self.parameters[:, :, p].dot(Y[:, t-p, ep])

        return Y


__all__.append('AbstractLinearModel')
register_class(AbstractLinearModel)
+525 −278

File changed.

Preview size limit exceeded, changes collapsed.

+30 −14
Original line number Diff line number Diff line
@@ -21,9 +21,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
        for ii in range(5):
            xx = np.random.randn(4096,)
            f, pxx = signal.welch(xx, nperseg=2**(4+ii))
            f2, pxx2 = periodogram(xx, nperseg=2**(4+ii))
            pxx2 = periodogram(xx, nperseg=2**(4+ii))

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_window_type(self):
        """Ensure window type results are consistent."""
@@ -35,9 +35,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
            xx = np.random.randn(4096,)
            win = window_tests[ii] if window_tests[ii] is not None else np.ones((128,)) / 128
            f, pxx = signal.welch(xx, nperseg=128, window=win)
            f2, pxx2 = periodogram(xx, nperseg=128, window_type=window_tests[ii])
            pxx2 = periodogram(xx, nperseg=128, window_type=window_tests[ii])

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_nfft(self):
        """Ensure nfft results are consistent."""
@@ -46,9 +46,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
        for ii in range(5):
            xx = np.random.randn(4096,)
            f, pxx = signal.welch(xx, nfft=2**(ii+4), nperseg=16)
            f2, pxx2 = periodogram(xx, nfft=2**(ii+4), nperseg=16)
            pxx2 = periodogram(xx, nfft=2**(ii+4), nperseg=16)

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_scaling(self):
        """Ensure scaling results are consistent."""
@@ -59,9 +59,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
        for ii in range(len(scaling_tests)):
            xx = np.random.randn(4096,)
            f, pxx = signal.welch(xx, nperseg=128, scaling=scaling_tests[ii])
            f2, pxx2 = periodogram(xx, nperseg=128, scaling=scaling_tests[ii])
            pxx2 = periodogram(xx, nperseg=128, scaling=scaling_tests[ii])

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_sided(self):
        """Ensure scaling results are consistent."""
@@ -73,9 +73,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
            print(side_tests[ii])
            xx = np.random.randn(4096,)
            f, pxx = signal.welch(xx, nperseg=128, return_onesided=side_tests[ii])
            f2, pxx2 = periodogram(xx, nperseg=128, return_onesided=side_tests[ii])
            pxx2 = periodogram(xx, nperseg=128, return_onesided=side_tests[ii])

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_detrend(self):
        """Ensure scaling results are consistent."""
@@ -87,9 +87,9 @@ class TestSTFTAgainstScipy(unittest.TestCase):
            print(detrend_tests[ii])
            xx = np.random.randn(4096,)
            f, pxx = signal.welch(xx, nperseg=128, detrend=detrend_tests[ii])
            f2, pxx2 = periodogram(xx, nperseg=128, detrend=detrend_tests[ii])
            pxx2 = periodogram(xx, nperseg=128, detrend=detrend_tests[ii])

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))

    def test_simple_periodogram_average(self):
        """Ensure scaling results are consistent."""
@@ -106,6 +106,22 @@ class TestSTFTAgainstScipy(unittest.TestCase):
            # they use a median with bias correction referencing a paper
            # https://github.com/scipy/scipy/blob/v1.11.3/scipy/signal/_spectral_py.py#L2037
            avg = average_tests[ii] if ii == 0 else average_tests[ii] + '_bias'
            f2, pxx2 = periodogram(xx, nperseg=128, average=avg, verbose='DEBUG')
            pxx2 = periodogram(xx, nperseg=128, average=avg, verbose='DEBUG')

            assert(np.allclose(pxx, pxx2))
            assert(np.allclose(pxx, pxx2.spectrum))


class TestBasicIRASA(unittest.TestCase):
    """Test that IRASA functions run."""

    def test_canary_irasa(self):
        """Ensure irasa runs."""
        from ..stft import irasa, periodogram

        # Run test 5 times
        for ii in range(5):
            xx = np.random.randn(4096,)
            pxx = periodogram(xx, nperseg=2**(4+ii), average='mean')
            aperiodic, oscillatory = irasa(xx, nperseg=2**(4+ii), average='mean')
            assert(np.all(pxx.f == aperiodic.f))
            assert(np.allclose(pxx.spectrum, aperiodic.spectrum + oscillatory.spectrum))