tflite model input data type mismatch?

Jarvis Irving

New member
  1. Mobile device : OnePlus 7T
  2. APP version: 5.0.3.
  3. TF version: 2.11

I have converted my pretrained model to an int8 quantized tflite model, but I encountered some problems as shown in following shotcuts. I suspect that the type of input data (maybe uint8/int8) conflicts with the required float32 format for input data.

Can anyone provide some help or advice? Thank you in advance:):)!

  • Inputs and Outputs information of my tflite model:

  • Problems:question_2.jpgquestion_3.jpgquestion_4.jpgquestion_1.jpg

Jarvis Irving

New member
I attempted to correct my code based on the links you provided. However, when I ran the model through NNAPI, I encountered an issue with node 98 in the TFLite model. Strangely, this node is not present in the converted model : ( !
I have attached a zip file of the TFLite model for your reference, which you can view using Netron. Do you have any suggestions on how to address this issue? Thank U.

  • Code of converting to tflite
    """Convert keras model to tflite."""
    import argparse
    import tensorflow as tf
    from tensorflow.keras.layers import Input
    from tensorflow.lite.python import interpreter as interpreter_wrapper
    from util import plugin
    import pdb
    import glob
    import random
    def rep_data_gen():
        # Get list of all images in train directory
        image_path = '/root/autodl-tmp/mai22-real-time-video-sr/data/REDS/train/train_sharp_bicubic/X4'
        jpg_file_list = glob.glob(image_path + '/**/*.jpg', recursive=True)
        JPG_file_list = glob.glob(image_path + '/**/*.JPG', recursive=True)
        png_file_list = glob.glob(image_path + '/**/*.png', recursive=True)
        bmp_file_list = glob.glob(image_path + '/**/*.bmp', recursive=True)
        quant_image_list = jpg_file_list + JPG_file_list + png_file_list + bmp_file_list
        dataset_list = quant_image_list
        quant_num = 10  # TODO: Replace 200s with an automatic way of reading network input size
        for i in range(quant_num):
            pick_me = random.choice(dataset_list)
            image =
            if pick_me.endswith('.jpg') or pick_me.endswith('.JPG'):
                image =, channels=3)
            elif pick_me.endswith('.png'):
                image =, channels=3)
            elif pick_me.endswith('.bmp'):
                image =, channels=3)
            image = tf.cast(image, tf.float32) / 255.
            image = tf.expand_dims(image, 0)
            yield [image]
    def _parse_argument():
        """Return arguments for conversion."""
        parser = argparse.ArgumentParser(description='Conversion.')
        parser.add_argument('--model_path', help='Path of model file.', type=str, required=True)
        parser.add_argument('--model_name', help='Name of model class.', type=str, required=True)
            '--input_shapes', help='Series of the input shapes split by `:`.', required=True
        parser.add_argument('--ckpt_path', help='Path of checkpoint.', type=str, required=True)
        parser.add_argument('--output_tflite', help='Path of output tflite.', type=str, required=True)
        args = parser.parse_args()
        return args
    def main(args):
        """Run main function for converting keras model to tflite.
            args: A `dict` contain augments.
        # Prepare model
        model_builder = plugin.plugin_from_file(args.model_path, args.model_name, tf.keras.Model)
        model = model_builder()
        # Load checkpoint
        ckpt = tf.train.Checkpoint(model=model)
        input_tensors = []
        for input_shape in args.input_shapes.split(':'):
            input_shape = list(map(int, input_shape.split(',')))
            input_shape = [None if x == -1 else x for x in input_shape]
            input_tensor = Input(shape=input_shape[1:], batch_size=input_shape[0])
        # Save loaded model as 'Saved Model' format'{args.model_name}_pd')
        # Load model from 'Saved Model' file
        model = tf.keras.models.load_model(f'{args.model_name}_pd')
        concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])   
        # Configure converter and convert the keras model
        converter.experimental_new_converter = True
        converter.experimental_new_quantizer = True
        # This enables quantization
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        # This sets the representative dataset for quantization
        converter.representative_dataset = rep_data_gen
        # This ensures that if any ops can't be quantized, the converter throws an error
        converter.target_spec.supported_ops = [
            tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
            tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops.
            tf.lite.OpsSet.TFLITE_BUILTINS_INT8, # enable INT8 inference mode
        # For full integer quantization, though supported types defaults to int8 only, we explicitly declare it for clarity.
        converter.target_spec.supported_types = [tf.int8]   
        # These set the input/output tensors to tf.uint8
        converter.inference_input_type = tf.uint8
        converter.inference_output_type = tf.uint8   
        # Apply the convert
        tflite_model = converter.convert()
        # Save the tflite
        with open(args.output_tflite, 'wb') as f:
        print(f'Optimized TFlite saved in:{args.output_tflite}')
        # Get input output details of tflite model to know how to preprocess images
        interpreter = interpreter_wrapper.Interpreter(model_path=f'{args.output_tflite}', num_threads=32)
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
    if __name__ == '__main__':
        arguments = _parse_argument()
  • model structure shotcutmodel_1.png
  • problem shotcutquestion_5.jpg


    78.9 KB · Views: 0

Jarvis Irving

New member
I decomposed and tested each module of the model step by step, and finally identified the root cause of the problem. It was caused by custom pixel normalization demonstrated in the following code. It seems that the tf.reshape operation in conjuction with tf.keras.layers.LayerNormalization are incompatible with TFLite GPU Delegate and NNAPI.

I directly replace this part with tf.keras.layers.LayerNormalization and it works perfectly fine :) !
  • customed normalization
  • Python:
    def PixelNorm(inputs, epsilon=1e-8):
            # default inputs_shape: [B, H, W, C]
            inputs_shape = tf.shape(inputs)
            # flatten inputs to shape [B x H x W, C]
            assert inputs_shape[-1] is not None, "Inputs shape must have known channel dimension."
            flat_inputs = tf.reshape(inputs, [-1, inputs_shape[-1]])
            # LayerNormalization: apply normalization to every pixel across channel dimension
            norm_inputs = tf.keras.layers.LayerNormalization(axis=-1,epsilon=epsilon)(flat_inputs)
            # reshape norm_input to orginal shape [B, H, W, C]
            outputs = tf.reshape(norm_inputs, inputs_shape)
            return outputs