• `import tensorflow as tf
    from tensorflow.python.pywrap_mlir import import_graphdef

    BATCH_SIZE = 8
    INPUTS = 4
    OUTPUTS = 3

    if name == "main":
    model = tf.keras.Sequential(
    [
    tf.keras.layers.InputLayer(input_shape=(INPUTS,), batch_size=BATCH_SIZE),
    tf.keras.layers.Dense(OUTPUTS),
    ]
    )

    model.compile(                                                                                                                                                                                                 
        optimizer=tf.keras.optimizers.SGD(),                                                                                                                                                                       
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),                                                                                                                                      
    )                                                                                                                                                                                                              
                                                                                                                                                                                                                   
    @tf.function()                                                                                                                                                                                                 
    def train_step_func(x, y):                                                                                                                                                                                     
        model.train_step((x, y))                                                                                                                                                                                   
                                                                                                                                                                                                                   
    concrete_func = train_step_func.get_concrete_function(                                                                                                                                                         
        x=tf.TensorSpec(shape=(BATCH_SIZE, INPUTS), dtype=tf.float32),                                                                                                                                             
        y=tf.TensorSpec((BATCH_SIZE,), tf.int32),                                                                                                                                                                  
    )                                                                                                                                                                                                              
                                                                                                                                                                                                                   
    mlir = tf.mlir.experimental.convert_function(                                                                                                                                                                  
        concrete_func, pass_pipeline="tf-standard-pipeline"                                                                                                                                                        
    )                                                                                                                                                                                                              
    print(mlir)                                                                                                                                                                                                                                                                                                                                                             
                                                                                                                                                                                                                   
    mlir_tf = import_graphdef(                                                                                                                                                                                     
        concrete_func.graph.as_graph_def(add_shapes=True), "tf-standard-pipeline", False                                                                                                                           
    )                                                                                                                                                                                                              
                                                                                                                                                                                                                   
    print(mlir_tf) `
  • There is https://github.com/tensorflow/tensorflow/pull/49380 PR waiting for approval in which you can specify input and output options. All the input and output will be there in the main function argument.

  • Thanks, I'll keep an eye on that and try when merged!

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment