全国免费咨询:

13245491521

VR图标白色 VR图标黑色
X

中高端软件定制开发服务商

与我们取得联系

13245491521     13245491521

2019-07-06_两行代码统计模型参数量与FLOPs,这个PyTorch小工具值得一试

您的位置:首页 >> 新闻 >> 行业资讯

两行代码统计模型参数量与FLOPs,这个PyTorch小工具值得一试 机器之心报道 参与:思源 你的模型到底有多少参数,每秒的浮点运算到底有多少,这些你都知道吗?近日,GitHub 开源了一个小工具,它可以统计 PyTorch 模型的参数量与每秒浮点运算数(FLOPs)。有了这两种信息,模型大小控制也就更合理了。 其实模型的参数量好算,但浮点运算数并不好确定,我们一般也就根据参数量直接估计计算量了。但是像卷积之类的运算,它的参数量比较小,但是运算量非常大,它是一种计算密集型的操作。反观全连接结构,它的参数量非常多,但运算量并没有显得那么大。 此外,机器学习还有很多结构没有参数但存在计算,例如最大池化和 Dropout 等。因此,PyTorch-OpCounter 这种能直接统计 FLOPs 的工具还是非常有吸引力的。 PyTorch-OpCounter GitHub 地址:https://github.com/Lyken17/pytorch-OpCounter OpCouter PyTorch-OpCounter 的安装和使用都非常简单,并且还能定制化统计规则,因此那些特殊的运算也能自定义地统计进去。 我们可以使用 pip 简单地完成安装:pip install thop。不过 GitHub 上的代码总是最新的,因此也可以从 GitHub 上的脚本安装。 对于 torchvision 中自带的模型,Flops 统计通过以下几行代码就能完成: from torchvision.models import resnet50from thop import profile model = resnet50()input = torch.randn(1, 3, 224, 224)flops, params = profile(model, inputs=(input, )) 我们测试了一下 DenseNet-121,用 OpCouter 统计了参数量与运算量。API 的输出如下所示,它会告诉我们具体统计了哪些结构,它们的配置又是什么样的。 最后输出的浮点运算数和参数量分别为如下所示,换算一下就能知道 DenseNet-121 的参数量约有 798 万,计算量约有 2.91 GFLOPs。 flops: 2914598912.0parameters: 7978856.0 OpCouter 是怎么算的 我们可能会疑惑,OpCouter 到底是怎么统计的浮点运算数。其实它的统计代码在项目中也非常可读,从代码上看,目前该工具主要统计了视觉方面的运算,包括各种卷积、激活函数、池化、批归一化等。例如最常见的二维卷积运算,它的统计代码如下所示: def count_conv2d(m, x, y): x = x[0] cin = m.in_channels cout = m.out_channels kh, kw = m.kernel_size batch_size = x.size()[0] out_h = y.size(2) out_w = y.size(3) # ops per output element # kernel_mul = kh * kw * cin # kernel_add = kh * kw * cin - 1 kernel_ops = multiply_adds * kh * kw bias_ops = 1 if m.bias is not None else 0 ops_per_element = kernel_ops + bias_ops # total ops # num_out_elements = y.numel() output_elements = batch_size * out_w * out_h * cout total_ops = output_elements * ops_per_element * cin // m.groups m.total_ops=torch.Tensor([int(total_ops)]) 总体而言,模型会计算每一个卷积核发生的乘加运算数,再推广到整个卷积层级的总乘加运算数。 定制你的运算统计 有一些运算统计还没有加进去,如果我们知道该怎样算,那么就可以写个自定义函数。 class YourModule(nn.Module): # your definitiondef count_your_model(model, x, y): # your rule here input = torch.randn(1, 3, 224, 224)flops, params = profile(model, inputs=(input, ),custom_ops={YourModule:count_your_model}) 最后,作者利用这个工具统计了各种流行视觉模型的参数量与 FLOPs 量: 深度Pro 理论详解 | 工程实践 | 产业分析 | 行研报告 机器之心最新上线深度内容栏目,汇总AI深度好文,详解理论、工程、产业与应用。这里的每一篇文章,都需要深度阅读15分钟。 今日深度推荐 爱奇艺短视频分类技术解析 CVPR 2019提前看:少样本学习专题 万字综述,核心开发者全面解读PyTorch内部机制 点击图片,进入小程序深度Pro栏目 PC点击阅读原文,访问官网 更适合深度阅读 www.jiqizhixin.com/insight 每日重要论文、教程、资讯、报告也不想错过? 点击订阅每日精选 阅读原文

上一篇:2024-07-02_请大家先暂停手上的事,来庆祝《提案者2》上线了 下一篇:2023-11-05_重新审视Transformer:倒置更有效,真实世界预测的新SOTA出现了

TAG标签:

20
网站开发网络凭借多年的网站建设经验,坚持以“帮助中小企业实现网络营销化”为宗旨,累计为4000多家客户提供品质建站服务,得到了客户的一致好评。如果您有网站建设网站改版域名注册主机空间手机网站建设网站备案等方面的需求...
请立即点击咨询我们或拨打咨询热线:13245491521 13245491521 ,我们会详细为你一一解答你心中的疑难。
项目经理在线

相关阅读 更多>>

猜您喜欢更多>>

我们已经准备好了,你呢?
2022我们与您携手共赢,为您的企业营销保驾护航!

不达标就退款

高性价比建站

免费网站代备案

1对1原创设计服务

7×24小时售后支持

 

全国免费咨询:

13245491521

业务咨询:13245491521 / 13245491521

节假值班:13245491521()

联系地址:

Copyright © 2019-2025      ICP备案:沪ICP备19027192号-6 法律顾问:律师XXX支持

在线
客服

技术在线服务时间:9:00-20:00

在网站开发,您对接的直接是技术员,而非客服传话!

电话
咨询

13245491521
7*24小时客服热线

13245491521
项目经理手机

微信
咨询

加微信获取报价