Skip to content

Use numpy.testing in Python tests

Brian Ward requested to merge tests/use-numpy-asserts into main

This changes test.py to use numpy.testing.assert_* rather than the (equivalent) assert numpy.*

Why? Because the test output is much better if the assertions fail, e.g.:

            constrained_theta[0] = 100
    
>           assert np.allclose(constrained_theta, x)
E           assert False
E            +  where False = <function allclose at 0x7f3d259914c0>(array([100.]), array([0.2300275]))
E            +    where <function allclose at 0x7f3d259914c0> = np.allclose

vs after

            constrained_theta[0] = 100
    
>           np.testing.assert_allclose(constrained_theta, x)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-07, atol=0
E           
E           Mismatched elements: 1 / 1 (100%)
E           Max absolute difference: 99.71472635
E           Max relative difference: 349.54060868
E            x: array([100.])
E            y: array([0.285274])

The above output is from running pytest test.py. The difference is even more stark if you don't use Pytest. Then, the results are

test bernoulli
Traceback (most recent call last):
  File "/home/brian/Dev/cpp/bridgestan/test/test.py", line 177, in <module>
    test_bernoulli()
  File "/home/brian/Dev/cpp/bridgestan/test/test.py", line 35, in test_bernoulli
    assert np.allclose(constrained_theta, x)
AssertionError

and

test bernoulli
Traceback (most recent call last):
  File "/home/brian/Dev/cpp/bridgestan/test/test.py", line 175, in <module>
    test_bernoulli()
  File "/home/brian/Dev/cpp/bridgestan/test/test.py", line 35, in test_bernoulli
    np.testing.assert_allclose(constrained_theta, x)
  File "/home/brian/miniconda3/envs/bridgestan/lib/python3.9/site-packages/numpy/testing/_private/utils.py", line 1527, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/brian/miniconda3/envs/bridgestan/lib/python3.9/site-packages/numpy/testing/_private/utils.py", line 844, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 1 / 1 (100%)
Max absolute difference: 99.71290823
Max relative difference: 347.32067308
 x: array([100.])
 y: array([0.287092])

Merge request reports