RError.com

RError.com Logo RError.com Logo

RError.com Navigation

  • 主页

Mobile menu

Close
  • 主页
  • 系统&网络
    • 热门问题
    • 最新问题
    • 标签
  • Ubuntu
    • 热门问题
    • 最新问题
    • 标签
  • 帮助
主页 / 问题 / 614462
Accepted
Alexander Pozharskii
Alexander Pozharskii
Asked:2020-01-14 04:27:05 +0000 UTC2020-01-14 04:27:05 +0000 UTC 2020-01-14 04:27:05 +0000 UTC

如何优化在 theano 上使用池化/非池化索引?

  • 772

实际上,任务是尽可能准确地复制theano 上 SpatialMaxPooling 和 SpatialMaxUnpooling 层的行为。

在这种情况下,SpatialMaxUnpooling 只填充对应于相应 SpatialMaxPooling 中最大值索引的那些“单元格”。

例如 - 这是输入图像

池化前的图片

SpatialMaxPooling 将存储每个 2x2 区域中具有最大值的像素及其索引。

而 SpatialMaxUnpooling - 只会将值设置为与索引对应的那些像素。也就是说,输出将是

去池化后的图像

我发布了以下实现:

def pooling2d_2x2(self, x):
    reshaped = x.reshape([
        x.shape[0], x.shape[1], x.shape[2] // 2, 2, x.shape[3] // 2, 2
    ])
    max_values, max_indices = T.max_and_argmax(reshaped, (3,5,))
    return max_values, max_indices

def unpooling2d_2x2(self, pooled, indices):
    tmp_shape = [pooled.shape[0], pooled.shape[1], pooled.shape[2], 2, pooled.shape[3], 2]
    # Resize image
    resized = pooled.repeat(2, 2).repeat(2, 3)
    pooled_reshaped = resized.reshape(tmp_shape)
    # Resize indices
    indices_repeaten = indices.repeat(2, 2).repeat(2, 3).reshape(tmp_shape)
    # Calculate output
    result = pooled_reshaped * 0.0
    result = T.set_subtensor(result[:, :, :, 0, :, 0],
                             pooled_reshaped[:, :, :, 0, :, 0] * T.eq(indices_repeaten[:, :, :, 0, :, 0], 0))
    result = T.set_subtensor(result[:, :, :, 0, :, 1],
                             pooled_reshaped[:, :, :, 0, :, 1] * T.eq(indices_repeaten[:, :, :, 0, :, 1], 1))
    result = T.set_subtensor(result[:, :, :, 1, :, 0],
                             pooled_reshaped[:, :, :, 1, :, 0] * T.eq(indices_repeaten[:, :, :, 1, :, 0], 2))
    result = T.set_subtensor(result[:, :, :, 1, :, 1],
                             pooled_reshaped[:, :, :, 1, :, 1] * T.eq(indices_repeaten[:, :, :, 1, :, 1], 3))
    result_shape = [pooled.shape[0], pooled.shape[1], pooled.shape[2] * 2, pooled.shape[3] * 2]
    return result.reshape(result_shape)

但她在速度上并不出色(顺便说一句,我不会拒绝建议 - 如何配置文件)。因此问题 - 这里可以改进什么?

cuda
  • 1 1 个回答
  • 10 Views

