0%

Pytorch网络构造Tips

关于网络的parameters

Pytorch中,每一个网络继承于nn.Module类,当实例化之后,是通过维护一下8个字典来实现各种网络功能的:

1
2
3
4
5
6
7
8
_parameters
_buffers
_backward_hooks
_forward_hooks
_forward_pre_hooks
_state_dict_hooks
_load_state_dict_pre_hooks
_modules

在网络初始化__init__过程中,所有类内变量的定义都会通过__setattr__方法在__dict__中进行注册,而nn.Module重写了注册方法,将所有类内变量中,类型派生于Parameter的变量归属到_parameters字典中,这就解释了为什么使用一个list来存放网络的每一层会导致网络中的parameters为空。

此外,在获取参数时,nn.Module是通过遍历整个_modules字典来实现的,因此在定义时可以使用nn.ModuleList类型来替代list类型存放多个网络层。

Reference

[1]Gemfield, “详解Pytorch中的网络构造,” 知乎专栏, Jan. 04, 2019. https://zhuanlan.zhihu.com/p/53927068 (accessed Sep. 26, 2022).

-------------本 文 结 束 啦 感 谢 您 的 阅 读-------------

欢迎关注我的其它发布渠道