为了方便模型在设备端的部署,有时我们需要将其转换为Tensorflow Lite的模型格式。 现有的转换方法主要会走以下的路径:
- 通过torch.onnx.export转换为ONNX的模型
- 通过onnx2tensorflow转换为tensorflow frozen model
- 通过tensorflow的TFLiteConverter转换为TFLite的模型
这条路径存在以下的不足:
- 转换路径较长,很容易产生问题
- 无法支持量化模型的转换
- 无法支持LSTM的模型
- onnx2tf的模型存在很多冗余的OP
为了解决上述的问题,我们实现了从PyTorch到TFLite的直接转换器。
- 支持PyTorch 1.6+
- 支持量化模型
- 支持LSTM
- 包含连续transpose和reshape消除、无用op删除等大量的优化pass
- 纯Python编写,易于维护
- operators: 转换器的大部分组件
- tflite : TFLite相关的类
- base.py : TFLite基础数据结构
- custom.py : TFLite自定义算子
- generated_ops.py : 从TFLite schema生成的Wrapper类
- transformable.py : 可转换算子,如BatchNorm、Conv2d等由多个TFLite算子组成的复合算子
- torch : PyTorch相关的类
- base.py : TorchScript解析所需的基础数据结构
- aten.py : ATen相关算子的翻译
- quantized.py : Quantized相关算子的翻译
- base.py : 通用算子的定义
- graph.py : 计算图相关的基础设施
- op_version.py : 设置算子版本
- optimize.py : 计算图优化
- tflite : TFLite相关的类
- schemas: schemas相关
- tflite : TFLite相关的schema
- schema_generated.py : TFLite schema 解析器
- torch : PyTorch相关的schema
- aten_schema.py : 从ATen schema生成的Wrapper类
- quantized_schema.py : 从Quantized schema生成的Wrapper类
- torchvision_schema.py : 从Torchvision schema生成的Wrapper类
- tflite : TFLite相关的schema
- base: 入口类TFLiteConverter