首页 >国际 > > 正文

使用PyTorch 2.0 加速Hugging Face和TIMM库的模型

2022-12-25 03:18:19

点蓝色字关注“机器学习算法工程师”


(相关资料图)

设为星标,干货直达!

PyTorch 2.0引入了**torch.compile()**来加速模型,这篇文章我们将介绍如何使用**torch.compile()**来加速Hugging Face和TIMM库的模型。

torch.compile() 使得尝试不同的编译器后端变得容易,从而使用单行装饰器 torch.compile() 使 PyTorch 代码更快。它可以直接在 nn.Module 上工作,作为 torch.jit.script() 的直接替代品,但不需要您进行任何源代码更改。我们希望这一行代码更改能够为您已经运行的绝大多数模型提供 30%-2 倍的训练时间加速。

opt_module=torch.compile(module)

torch.compile 支持任意 PyTorch 代码、控制流、变异,并带有对动态形状的实验性支持。我们对这一发展感到非常兴奋,我们将其称为 PyTorch 2.0。

这个版本对我们来说不同的是,我们已经对一些最流行的开源 PyTorch 模型进行了基准测试,并获得了 30% 到 2 倍的大幅加速(见https://github.com/pytorch/torchdynamo/issues/681) 。

这里没有技巧,我们已经 pip 安装了流行的库,比如https://github.com/huggingface/transformers, https://github.com/huggingface/accelerate 和 https://github.com/rwightman/pytorch-image-models等流行的库,然后对它们运行 torch.compile() 就可以了。

很难同时获得性能和便利性,但这就是核心团队发现 PyTorch 2.0 如此令人兴奋的原因。Hugging Face 团队也很兴奋,用他们的话说:

TIMM 的主要维护者 Ross Wightman:“PT 2.0 开箱即用,适用于推理和训练工作负载的大多数 timm 模型,无需更改代码。”

Sylvain Gugger 是 transformers 和 accelerate 的主要维护者:“只需添加一行代码,PyTorch 2.0 就可以在训练 Transformers 模型时提供 1.5 到 2.x 的加速。这是引入混合精度训练以来最激动人心的事情!”

本教程将向您展示如何使用这些加速,这样您就可以像我们一样对 PyTorch 2.0 感到兴奋。

安装教程

对于 GPU(新一代 GPU 的性能会大大提高):

pip3installnumpy--pretorch--force-reinstall--extra-index-urlhttps://download.pytorch.org/whl/nightly/cu117

对于CPU:

pip3install--pretorch--extra-index-urlhttps://download.pytorch.org/whl/nightly/cpu

当安装好后,你可以通过以下方式来进行验证:

gitclonehttps://github.com/pytorch/pytorchcdtools/dynamopythonverify_dynamo.py

另外一种安装方式是采用docker,我们还在 PyTorch nightly 二进制文件中提供了所有必需的依赖项,您可以使用它们下载:

dockerpullghcr.io/pytorch/pytorch-nightly

对于临时实验,只需确保您的容器可以访问所有 GPU:

dockerrun--gpusall-itghcr.io/pytorch/pytorch-nightly:latest/bin/bash

使用教程

让我们从一个简单的例子开始,一步步把事情复杂化。请注意,您的 GPU 越新,您可能会看到更显着的加速。

importtorchdeffn(x,y):a=torch.sin(x).cuda()b=torch.sin(y).cuda()returna+bnew_fn=torch.compile(fn,backend="inductor")input_tensor=torch.randn(10000).to(device="cuda:0")a=new_fn()

这个例子实际上不会运行得更快,但它具有教育意义。

以 torch.cos() 和 torch.sin() 为特色的示例,它们是逐点操作的示例,因为它们在向量上逐个元素地进行操作。你可能真正想要使用的一个更著名的逐点运算是类似 torch.relu() 的东西。eager模式下的逐点操作不是最优的,因为每个操作都需要从内存中读取一个张量,进行一些更改,然后写回这些更改。

PyTorch 2.0 为您所做的最重要的优化是融合。

回到我们的示例,我们可以将 2 次读取和 2 次写入变成 1 次读取和 1 次写入,这对于较新的 GPU 来说尤其重要,因为瓶颈是内存带宽(您可以多快地向 GPU 发送数据)而不是计算(您的速度有多快) GPU 可以处理浮点运算)。

PyTorch 2.0 为您做的第二个最重要的优化是 CUDA graphs。CUDA graphs有助于消除从 python 程序启动单个内核的开销。

