JAXDiscipline does not work with namespaces
From gemseo issue (gemseo#1349 (closed))
Hi,
I am using the
AutoJAXDisciplineofgemseo-jaxplugin to take advantage of AD derivation, but I have come into a problem with namespaces feature. I followed the example here at the beginning and I modified it to test namespaces, as follows:from gemseo_jax.auto_jax_discipline import AutoJAXDiscipline import numpy as np def f(a, x=1., y=2.): z = x + y t = x * y**2 m = a + 35 return z, t, m class NamespacedDisc(AutoJAXDiscipline): def __init__(self): super().__init__(function=f) self.inputs_namespace = ['x'] self.outputs_namespace = ['t'] jax_disc = NamespacedDisc() #### Comment this for input in jax_disc.inputs_namespace: jax_disc.add_namespace_to_input(input, 'to') for output in jax_disc.outputs_namespace: jax_disc.add_namespace_to_output(output, 'to') #### input_data = {'a': np.array([3])} results = jax_disc.execute(input_data) print('\nResults:') for k, v in results.items(): print(f'{k}: {v}')You can see that everything works as expected if you comment out the four lines with namespaces. I know that not all the classes support namespaces, but I wonder if there is a quick fix to make it work.
I am not sure what is the correct way of handling namespaces, but what seems to solve the issue above is to modify line 264 in src/gemseo_jax/jax_discipline.py from:
def _run(self) -> None:
output_data = self.jax_out_func({
input_name: self.io.data[input_name]
for input_name in self.input_grammar.names
})
to
def _run(self) -> None:
output_data = self.jax_out_func({
input_name: self.io.get_input_data(with_namespaces=False)[input_name]
for input_name in self.input_grammar.names_without_namespace
})