PyTorch 模型保存

1.PyTorch 保存模型

import torch

1.1 方法 1:保存和加载模型参数

# 保存模型
torch.save(the_model.state_dict(), PATH)

# 加载模型
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

1.2 方法 2:保存和加载整个模型

# 保存模型
torch.save(the_model, PATH)

# 加载模型
the_model = torch.load(PATH)