文章目录
目的
方法
trace
script
目的
将pytorch模型转化成torchscript目的就是为了可以在c++环境中调用pytorch模型。
pytorch官方链接
方法
共有两种方法将pytorch模型转成torch script ,一种是trace,另一种是script。一版在模型内部没有控制流存在的话(if,for循环),直接用trace方法就可以了。如果模型内部存在控制流,那就需要用到script方法了。
trace
通过使用示例输入对模型的结构进行一次评估,并记录这些输入在模型中的变化过程,从而捕获模型的结构。
class MyModule(nn.Module):
def init(self):
super(MyModule,self).init()
self.conv1 = nn.C++onv2d(1,3,3)
def forward(self,x):
x = self.conv1(x)
return x
model = MyModule() # 实例化模型
trace_module = torch.jit.trace(model,torch.rand(1,1,224,224))
print(trace_module.code) # 查看模型结构
output = trace_module (torch.ones(1, 3, 224, 224)) # 测试
print(output)
trace_modult(‘model.pt’) # 模型保存
script
如果模型内部有控制流结构,用trace就会报错。
class MyModule(nn.Module):
def init(self):
super(MyModule,self).init()
self.conv1 = nn.Conv2d(1,3,3)
self.conv2 = nn.Conv2d(2,3,3)
def forward(self,x):
b,c,h,w = x.shape
if c ==1:
x = self.conv1(x)
else:
x = self.conv2(x)
return x
model = MyModule()
这样写会报错,因为有控制流
trace_module = torch.jit.trace(model,torch.rand(1,1,224,224))
此时应该用script方法
script_module = torch.jit.script(model)
print(script_module.code)
output = script_module(torch.rand(1,1,224,224))