TensorFlow SaveModel¶
为了将训练好的机器学习模型部署到各个目标平台(如服务器、移动端、嵌入式设备和浏览器等), 我们的第一步往往是将训练好的整个模型完整导出(序列化)为一系列标准格式的文件。在此基础上, 我们才可以在不同的平台上使用相对应的部署工具来部署模型文件。
TensorFlow 提供了统一模型导出格式
SaveModel
, 这是我们在 TensorFlow 2 中主要使用的导出格式。 这样我们可以以这一格式为中介,将训练好的模型部署到多种平台上.同时,基于历史原因,Keras 的 Sequential 和 Functional 模式也有自有的模型导出格式。
1.tf.train.Checkpoint: 变量的保存与恢复¶
很多时候,希望在模型训练完成后能将训练好的参数(变量)保存起来,这样在需要使用模型的其他地方载入模型和参数, 就能直接得到训练好的模型,保存模型有很多中方式:
Python 的序列化模块
pickle
存储model.variables
- 然而,TensorFlow 的变量类型
ResourceVariable
并不能被序列化- 语法:
import pickle
1.1 tf.train.Checkpoint 介绍¶
tf.train.Checkpoint
简介TensorFlow 提供了
tf.train.Checkpoint
这一强大的变量保存与恢复类,提供的方法可以保存和恢复 TensorFlow 中的大部分对象, 比如下面类的实例都可以被保存:tf.keras.optimizer
tf.Variable
tf.keras.Layer
tf.keras.Model
- Checkpointable State 的对象
tf.train.Checkpoint
使用方法方法:
save()
restore()
语法:
# 保存训练好的模型, 先声明一个 Checkpoint model = TrainedModel() checkpoint = tf.train.Checkpoint(myAwesomeModel = model, myAwesomeOptimizer = optimizer) checkpoint.save(save_path_with_prefix) # 载入保存的训练模型 model_to_be_restored = MyModel() # 待恢复参数的同一模型 checkpoint = tf.train.Checkpoint(myAwesomeModel = model_to_be_restored) checkpoint.restore(save_path_with_prefix_and_index) # 为了载入最近的一个模型文件, 返回目录下最近一次检查点的文件名 tf.train.latest_checkpoint(save_path)
Note
参数:
myAwesomeModel
: 待保存的模型 model 所取的任意键名,在恢复变量时还将使用这一键名myAwesomeOptimizer
: 待保存的模型 optimizer 所取的任意键名,在恢复变量时还将使用这一键名save_path_with_prefix
: 保存文件的目录+前缀save_path_with_prefix_and_index
: 之前保存的文件目录+前缀+序号
checkpoint.save("./model_save/model.ckpt")
: 会在模型保存的文件夹中生成三个文件:checkpoint
model.ckpt-1.index
model.ckpt-1.data-00000-of-00001
checkpoint.restore("./model/save/model.ckpt-1")
- 载入前缀为
model.ckpt
、序号为1
的文件来恢复模型
- 载入前缀为
1.2 tf.train.Checkpoint 代码框架¶
1.train.py 模型训练阶段
# 训练好的模型 model = MyModel() # 实例化 Checkpoint, 指定保存对象为 model(如果需要保存 Optimizer 的参数也可以加入) checkpoint = tf.train.Checkpoint(myModel = model) manager = tf.train.CheckpointManager(checkpoint, directory = "./save", checkpoint_name = "model.ckpt", max_to_keep = 10) # ...(模型训练代码) # 模型训练完毕后将参数保存到文件(也可以在模型训练过程中每隔一段时间就保存一次) if manager: manager.save(checkpoint_number = 100) else: checkpoint.save("./save/model.ckpt")
2.test.py 模型使用阶段
# 要使用的模型 model = MyModel() # 实例化 Checkpoint, 指定恢复对象为 model checkpoint = tf.train.Checkpoint(myModel = model) # 从文件恢复模型参数 checkpoint.restore(tf.train.latest_checkpoint("./save)) # ...(模型使用代码)
Note
tf.train.Checkpoint
(检查点)只保存模型的参数,不保存模型的计算过程, 因此一般用于在具有的模型源码时恢复之前训练好的模型参数。如果需要导出模型(无须源代码也能运行模型)。
2.使用 SaveModel 完整导出模型¶
作为模型导出格式的 SaveModel
包含了一个 TensorFlow 程序的完整信息: 不仅包含参数的权值,还包含计算的流程(计算图)。
当模型导出为 SaveModel 文件时,无须模型的源代码即可再次运行模型, 这使得 SaveModel
尤其适用于模型的分享和部署。
Keras 模型均可以方便地导出为 SaveModel
格式。不过需要注意的是,因为 SaveModel
基于计算图,
所以对于通过继承 tf.keras.Model
类建立的 Keras 模型来说,需要导出为 SaveModel
格式的方法(比如 call) 都需要
使用 @tf.function
修饰。
语法:
# 保存 tf.saved_model.save(model, "保存的目标文件夹名称") # 载入 model = tf.saved_model.load("保存的目标文件夹名称")
示例:
pass
3.Keras 自有的模型导出格式¶
示例:
curl -LO https://raw.githubcontent.com/keras-team/keras/master/examples/mnist_cnn.pymodel.save("mnist_cnn.h5")import keras keras.models.load_model("mnist_cnn.h5")