"""Convert keras model to tflite."""
import argparse
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.lite.python import interpreter as interpreter_wrapper
from util import plugin
import pdb
import glob
import random
def rep_data_gen():
# Get list of all images in train directory
image_path = '/root/autodl-tmp/mai22-real-time-video-sr/data/REDS/train/train_sharp_bicubic/X4'
jpg_file_list = glob.glob(image_path + '/**/*.jpg', recursive=True)
JPG_file_list = glob.glob(image_path + '/**/*.JPG', recursive=True)
png_file_list = glob.glob(image_path + '/**/*.png', recursive=True)
bmp_file_list = glob.glob(image_path + '/**/*.bmp', recursive=True)
quant_image_list = jpg_file_list + JPG_file_list + png_file_list + bmp_file_list
dataset_list = quant_image_list
quant_num = 10 # TODO: Replace 200s with an automatic way of reading network input size
for i in range(quant_num):
pick_me = random.choice(dataset_list)
image = tf.io.read_file(pick_me)
if pick_me.endswith('.jpg') or pick_me.endswith('.JPG'):
image = tf.io.decode_jpeg(image, channels=3)
elif pick_me.endswith('.png'):
image = tf.io.decode_png(image, channels=3)
elif pick_me.endswith('.bmp'):
image = tf.io.decode_bmp(image, channels=3)
image = tf.cast(image, tf.float32) / 255.
image = tf.expand_dims(image, 0)
yield [image]
def _parse_argument():
"""Return arguments for conversion."""
parser = argparse.ArgumentParser(description='Conversion.')
parser.add_argument('--model_path', help='Path of model file.', type=str, required=True)
parser.add_argument('--model_name', help='Name of model class.', type=str, required=True)
parser.add_argument(
'--input_shapes', help='Series of the input shapes split by `:`.', required=True
)
parser.add_argument('--ckpt_path', help='Path of checkpoint.', type=str, required=True)
parser.add_argument('--output_tflite', help='Path of output tflite.', type=str, required=True)
args = parser.parse_args()
return args
def main(args):
"""Run main function for converting keras model to tflite.
Args:
args: A `dict` contain augments.
"""
# Prepare model
model_builder = plugin.plugin_from_file(args.model_path, args.model_name, tf.keras.Model)
model = model_builder()
# Load checkpoint
ckpt = tf.train.Checkpoint(model=model)
ckpt.restore(args.ckpt_path).expect_partial()
input_tensors = []
for input_shape in args.input_shapes.split(':'):
input_shape = list(map(int, input_shape.split(',')))
input_shape = [None if x == -1 else x for x in input_shape]
input_tensor = Input(shape=input_shape[1:], batch_size=input_shape[0])
input_tensors.append(input_tensor)
# Save loaded model as 'Saved Model' format
model.save(f'{args.model_name}_pd')
# Load model from 'Saved Model' file
model = tf.keras.models.load_model(f'{args.model_name}_pd')
model(input_tensors)
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
concrete_func.inputs[0].set_shape(input_tensors[0].shape)
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
# Configure converter and convert the keras model
converter.experimental_new_converter = True
converter.experimental_new_quantizer = True
# This enables quantization
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# This sets the representative dataset for quantization
converter.representative_dataset = rep_data_gen
# This ensures that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops.
tf.lite.OpsSet.TFLITE_BUILTINS_INT8, # enable INT8 inference mode
]
# For full integer quantization, though supported types defaults to int8 only, we explicitly declare it for clarity.
converter.target_spec.supported_types = [tf.int8]
# These set the input/output tensors to tf.uint8
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
# Apply the convert
tflite_model = converter.convert()
# Save the tflite
with open(args.output_tflite, 'wb') as f:
f.write(tflite_model)
print(f'Optimized TFlite saved in:{args.output_tflite}')
# Get input output details of tflite model to know how to preprocess images
interpreter = interpreter_wrapper.Interpreter(model_path=f'{args.output_tflite}', num_threads=32)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
if __name__ == '__main__':
arguments = _parse_argument()
main(arguments)