Loading test/test_features.py +3 −5 Original line number Diff line number Diff line Loading @@ -90,18 +90,16 @@ class TestMemmap(unittest.TestCase): class TestShuffleTransitions(unittest.TestCase): def test_shuffle_transitions(self): rb = ReplayBuffer(64,{"a": {}}) rb = ReplayBuffer(64, {"a": { "dtype": np.int64 }}) a = np.arange(64) a = np.arange(64, dtype=np.int64) rb.add(a=a) s1 = rb.get_all_transitions()["a"] s2 = rb.get_all_transitions(shuffle=True)["a"] self.assertFalse((s1 == s2).all()) s = np.intersect1d(s1, s2) np.testing.assert_allclose(np.ravel(s), np.ravel(s1)) self.assertFalse(set(np.ravel(s1)) ^ set(np.ravel(s2))) if __name__ == '__main__': Loading Loading
test/test_features.py +3 −5 Original line number Diff line number Diff line Loading @@ -90,18 +90,16 @@ class TestMemmap(unittest.TestCase): class TestShuffleTransitions(unittest.TestCase): def test_shuffle_transitions(self): rb = ReplayBuffer(64,{"a": {}}) rb = ReplayBuffer(64, {"a": { "dtype": np.int64 }}) a = np.arange(64) a = np.arange(64, dtype=np.int64) rb.add(a=a) s1 = rb.get_all_transitions()["a"] s2 = rb.get_all_transitions(shuffle=True)["a"] self.assertFalse((s1 == s2).all()) s = np.intersect1d(s1, s2) np.testing.assert_allclose(np.ravel(s), np.ravel(s1)) self.assertFalse(set(np.ravel(s1)) ^ set(np.ravel(s2))) if __name__ == '__main__': Loading