Loading test/test_jax.py +6 −3 Original line number Diff line number Diff line Loading @@ -7,12 +7,15 @@ from cpprb import ReplayBuffer from cpprb.PyReplayBuffer import NstepBuffer is_win = sys.platform.startswith("win") if not is_win: has_jax: bool = False try: import jax.numpy as jnp has_jax = True except ImportError: pass @unittest.skipIf(is_win, "JAX doesn't support Windows") @unittest.skipUnless(has_jax, "JAX is not installed") class TestJAX(unittest.TestCase): def test_nstep_buffer(self): buffer = NstepBuffer({"obs": {}, "rew": {}, "done": {}, "next_obs": {}}, Loading Loading
test/test_jax.py +6 −3 Original line number Diff line number Diff line Loading @@ -7,12 +7,15 @@ from cpprb import ReplayBuffer from cpprb.PyReplayBuffer import NstepBuffer is_win = sys.platform.startswith("win") if not is_win: has_jax: bool = False try: import jax.numpy as jnp has_jax = True except ImportError: pass @unittest.skipIf(is_win, "JAX doesn't support Windows") @unittest.skipUnless(has_jax, "JAX is not installed") class TestJAX(unittest.TestCase): def test_nstep_buffer(self): buffer = NstepBuffer({"obs": {}, "rew": {}, "done": {}, "next_obs": {}}, Loading