PyTorch如何檢查模型的參數量及模型檔案大小?
4 min readJul 9, 2021
1. torchinfo
安裝
pip install torchinfo
使用
from torchinfo import summary
model = ConvNet()
batch_size = 16
summary(model, input_size=(batch_size, 1, 28, 28))
輸出
====================================================================
Layer Input Shape Output Shape Param # Mult-Adds
====================================================================.
.
.
略
.
.
.
====================================================================Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 3.41
========================================================================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.40
Params size (MB): 0.09
Estimated Total Size (MB): 0.51
========================================================================================================================================
2.thop
pip install thop# 官方簡單範例from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))
# 官方詳細範例import torch
from torchvision import models
from thop.profile import profilemodel_names = sorted(name for name in models.__dict__ if
name.islower() and not name.startswith("__") # and "inception" in name
and callable(models.__dict__[name]))print("%s | %s | %s" % ("Model", "Params(M)", "FLOPs(G)"))
print("---|---|---")device = "cpu"
if torch.cuda.is_available():
device = "cuda"for name in model_names:
model = models.__dict__[name]().to(device)
dsize = (1, 3, 224, 224)
if "inception" in name:
dsize = (1, 3, 299, 299)
inputs = torch.randn(dsize).to(device)
total_ops, total_params = profile(model, (inputs,), verbose=False)
print("%s | %.2f | %.2f" % (name, total_params / (1000 ** 2), total_ops / (1000 ** 3)))