Tips for Constructing Pytorch Networks

This is an automatically translated post by LLM. The original post is in Chinese. If you find any translation errors, please leave a comment to help me improve the translation. Thanks!

About Network Parameters

In Pytorch, each network inherits from the nn.Module class and implements various network functions by maintaining the following eight dictionaries after instantiation:

1
2
3
4
5
6
7
8
_parameters
_buffers
_backward_hooks
_forward_hooks
_forward_pre_hooks
_state_dict_hooks
_load_state_dict_pre_hooks
_modules

During the network initialization process __init__, the definition of all class variables is registered in __dict__ through the __setattr__ method. nn.Module overrides the registration method and assigns all variables derived from the Parameter type in the class to the _parameters dictionary. This explains why using a list to store each layer of the network will result in empty parameters.

In addition, when obtaining parameters, nn.Module traverses the entire _modules dictionary. Therefore, nn.ModuleList can be used instead of list to store multiple network layers.

Modifying an Instantiated Network

When obtaining a pre-trained model from others, it is sometimes necessary to modify the network, such as resetting parameters or replacing certain modules in the network for retraining. One direct approach is to find the corresponding attribute in the class and directly assign and replace it.

First, to modify parameters, there are two approaches. Since modifying parameters does not change the network structure, the network parameter dictionary and import method provided by pytorch can be used directly:

1
2
3
state_dict = model.state_dict()
# change some parameters in state_dict
model.load_state_dict(state_dict)

Another approach is to use named_parameters() to perform corresponding modifications:

1
2
3
for name, params in model.named_parameters():
if name == "name of the layer to modify":
params.data=torch.zeros(params.data.shape)

The above are the methods for modifying parameters. Relatively speaking, modifying parameters is relatively simple, and pytorch also provides some methods for convenient parameter replacement. As for modifying the model, the implementation approach is more direct. Find the location of the layer in the model and replace it. named_modules() method is used here. Taking the open-source large language model llama-2-13b as an example, after the model is loaded, first check the names of each layer:

1
2
3
4
5
for name,layer in model.named_modules():
print(name, layer)

# If you are familiar with the network structure, you can also just look at the names
print([name for (name,module) in model.named_modules()])

Then find the name of the layer that needs to be replaced and replace it. Taking the embeding layer as an example:

1
2
from torch import nn
model.model.embed_tokens = nn.Linear(32000, 5120)

It can be seen that the embedding layer has been successfully replaced with a fully connected layer. Modifications to other layers follow suit.

Reference

[1] Gemfield, "Detailed Explanation of Network Construction in Pytorch," Zhihu Column, Jan. 04, 2019. https://zhuanlan.zhihu.com/p/53927068 (accessed Sep. 26, 2022).