JAXDiscipline does not work with namespaces

From gemseo issue (gemseo#1349 (closed))

Hi,

I am using the AutoJAXDiscipline of gemseo-jax plugin to take advantage of AD derivation, but I have come into a problem with namespaces feature. I followed the example here at the beginning and I modified it to test namespaces, as follows:

from gemseo_jax.auto_jax_discipline import AutoJAXDiscipline
import numpy as np


def f(a, x=1., y=2.):
    z = x + y
    t = x * y**2
    m = a + 35

    return z, t, m


class NamespacedDisc(AutoJAXDiscipline):
    def __init__(self):
        super().__init__(function=f)
        self.inputs_namespace = ['x']
        self.outputs_namespace = ['t']


jax_disc = NamespacedDisc()

#### Comment this
for input in jax_disc.inputs_namespace:
    jax_disc.add_namespace_to_input(input, 'to')
for output in jax_disc.outputs_namespace:
    jax_disc.add_namespace_to_output(output, 'to')
####

input_data = {'a': np.array([3])}

results = jax_disc.execute(input_data)

print('\nResults:')
for k, v in results.items():
    print(f'{k}: {v}')

You can see that everything works as expected if you comment out the four lines with namespaces. I know that not all the classes support namespaces, but I wonder if there is a quick fix to make it work.

I am not sure what is the correct way of handling namespaces, but what seems to solve the issue above is to modify line 264 in src/gemseo_jax/jax_discipline.py from:

    def _run(self) -> None:
        output_data = self.jax_out_func({
            input_name: self.io.data[input_name]
            for input_name in self.input_grammar.names
        })

to

    def _run(self) -> None:
        output_data = self.jax_out_func({
            input_name: self.io.get_input_data(with_namespaces=False)[input_name]
            for input_name in self.input_grammar.names_without_namespace
        })