RError.com

RError.com Logo RError.com Logo

RError.com Navigation

  • 主页

Mobile menu

Close
  • 主页
  • 系统&网络
    • 热门问题
    • 最新问题
    • 标签
  • Ubuntu
    • 热门问题
    • 最新问题
    • 标签
  • 帮助
主页 / 问题 / 1575323
Accepted
Тима
Тима
Asked:2024-04-09 05:30:33 +0000 UTC2024-04-09 05:30:33 +0000 UTC 2024-04-09 05:30:33 +0000 UTC

如何跟踪Python类对象的创建?

  • 772

pytorch 深度学习库具有这种设计,允许您从不同层组装神经网络。

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x) # activation function
        x = self.conv2(x)
        return F.relu(x)

我正在编写自己的库,其外观与pytorch类似(一切都是出于教育目的,同时我每次都了解深度学习的新方面。)我对字符串感兴趣

self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)

我们创建两个类对象Conv2d并将它们的引用分配给该类对象Model。在学习过程(反向传播)期间,梯度从末端开始依次穿过此类网络的所有层(首先通过self.conv2,然后通过self.conv1)。这就引出了一个问题,算法如何了解该类对象的所有层Model?这个讨论部分回答了我的问题。我意识到还有另一个类在创建“层”类的对象时将对它们的引用收集到某个列表中。我自己是这样实现的

class Parameter:
    layers = []
    calling = dict()
    number_of_classes = 0

    def __init__(self, info):
        Parameter.layers.append(info[0])
        Parameter.calling[info[0]] = info[1:]


class Conv2d:
    def __init__(self, input_channels: int, output_channels: int, kernel_size: tuple, bias = True):
        # something happen here
        Parameter([self, self.kernel_array, self.bias_array])

这个方法有效。然而,目前只有Model一个类对象。一旦创建了另一个对象(另一个神经元),Parameter.layers该类的不同对象的引用就已经存储在其中Model,因此,一切都停止工作。

您有什么想法可以避免这种情况吗?

python
  • 2 2 个回答
  • 75 Views

2 个回答

  • Voted
  1. Best Answer
    kristal
    2024-04-09T17:21:42Z2024-04-09T17:21:42Z

    我马上说,我对神经元了解不多,我不知道pytorch内部是如何实现的,我是从纯python的角度看问题。

    我想到了两种解决方案,但您描述的行为与closures极其相似。因此,第一个解决方案是这个选项:

    def ParameterObj():
        layers_list = []
    
        class Parameter:
            layers = layers_list
            calling = dict()
            number_of_classes = 0
    
            def __init__(self, info):
                Parameter.layers.append(info[0])
                Parameter.calling[info[0]] = info[1:]
    
        return Parameter
    
    
    def Conv2dObj():
        constructor_Parameter = ParameterObj()
    
        class Conv2d:
            def __init__(self,
                         input_channels: int,
                         output_channels: int,
                         kernel_size: tuple,
                         bias=True
                         ):
                self.layers = constructor_Parameter([self, input_channels, output_channels])
    
        return Conv2d
    
    
    class Model:
        def __init__(self, *args):
            self._constructor_Conv2d = Conv2dObj()
            self.conv1 = self._constructor_Conv2d(*args)
            self.conv2 = self._constructor_Conv2d(*[i + 1 for i in args])
    
    
    m1 = Model(1, 2, 3)
    print(m1.conv1.layers.layers)
    # [<__main__.Conv2dObj.<locals>.Conv2d object at 0x00CB41B0>, <__main__.Conv2dObj.<locals>.Conv2d object at 0x00CB4210>]
    m2 = Model(1, 2, 3)
    print(m2.conv1.layers.layers)
    # [<__main__.Conv2dObj.<locals>.Conv2d object at 0x00CB4330>, <__main__.Conv2dObj.<locals>.Conv2d object at 0x00CB4370>]
    

    可以看到,不同数组中的对象在内存中的位置是不同的。到底这是怎么回事?创建类的实例时Model,会为其创建一个字段_constructor_Conv2d,其中(逻辑上)存储类的构造函数Conv2d,但有一个小功能。使用此构造函数创建实例时,它将Parameter.layers仅包含对使用此构造函数创建的那些对象的引用。如果您创建一个新的构造函数,根据您的问题,所有链接都将落入不同的列表中。这是因为我们以相同的方式为类创建了一个构造函数,并且在我们关闭的Parameter类中,因此对于每个新的构造函数,都会创建一个新的空数组(如果仍然不清楚,请阅读有关闭包的内容)。Parameterlayers_listParameterObj

    对于第二个想法,代码将是多余的。这就是要点:您可以Model创建一个类的实例Parameter,并在创建实例时简单地传递它,其中将Conv2d通过引用Parameter.layers添加您需要的数据。好吧,不要忘记layers在 init 中初始化它,否则在当前形式下,该列表对于该类的所有实例都是通用的。

    第二种选择可能更简单、更清晰,但如果我们想象未来的类链会变得非常大,那么第一种选择对我来说似乎更可取。

    PS 一般来说,您可以Model将其包装在函数中以关闭它constructor_Conv2d,而不是在类中创建额外的字段。

    • 1
  2. Тима
    2024-04-10T09:27:52Z2024-04-10T09:27:52Z

    非常感谢@kristal 的回答。正是在闭包的帮助下,我才能够自己实现一切。

    在一个单独的模块中,我制作了相同的辅助实体参数,但作为闭包

    def ParameterObj():
        class Parameter:
            layers = []
            calling = dict()
            def __init__(self, info):
                Parameter.layers.append(info[0])
                Parameter.calling[info[0]] = info[1:]
        return Parameter
    
    

    在主模块中,我决定使用全局参数变量,当初始化子类(模块的子级)的对象时以及__call__.

    由于学习(反向传播)仅在直接通过网络(前向传播,调用__call__)之后发生,因此预计全局 Parameter 变量将准确引用在创建新对象时创建的 Parameter 类,相应地,在调用时self._constructor_Parameter = ParameterObj()

    Parameter = None
    
    class Module:
        def __init__(self):
            self._constructor_Parameter = ParameterObj()
            global Parameter
            Parameter = self._constructor_Parameter
    
    class Conv:
        def __init__(self, input_channels: int, output_channels: int, kernel_size: tuple, bias = True):
            # something happen here
            Parameter([self, self.filter_array, np.zeros(self.n_filters)])
    

    现在创建一个神经网络看起来像这样

    import nn
    
    class SimpleNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Сonv(3, 5)
            self.act1 = nn.Relu()
            self.conv2 = nn.Conv(5, 5)
            self.sftmx = nn.Softmax()
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.act1(x)
            x = self.conv2(x)
            x = self.sftmx(x)
            return x
    
    model = SimpleNet()
    prediction = model(x)
    
    • 0

