Commit 2cc26d80 authored by Rémi Huguet's avatar Rémi Huguet
Browse files

fix: compute batch size private method working statically

parent 57a40a7f
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -69,7 +69,10 @@ class Container:
            (50001, 250),
            (100001, 100)
        ]
        try:
            input_size = len(self._ths[0].samples[self.frame])
        except AttributeError:
            input_size = 0
        max_size = max(trace_size, input_size)
        for i in range(len(ref_sizes)):
            try:
+8 −0
Original line number Diff line number Diff line
@@ -189,6 +189,14 @@ def test_container_with_frame_compute_batch_size(ths):
    assert np.array_equal(b.samples, ths.samples[:s, :20])


def test_container_compute_batch_size_static_call(ths):
    s = scared.Container._compute_batch_size({}, trace_size=len(ths.samples[0, :20]))
    assert isinstance(s, int)
    c = scared.Container(ths, frame=slice(None, 20))
    b = c.batches(batch_size=s)[0]
    assert np.array_equal(b.samples, ths.samples[:s, :20])


def test_container_with_multiple_preprocess_and_frame(ths):
    @scared.preprocess
    def square(traces):