pytorch 深度学习库具有这种设计,允许您从不同层组装神经网络。
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x) # activation function
x = self.conv2(x)
return F.relu(x)
我正在编写自己的库,其外观与pytorch类似(一切都是出于教育目的,同时我每次都了解深度学习的新方面。)我对字符串感兴趣
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
我们创建两个类对象Conv2d并将它们的引用分配给该类对象Model。在学习过程(反向传播)期间,梯度从末端开始依次穿过此类网络的所有层(首先通过self.conv2,然后通过self.conv1)。这就引出了一个问题,算法如何了解该类对象的所有层Model?这个讨论部分回答了我的问题。我意识到还有另一个类在创建“层”类的对象时将对它们的引用收集到某个列表中。我自己是这样实现的
class Parameter:
layers = []
calling = dict()
number_of_classes = 0
def __init__(self, info):
Parameter.layers.append(info[0])
Parameter.calling[info[0]] = info[1:]
class Conv2d:
def __init__(self, input_channels: int, output_channels: int, kernel_size: tuple, bias = True):
# something happen here
Parameter([self, self.kernel_array, self.bias_array])
这个方法有效。然而,目前只有Model一个类对象。一旦创建了另一个对象(另一个神经元),Parameter.layers该类的不同对象的引用就已经存储在其中Model,因此,一切都停止工作。
您有什么想法可以避免这种情况吗?
我马上说,我对神经元了解不多,我不知道pytorch内部是如何实现的,我是从纯python的角度看问题。
我想到了两种解决方案,但您描述的行为与closures极其相似。因此,第一个解决方案是这个选项:
可以看到,不同数组中的对象在内存中的位置是不同的。到底这是怎么回事?创建类的实例时
Model,会为其创建一个字段_constructor_Conv2d,其中(逻辑上)存储类的构造函数Conv2d,但有一个小功能。使用此构造函数创建实例时,它将Parameter.layers仅包含对使用此构造函数创建的那些对象的引用。如果您创建一个新的构造函数,根据您的问题,所有链接都将落入不同的列表中。这是因为我们以相同的方式为类创建了一个构造函数,并且在我们关闭的Parameter类中,因此对于每个新的构造函数,都会创建一个新的空数组(如果仍然不清楚,请阅读有关闭包的内容)。Parameterlayers_listParameterObj对于第二个想法,代码将是多余的。这就是要点:您可以
Model创建一个类的实例Parameter,并在创建实例时简单地传递它,其中将Conv2d通过引用Parameter.layers添加您需要的数据。好吧,不要忘记layers在 init 中初始化它,否则在当前形式下,该列表对于该类的所有实例都是通用的。第二种选择可能更简单、更清晰,但如果我们想象未来的类链会变得非常大,那么第一种选择对我来说似乎更可取。
PS 一般来说,您可以
Model将其包装在函数中以关闭它constructor_Conv2d,而不是在类中创建额外的字段。非常感谢@kristal 的回答。正是在闭包的帮助下,我才能够自己实现一切。
在一个单独的模块中,我制作了相同的辅助实体参数,但作为闭包
在主模块中,我决定使用全局参数变量,当初始化子类(模块的子级)的对象时以及
__call__.由于学习(反向传播)仅在直接通过网络(前向传播,调用
__call__)之后发生,因此预计全局 Parameter 变量将准确引用在创建新对象时创建的 Parameter 类,相应地,在调用时self._constructor_Parameter = ParameterObj()现在创建一个神经网络看起来像这样