相关问题

  • 是否可以以某种方式自定义 QTabWidget?

  • telebot.anihelper.ApiException 错误

  • Python。检查一个数字是否是 3 的幂。输出 无

  • 解析多个响应

  • 交换两个数组的元素,以便它们的新内容也反转

Sidebar

Stats

  • 问题 10021
  • Answers 30001
  • 最佳答案 8000
  • 用户 6900
  • 常问
  • 回答
  • Marko Smith

    我看不懂措辞

    • 1 个回答
  • Marko Smith

    请求的模块“del”不提供名为“default”的导出

    • 3 个回答
  • Marko Smith

    "!+tab" 在 HTML 的 vs 代码中不起作用

    • 5 个回答
  • Marko Smith

    我正在尝试解决“猜词”的问题。Python

    • 2 个回答
  • Marko Smith

    可以使用哪些命令将当前指针移动到指定的提交而不更改工作目录中的文件?

    • 1 个回答
  • Marko Smith

    Python解析野莓

    • 1 个回答
  • Marko Smith

    问题:“警告:检查最新版本的 pip 时出错。”

    • 2 个回答
  • Marko Smith

    帮助编写一个用值填充变量的循环。解决这个问题

    • 2 个回答
  • Marko Smith

    尽管依赖数组为空,但在渲染上调用了 2 次 useEffect

    • 2 个回答
  • Marko Smith

    数据不通过 Telegram.WebApp.sendData 发送

    • 1 个回答
  • Martin Hope
    Alexandr_TT 2020年新年大赛! 2020-12-20 18:20:21 +0000 UTC
  • Martin Hope
    Alexandr_TT 圣诞树动画 2020-12-23 00:38:08 +0000 UTC
  • Martin Hope
    Air 究竟是什么标识了网站访问者? 2020-11-03 15:49:20 +0000 UTC
  • Martin Hope
    Qwertiy 号码显示 9223372036854775807 2020-07-11 18:16:49 +0000 UTC
  • Martin Hope
    user216109 如何为黑客设下陷阱,或充分击退攻击? 2020-05-10 02:22:52 +0000 UTC
  • Martin Hope
    Qwertiy 并变成3个无穷大 2020-11-06 07:15:57 +0000 UTC
  • Martin Hope
    koks_rs 什么是样板代码? 2020-10-27 15:43:19 +0000 UTC
  • Martin Hope
    Sirop4ik 向 git 提交发布的正确方法是什么? 2020-10-05 00:02:00 +0000 UTC
  • Martin Hope
    faoxis 为什么在这么多示例中函数都称为 foo? 2020-08-15 04:42:49 +0000 UTC
  • Martin Hope
    Pavel Mayorov 如何从事件或回调函数中返回值?或者至少等他们完成。 2020-08-11 16:49:28 +0000 UTC

热门标签

javascript python java php c# c++ html android jquery mysql

Explore

  • 主页
  • 问题
    • 热门问题
    • 最新问题
  • 标签
  • 帮助

Footer

RError.com

关于我们

  • 关于我们
  • 联系我们

Legal Stuff

  • Privacy Policy

帮助

© 2023 RError.com All Rights Reserve   沪ICP备12040472号-5