Commit e2d19d6b authored by Yamada Hiroyuki's avatar Yamada Hiroyuki
Browse files

Fix: Test JAX only when installed

parent 0d78759d
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -7,12 +7,15 @@ from cpprb import ReplayBuffer
from cpprb.PyReplayBuffer import NstepBuffer


is_win = sys.platform.startswith("win")
if not is_win:
has_jax: bool = False
try:
    import jax.numpy as jnp
    has_jax = True
except ImportError:
    pass


@unittest.skipIf(is_win, "JAX doesn't support Windows")
@unittest.skipUnless(has_jax, "JAX is not installed")
class TestJAX(unittest.TestCase):
    def test_nstep_buffer(self):
        buffer = NstepBuffer({"obs": {}, "rew": {},  "done": {}, "next_obs": {}},