-
`import tensorflow as tf
from tensorflow.python.pywrap_mlir import import_graphdefBATCH_SIZE = 8
INPUTS = 4
OUTPUTS = 3if name == "main":
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(INPUTS,), batch_size=BATCH_SIZE),
tf.keras.layers.Dense(OUTPUTS),
]
)model.compile( optimizer=tf.keras.optimizers.SGD(), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), ) @tf.function() def train_step_func(x, y): model.train_step((x, y)) concrete_func = train_step_func.get_concrete_function( x=tf.TensorSpec(shape=(BATCH_SIZE, INPUTS), dtype=tf.float32), y=tf.TensorSpec((BATCH_SIZE,), tf.int32), ) mlir = tf.mlir.experimental.convert_function( concrete_func, pass_pipeline="tf-standard-pipeline" ) print(mlir) mlir_tf = import_graphdef( concrete_func.graph.as_graph_def(add_shapes=True), "tf-standard-pipeline", False ) print(mlir_tf) ` -
There is https://github.com/tensorflow/tensorflow/pull/49380 PR waiting for approval in which you can specify input and output options. All the input and output will be there in the main function argument.
Please register or sign in to comment