pytorch模型转torchscript

828次阅读
没有评论

文章目录
目的
方法
trace
script
目的
将pytorch模型转化成torchscript目的就是为了可以在c++环境中调用pytorch模型。
pytorch官方链接

方法
共有两种方法将pytorch模型转成torch script ,一种是trace,另一种是script。一版在模型内部没有控制流存在的话(if,for循环),直接用trace方法就可以了。如果模型内部存在控制流,那就需要用到script方法了。

trace
通过使用示例输入对模型的结构进行一次评估,并记录这些输入在模型中的变化过程,从而捕获模型的结构。

class MyModule(nn.Module):
def init(self):
super(MyModule,self).init()
self.conv1 = nn.C++onv2d(1,3,3)

def forward(self,x):
    x = self.conv1(x)
    return x

model = MyModule() # 实例化模型
trace_module = torch.jit.trace(model,torch.rand(1,1,224,224))
print(trace_module.code) # 查看模型结构
output = trace_module (torch.ones(1, 3, 224, 224)) # 测试
print(output)
trace_modult(‘model.pt’) # 模型保存

script
如果模型内部有控制流结构,用trace就会报错。

class MyModule(nn.Module):
def init(self):
super(MyModule,self).init()
self.conv1 = nn.Conv2d(1,3,3)
self.conv2 = nn.Conv2d(2,3,3)

def forward(self,x):
    b,c,h,w = x.shape
    if c ==1:
        x = self.conv1(x)
    else:
        x = self.conv2(x)
    return x

model = MyModule()

这样写会报错,因为有控制流

trace_module = torch.jit.trace(model,torch.rand(1,1,224,224))

此时应该用script方法

script_module = torch.jit.script(model)
print(script_module.code)
output = script_module(torch.rand(1,1,224,224))

正文完
可以使用微信扫码关注公众号(ID:xzluomor)
post-qrcode
 
评论(没有评论)