pytorch模型转torchscript

1,610次阅读
没有评论

文章目录
目的
方法
trace
script
目的
将pytorch模型转化成torchscript目的就是为了可以在c++环境中调用pytorch模型。
pytorch官方链接

方法
共有两种方法将pytorch模型转成torch script ,一种是trace,另一种是script。一版在模型内部没有控制流存在的话(if,for循环),直接用trace方法就可以了。如果模型内部存在控制流,那就需要用到script方法了。

trace
通过使用示例输入对模型的结构进行一次评估,并记录这些输入在模型中的变化过程,从而捕获模型的结构。

class MyModule(nn.Module):
def init(self):
super(MyModule,self).init()
self.conv1 = nn.C++onv2d(1,3,3)

def forward(self,x):
    x = self.conv1(x)
    return x

model = MyModule() # 实例化模型
trace_module = torch.jit.trace(model,torch.rand(1,1,224,224))
print(trace_module.code) # 查看模型结构
output = trace_module (torch.ones(1, 3, 224, 224)) # 测试
print(output)
trace_modult(‘model.pt’) # 模型保存

script
如果模型内部有控制流结构,用trace就会报错。

class MyModule(nn.Module):
def init(self):
super(MyModule,self).init()
self.conv1 = nn.Conv2d(1,3,3)
self.conv2 = nn.Conv2d(2,3,3)

def forward(self,x):
    b,c,h,w = x.shape
    if c ==1:
        x = self.conv1(x)
    else:
        x = self.conv2(x)
    return x

model = MyModule()

这样写会报错,因为有控制流

trace_module = torch.jit.trace(model,torch.rand(1,1,224,224))

此时应该用script方法

script_module = torch.jit.script(model)
print(script_module.code)
output = script_module(torch.rand(1,1,224,224))

正文完
可以使用微信扫码关注公众号(ID:xzluomor)
post-qrcode
 0
评论(没有评论)

文心AIGC

2024 年 1 月
1234567
891011121314
15161718192021
22232425262728
293031  
文心AIGC
文心AIGC
人工智能ChatGPT,AIGC指利用人工智能技术来生成内容,其中包括文字、语音、代码、图像、视频、机器人动作等等。被认为是继PGC、UGC之后的新型内容创作方式。AIGC作为元宇宙的新方向,近几年迭代速度呈现指数级爆发,谷歌、Meta、百度等平台型巨头持续布局
文章搜索
热门文章
最新评论
王光卫博客 王光卫博客 用户思维很有必要对用户进行数据分析
王光卫博客 王光卫博客 我们活得居然不如AI,唉
王光卫博客 王光卫博客 这又得开始存钱了
王光卫博客 王光卫博客 正在找能理解中国古汉语的AI :cry:
□惊叹号!! □惊叹号!! 可以领券
一路向北 一路向北 已经添加
一路向北 一路向北 申请友情链接: 名称:烙馍省钱 网址:https://tb-m.luomor.com/ 已添加文心AIGC
热评文章