如何將PyTorch的ModuleList改成ModuleDict格式?

Yanwei Liu
1 min readAug 21, 2020

--

named_children()和children()的差別在於:前者可迭代取得ModuleList中的Name、Sequential;後者只有Sequential

for name, module in tasks.named_children():
print(type(name),type(module))
<class 'str'> <class 'torch.nn.modules.container.Sequential'>

for i in tasks.children():
print(type(i))
<class 'torch.nn.modules.container.Sequential'>

for idx, m in enumerate(tasks.modules()):
print(type(idx),type(m))
<class 'int'> <class 'torch.nn.modules.container.Conv' or 各種layer>
#先將name, module透過named_children取出後,搭配dict()轉成Dict格式task = dict((name,module) for name, module in tasks.named_children())#使用nn.ModuleDict將task這個Dict進行轉換,最後得到ModuleDict
unit = nn.ModuleDict({t: nn.ModuleDict({t: Multiply(num_channels) for t in tasks}) for t in tasks})

--

--