1 个回答

  • Voted
  1. Best Answer
    Alexander Pozharskii
    2020-01-14T05:22:49Z2020-01-14T05:22:49Z

    下一个替换(据我了解 theano(这可能是一个非常平庸的理解 :-))——这里我们不再为新张量分配内存,仅指“增加的”输入张量和索引)稍微增加了速度。但是,也许 - 还有其他可能的改进吗?

    def unpooling2d_2x2(self, pooled, indices):
        tmp_shape = [pooled.shape[0], pooled.shape[1], pooled.shape[2], 2, pooled.shape[3], 2]
        # Resize image
        resized = pooled.repeat(2, 2).repeat(2, 3)
        pooled_reshaped = resized.reshape(tmp_shape)
        # Resize indices
        indices_repeaten = indices.repeat(2, 2).repeat(2, 3).reshape(tmp_shape)
        # Calculate output
        result = pooled_reshaped * 0.0
        # Calculate output
        result = T.set_subtensor(pooled_reshaped[:, :, :, 0, :, 0],
                                 pooled_reshaped[:, :, :, 0, :, 0] * T.eq(indices_repeaten[:, :, :, 0, :, 0], 0))
        result = T.set_subtensor(result[:, :, :, 0, :, 1],
                                 pooled_reshaped[:, :, :, 0, :, 1] * T.eq(indices_repeaten[:, :, :, 0, :, 1], 1))
        result = T.set_subtensor(result[:, :, :, 1, :, 0],
                                 pooled_reshaped[:, :, :, 1, :, 0] * T.eq(indices_repeaten[:, :, :, 1, :, 0], 2))
        result = T.set_subtensor(result[:, :, :, 1, :, 1],
                                 pooled_reshaped[:, :, :, 1, :, 1] * T.eq(indices_repeaten[:, :, :, 1, :, 1], 3))
        result_shape = [pooled.shape[0], pooled.shape[1], pooled.shape[2] * 2, pooled.shape[3] * 2]
    
    • 1

相关问题

Sidebar

Stats

  • 问题 10021
  • Answers 30001
  • 最佳答案 8000
  • 用户 6900
  • 常问
  • 回答
  • Marko Smith

    Python 3.6 - 安装 MySQL (Windows)

    • 1 个回答
  • Marko Smith

    C++ 编写程序“计算单个岛屿”。填充一个二维数组 12x12 0 和 1

    • 2 个回答
  • Marko Smith

    返回指针的函数

    • 1 个回答
  • Marko Smith

    我使用 django 管理面板添加图像,但它没有显示

    • 1 个回答
  • Marko Smith

    这些条目是什么意思,它们的完整等效项是什么样的

    • 2 个回答
  • Marko Smith

    浏览器仍然缓存文件数据

    • 1 个回答
  • Marko Smith

    在 Excel VBA 中激活工作表的问题

    • 3 个回答
  • Marko Smith

    为什么内置类型中包含复数而小数不包含?

    • 2 个回答
  • Marko Smith

    获得唯一途径

    • 3 个回答
  • Marko Smith

    告诉我一个像幻灯片一样创建滚动的库

    • 1 个回答
  • Martin Hope
    Air 究竟是什么标识了网站访问者? 2020-11-03 15:49:20 +0000 UTC
  • Martin Hope
    Алексей Шиманский 如何以及通过什么方式来查找 Javascript 代码中的错误? 2020-08-03 00:21:37 +0000 UTC
  • Martin Hope
    Qwertiy 号码显示 9223372036854775807 2020-07-11 18:16:49 +0000 UTC
  • Martin Hope
    user216109 如何为黑客设下陷阱,或充分击退攻击? 2020-05-10 02:22:52 +0000 UTC
  • Martin Hope
    Qwertiy 并变成3个无穷大 2020-11-06 07:15:57 +0000 UTC
  • Martin Hope
    koks_rs 什么是样板代码? 2020-10-27 15:43:19 +0000 UTC
  • Martin Hope
    user207618 Codegolf——组合选择算法的实现 2020-10-23 18:46:29 +0000 UTC
  • Martin Hope
    Sirop4ik 向 git 提交发布的正确方法是什么? 2020-10-05 00:02:00 +0000 UTC
  • Martin Hope
    faoxis 为什么在这么多示例中函数都称为 foo? 2020-08-15 04:42:49 +0000 UTC
  • Martin Hope
    Pavel Mayorov 如何从事件或回调函数中返回值?或者至少等他们完成。 2020-08-11 16:49:28 +0000 UTC

热门标签

javascript python java php c# c++ html android jquery mysql

Explore

  • 主页
  • 问题
    • 热门问题
    • 最新问题
  • 标签
  • 帮助

Footer

RError.com

关于我们

  • 关于我们
  • 联系我们

Legal Stuff

  • Privacy Policy

帮助

© 2023 RError.com All Rights Reserve   沪ICP备12040472号-5