torch.compile() 支持许多不同的后端,但我们特别兴奋的一个是生成 Triton 内核(https://github.com/openai/triton,用 Python 编写的,但性能优于绝大多数手写的 CUDA 内核)的 Inductor。假设我们上面的示例名为 trig.py,我们实际上可以通过运行来检查代码生成的 triton 内核:

TORCHINDUCTOR_TRACE=1pythontrig.py

@pointwise(size_hints=[16384],filename=__file__,meta={"signature":{0:"*fp32",1:"*fp32",2:"i32"},"device":0,"constants":{},"configs":[instance_descriptor(divisible_by_16=(0,1,2),equal_to_1=())]})@triton.jitdefkernel(in_ptr0,out_ptr0,xnumel,XBLOCK:tl.constexpr):xnumel=10000xoffset=tl.program_id(0)*XBLOCKxindex=xoffset+tl.reshape(tl.arange(0,XBLOCK),[XBLOCK])xmask=xindex

你可以验证融合这两个 sins 确实发生了,因为这两个 sin 操作发生在一个单一的 Triton 内核中,并且临时变量保存在寄存器中,可以非常快速地访问。

下一步,让我们尝试一个真实的模型,比如来自 PyTorch hub 的 resnet50。

importtorchmodel=torch.hub.load("pytorch/vision:v0.10.0","resnet18",pretrained=True)opt_model=torch.compile(model,backend="inductor")model(torch.randn(1,3,64,64))

如果您实际运行,您可能会惊讶于第一次运行很慢,那是因为正在编译模型。后续运行会更快,因此在开始对模型进行基准测试之前预热模型是常见的做法。

您可能已经注意到我们如何在此处使用“inductor”显式传递编译器的名称,但它不是唯一可用的后端,您可以在 torch._dynamo.list_backends() 中运行以查看可用后端的完整列表。为了好玩,您应该尝试 aot_cudagraphs 或 nvfuser。

现在让我们做一些更有趣的事情,我们的社区经常使用来自 transformers (https://github.com/huggingface/transformers) 或 TIMM (https://github.com/rwightman/pytorch-image-models)的预训练模型和我们的设计之一PyTorch 2.0 的目标是任何新的编译器堆栈都需要开箱即用,可以与人们实际运行的绝大多数模型一起工作。因此,我们将直接从 Hugging Face hub 下载预训练模型并对其进行优化。

importtorchfromtransformersimportBertTokenizer,BertModel#Copypastedfromherehttps://huggingface.co/bert-base-uncasedtokenizer=BertTokenizer.from_pretrained("bert-base-uncased")model=BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")model=torch.compile(model)#Thisistheonlylineofcodethatwechangedtext="Replacemebyanytextyou"dlike."encoded_input=tokenizer(text,return_tensors="pt").to(device="cuda:0")output=model(**encoded_input)

如果您从模型和 encoded_input 中删除 to(device="cuda:0") ,那么 PyTorch 2.0 将生成 C++ 内核,这些内核将针对在您的 CPU 上运行进行优化。你可以检查 Triton 或 C++ 内核的 BERT,它们显然比我们上面的三角函数示例更复杂,但如果你了解 PyTorch,你也可以类似地浏览它并理解。

相同的代码也可以https://github.com/huggingface/accelerate 和 DDP 一起使用。

同样让我们尝试一个 TIMM 示例:

importtimmimporttorchmodel=timm.create_model("resnext101_32x8d",pretrained=True,num_classes=2)opt_model=torch.compile(model,backend="inductor")opt_model(torch.randn(64,3,7,7))

我们使用 PyTorch 的目标是构建一个广度优先的编译器,该编译器将加速人们在开源中运行的绝大多数实际模型。Hugging Face Hub 最终成为我们非常有价值的基准测试工具,确保我们所做的任何优化实际上都有助于加速人们想要运行的模型。

本文翻译自https://pytorch.org/blog/Accelerating-Hugging-Face-and-TIMM-models/

上一篇: 下一篇:
x
推荐阅读

使用PyTorch 2.0 加速Hugging Face和TIMM库的模型

2022-12-25

赤峰黄金(600988)12月23日主力资金净卖出8064.43万元

2022-12-24

新鲜牛肉怎么保存不坏 新鲜牛肉保存不坏的方法介绍

2022-12-23

养老规划要及早

2022-12-23

全球最资讯丨【机器学习】集成学习代码练习(随机森林、GBDT、XGBoost、LightGBM等)

2022-12-23

同仁堂: 同仁堂第九届董事会第十六次会议决议公告 快报

2022-12-22

环球头条:马来西亚山体滑坡事故遇难人数升至30人

2022-12-22

微动态丨2022 ZOL推荐 | 惠普Z系列大师本 ZBook Studio G9提供强劲生产力 获奖

2022-12-22

世界热推荐:梅花生物: 梅花生物2023年员工持股计划(草案)摘要

2022-12-21

天天精选!他们连夜分装紧俏药品缓解配药难,一盒布洛芬为4名患者退了烧

2022-12-21

实施好积极的财政政策和稳健的货币政策|今日热闻

2022-12-21

【世界速看料】胜通能源: 国元证券股份有限公司关于胜通能源股份有限公司2022年持续督导培训情况的报告

2022-12-20

每日速讯:北京地铁今起缩短发车间隔!10号线早高峰2到3分钟一趟

2022-12-20

浙江湖州:优化核酸检测与查验政策,做好医疗救治资源储备

2022-12-19

【环球时快讯】迎丰股份:拟收购绍兴布泰100%股权

2022-12-19

国产操作系统推出移动固态硬盘,双系统一插即用 天天快播

2022-12-19

许昌市建安区税务局打造“3公里”办税缴费服务圈

2022-12-19

唏嘘!阿杜:真怀念与老詹巅峰对决 已经四年没碰面|当前热讯

2022-12-18

特斯拉(TSLA.US)股价腰斩 分析师和股东将枪口对准马斯克

2022-12-17

中央经济工作会议:适时实施渐进式延迟法定退休年龄政策_实时焦点

2022-12-16

本周盘点(12.12-12.16):金枫酒业周涨3.70%,主力资金合计净流出911.77万元 焦点快报

2022-12-16

两市ETF两融余额减少1.9亿元 当前讯息

2022-12-16

马斯克抛售价值35亿美元的特斯拉股票 投资者其参与Twitter事务表示担忧

2022-12-15

桂东电力董秘回复:闽商石业借资是根据其发展需要作出的安排,闽商石业全体股东均按出资比例同等条件借资 当前短讯

2022-12-15

涨停雷达:新零售个股异动 青岛金王触及涨停

2022-12-15

每日观点:宝钢超级13Cr产品独家供货海南福山油田CCUS重点项目

2022-12-14

世界关注:副省长张广智莅济调研全域旅游工作

2022-12-14

每日快报!马应龙:12月13日融券卖出金额87.47万元,占当日流出金额的1.18%

2022-12-14

凉拌腐竹的做法(正确泡腐竹的方法)

2022-12-13

鞍钢股份:12月12日融券卖出金额27.49万元,占当日流出金额的0.4%

2022-12-13

报告:预计今年中国将接待入境游客超2000万人次 当前焦点

2022-12-12

将毕业设计“写”在乡村田间地头_环球快消息

2022-12-12

【环球时快讯】北京今起云量增多 明后天风力加大最高温或降至冰点以下

2022-12-11

歌尔股份:涉及具体客户或项目名称的问题,不便于评论

2022-12-09

白云山(600332)12月7日主力资金净买入1.46亿元 天天快资讯

2022-12-08

股票行情快报:阳光照明(600261)12月6日主力资金净卖出18.85万元 全球快资讯

2022-12-06

科创板解禁潮在即 高成长增添“惜售”底气

2022-07-20

总投资3172.5亿元 石家庄提前超额完成年度目标任务

2022-03-20

石家庄海关共签发RCEP原产地证书864份 货值3.9亿元

2022-03-20

蚌埠海关累计签发RCEP原产地证书35份 涉及金额2583.09万元

2022-03-20

绥化望奎以工业化思维为引领 推动肉类加工制造产业腾飞

2022-03-20

衡阳耒阳免费发放油茶苗 助推油茶产业稳步发展

2022-03-20

郴州安仁文旅项目集中开工 总投资1000万元

2022-03-20

2022年郴州计划重点推进文旅项目101个 总投资354亿元

2022-03-20

宿州泗县深入推进文旅融合发展 擦亮城市品牌

2022-03-20

汽车零部件产业“领头羊” 锦州力争一季度“开门红”

2022-03-20

油价或有望冲击“九元”大关 宁波新能源汽车市场如何

2022-03-20

从水塘到“云”端 全国最大高邮鸭养殖基地实现智慧养殖

2022-03-20

淡季不忘引流 京郊民宿市场有望迎来回暖

2022-03-20

镇江乡村一二三产业融合发展 闯出“镇江之路”

2022-03-20

总投资30亿元 盐城东台8个重大产业项目相继开工

2022-03-20

去年南京规上信息软件业企业实现营收7577.28亿元 同比增长10.3%

2022-03-20

2021年南京农业保险保费收入53.07亿元 同比增长19.13%

2022-03-20

安阳本土确诊病例上升至26例

2022-01-10

3次推迟婚期 满洲里抗疫民警兑现承诺:“我回来娶你了!”

2022-01-10

上海公安民警在岗位上迎接2022年“中国人民警察节”

2022-01-10

郑州核酸检测为中小学生开辟“绿色通道”

2022-01-10

反扒便衣警察“小曹”:藏在人海中的隐形“守护者”

2022-01-10

哥哥移植肾脏给病重弟弟 已在上海顺利康复

2022-01-10

网友与人裸聊被敲诈10万余元 被告人获刑5年

2022-01-10

1月10日起天津市暂停开展旅行社旅游业务活动

2022-01-10

“3·28”特大跨境电信网络诈骗案公开审理

2022-01-10

忠诚履职 守护万家灯火

2022-01-10

奥密克戎病例已涉天津、安阳 “动态清零”必须坚持!

2022-01-10

专家协作成功完成亲体肾移植 同“肾”兄弟顺利康复

2022-01-10

著名指挥陈燮阳携苏州交响乐团“相约北京”

2022-01-10

中国热科院选育出4个木薯新品种

2022-01-10

北京疾控:12月9日以来途经或旅居天津市人员请立即报备

2022-01-10

河南安阳本轮疫情累计报告确诊病例26例

2022-01-10

许勤批示黑土地保护不力问题:加快形成黑土地保护长效机制

2022-01-10

【挑战365天正能量速写画】第041期:当警娃难,当双警家庭的警娃更难

2022-01-10

重庆姐弟坠亡案两被告人5个月间聊天记录曝光

2022-01-10

因疫情防控措施落实不力 江苏金湖一超市被红牌警告

2022-01-10

江歌案一审判决刘鑫赔偿近70万元 有何依据?专家解读

2022-01-10

广东肇庆“毒驾连撞5车致1死”肇事司机被批捕

2022-01-10

一线工作近22年的缉毒警:我知道坏的是毒品不是人性

2022-01-10

青海保障门源地震后生活必需品应急物资

2022-01-10

江西最大文物倒卖案宣判:倒卖国家二级文物 9人获刑

2022-01-10

呼和浩特:寒假期间有条件的学校要开展校内托管服务

2022-01-10

广西东兴口岸恢复通关 入境需网上预约

2022-01-10

天津米面油存量由20天提高至30天 超市菜市场进货量翻倍

2022-01-10

天津市委市政府致全市父老乡亲的慰问信:我们一定能够打赢

2022-01-10

北京市十五届人大五次会议胜利闭幕

2022-01-10

“中国最后一个原始部落”翁丁老寨火灾原因公布

2022-01-10

天津:划定封控区 全市开展全员核酸检测

2022-01-10

重庆姐弟被生父扔下坠亡案上诉期结束 一审法院暂未收到两被告人上诉状

2022-01-10

子夜直击,天津寒天战“疫”

2022-01-10

兰州名师话“美育”:“尚乐立人”分层培优 以“美”润教

2022-01-10

中国边疆“北方第一所”:9名民警守护“生命禁区”

2022-01-10

江歌母亲江秋莲:尊重法院判决,法律认定在我意料之中

2022-01-10

河南安阳9日12时至24时新增11例本土确诊病例

2022-01-10

辟谣!网传“封控区管控区相继解封”通知并非西安

2022-01-10

铁路公安以110幅优秀书画作品庆祝人民警察节

2022-01-10

“中国最后一个原始部落”翁丁老寨火灾原因公布

2022-01-10

天津:划定封控区 全市开展全员核酸检测

2022-01-10

重庆姐弟被生父扔下坠亡案上诉期结束 一审法院暂未收到两被告人上诉状

2022-01-10

子夜直击,天津寒天战“疫”

2022-01-10

兰州名师话“美育”:“尚乐立人”分层培优 以“美”润教

2022-01-10

中国边疆“北方第一所”:9名民警守护“生命禁区”

2022-01-10

江歌母亲江秋莲:尊重法院判决,法律认定在我意料之中

2022-01-10