如何將PyTorch的ModuleList改成ModuleDict格式?

named_children()和children()的差別在於:前者可迭代取得ModuleList中的Name、Sequential;後者只有Sequential

for name, module in tasks.named_children():
print(type(name),type(module))
<class 'str'> <class 'torch.nn.modules.container.Sequential'>

for i in tasks.children():
print(type(i))
<class 'torch.nn.modules.container.Sequential'>

for idx, m in enumerate(tasks.modules()):
print(type(idx),type(m))
<class 'int'> <class 'torch.nn.modules.container.Conv' or 各種layer>
#先將name, module透過named_children取出後,搭配dict()轉成Dict格式task = dict((name,module) for name, module in tasks.named_children())#使用nn.ModuleDict將task這個Dict進行轉換,最後得到ModuleDict
unit = nn.ModuleDict({t: nn.ModuleDict({t: Multiply(num_channels) for t in tasks}) for t in tasks})

Written by

Machine Learning / Deep Learning / Python / Flutter cakeresume.com/yanwei-liu

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store