RError.com

RError.com Logo RError.com Logo

RError.com Navigation

  • 主页

Mobile menu

Close
  • 主页
  • 系统&网络
    • 热门问题
    • 最新问题
    • 标签
  • Ubuntu
    • 热门问题
    • 最新问题
    • 标签
  • 帮助
主页 / 问题 / 1441137
Accepted
Сергей Маслобойщиков
Сергей Маслобойщиков
Asked:2022-08-21 01:28:16 +0000 UTC2022-08-21 01:28:16 +0000 UTC 2022-08-21 01:28:16 +0000 UTC

代码速度优化

  • 772

我正在尝试解决一个问题。 在此处输入图像描述

代码执行没有错误,但随着数据量的增加,速度降低,解决方案在某些时候不通过条件。帮助优化代码。也许还有另一种计算公式的选择。

谢谢你。

import numpy as np

ar = np.loadtxt('input.txt', skiprows=1)

ar.astype(dtype=np.float16)
ar = ar[ar[:, 0].argsort()]

s = 0
p = 0

un = np.unique(ar[:, 0])

for _ in un:

    for row in ar[ar[:, 0] == _][:, 1]:
        s += (
            ar[np.where((ar[:, 0] > _)
                        & ((ar[:, 1] > row)))].shape[0] +
            ar[np.where((ar[:, 0] > _)
                        & ((ar[:, 1] == row)))].shape[0] / 2
        )
        p += (
            ar[np.where(ar[:, 0] > _)].shape[0]
        )
print(round(s / p, 7))

我没有对代码进行太多优化,但我无法进一步减少时间。

import numpy as np

ar = np.loadtxt('input.txt', skiprows=1)

ar = ar[ar[:, 0].argsort()]

s = 0
p = 0

un = np.unique(ar[:, 0])
un_1, duplicate_count = np.unique(ar, axis=0, return_counts=True)
pivot_table = np.concatenate((un_1, (np.array([duplicate_count])).T), axis=1)

for _ in un:

    for row in pivot_table[pivot_table[:, 0] == _][:, 1]:

        s += (
                (
                        np.sum(pivot_table[np.where((pivot_table[:, 0] > _)
                                                    & (pivot_table[:, 1] > row))][:, 2])
                        + np.sum(pivot_table[np.where((pivot_table[:, 0] > _)
                                                      & (pivot_table[:, 1] == row))][:, 2])
                        / 2
                 )
                * pivot_table[np.where((pivot_table[:, 0] == _) & (pivot_table[:, 1] == row))][:, 2]
        )
        p += (
                np.sum(pivot_table[np.where(pivot_table[:, 0] > _)][:, 2])
                * pivot_table[np.where((pivot_table[:, 0] == _) & (pivot_table[:, 1] == row))][:, 2]
        )
        
print(round(float(s / p), 7))
python оптимизация
  • 1 1 个回答
  • 84 Views

1 个回答

  • Voted
  1. Best Answer
    Stanislav Volodarskiy
    2022-08-27T22:21:44Z2022-08-27T22:21:44Z

    你的算法是二次的。无论您如何优化它,任何比较所有记录对的算法都是相同的。

    一个快速的解决方案是基于一种用于计算排列/数组中反转数的算法 - 合并排序的修改计算 NlogN 中的反转数。我们对反转的总数不感兴趣,但是对于数组的每个元素,我们可以计算它前面有多少比它小的元素。

    输入对(t_i, y_i)按t_i(如果有点t_i相等,y_i则按降序排序)排序并用两个计数器填充:

    t_i - точное значение
    y_i - предсказанное значение
    lt_i - число элементов, стоящих перед текущим, таких что `y` меньше текущего
    eq_i - число элементов, стоящих перед текущим, таких что `y` равен текущему
    

    对按 排序的记录列表进行合并排序t_i构建按 排序的记录列表y_i。

    lt_i它分两次计算:首先,在归并排序中,计算较小或相等的元素,然后从它们中减去相等元素的数量。 eq_i计入一次通过已排序的列表。

    使用计数器lt_i,eq_i您可以在线性时间内计算 ROC-AUC。

    NlogN 算法的总复杂度。我无法访问检查系统,我将您的解决方案的性能与我在随机数据集上的性能进行了比较。检查了两组唯一值和重复值。在所有情况下,ROC-AUC 的前六位数都匹配。

    最大的困难是由重复t_i和y_i各种组合的处理引起的。细节我就不说了,太多了。

    def merge(a, b):
        c = []
        i = 0
        j = 0
        while i < len(a) and j < len(b):
            if a[i][1] <= b[j][1]:
                c.append(a[i])
                i += 1
            else:
                b[j][2] += i
                c.append(b[j])
                j += 1
        while i < len(a):
            c.append(a[i])
            i += 1
        while j < len(b):
            b[j][2] += i
            c.append(b[j])
            j += 1
        return c
    
    
    def merge_sort(a):
        if len(a) <= 1:
            return a
        h = len(a) // 2
        a1 = merge_sort(a[:h])
        a2 = merge_sort(a[h:])
        return merge(a1, a2)
    
    
    def fill_counters(a):
        a = merge_sort(a)
    
        prev_y = None
        c_y = 0
        for e in a:
            if e[1] == prev_y:
                c_y += 1
            else:
                c_y = 0
            e[2] -= c_y
            prev_y = e[1]
    
        prev_y = None
        c_y = 0
        c_t = 0
        prev_t = None
        for e in a:
            if e[1] == prev_y:
                c_y += 1
                if e[0] != prev_t:
                    c_t = c_y
            else:
                c_y = 0
                c_t = 0
            e[3] = c_t
            prev_y = e[1]
            prev_t = e[0]
    
    
    def main():
        array = [
            [float(t) for t in input().split()] + [0, 0]
            for _ in range(int(input()))
        ]
        array.sort(key=lambda e: (e[0], -e[1]))
        fill_counters(array)
    
        n = 0
        d = 0
        k = 0
        prev_t = None
        for i, e in enumerate(array):
            if e[0] != prev_t:
                k = i
            n += e[2] + e[3] / 2
            d += k
            prev_t = e[0]
        print(n / d)
    
    
    main()
    

    交货时间:

    размер   разница  время       время
    списка   ROC-AUC  работы      работы
       (N)            NumPy (c)   merge_sort (с)
                       
     10000  0.000000      3.794     0.119
     20000  0.000000     13.699     0.225
     30000  0.000000     27.586     0.318
     40000  0.000000     48.019     0.460
     50000  0.000000     75.546     0.643
     60000  0.000000    106.449     0.717
     70000  0.000000    149.666     0.892
     80000  0.000000    183.675     0.922
     90000  0.000000    230.494     1.073
    100000  0.000000    303.946     1.214
    

    PS print(round(s / p, 7)) - 使用格式化打印四舍五入的值更好。真正的小数不能在计算机中精确表示。您冒着打印未舍入结果的风险。所以更好:print(f'{s / p:.7f}')。

    • 1

相关问题

Sidebar

Stats

  • 问题 10021
  • Answers 30001
  • 最佳答案 8000
  • 用户 6900
  • 常问
  • 回答
  • Marko Smith

    我看不懂措辞

    • 1 个回答
  • Marko Smith

    请求的模块“del”不提供名为“default”的导出

    • 3 个回答
  • Marko Smith

    "!+tab" 在 HTML 的 vs 代码中不起作用

    • 5 个回答
  • Marko Smith

    我正在尝试解决“猜词”的问题。Python

    • 2 个回答
  • Marko Smith

    可以使用哪些命令将当前指针移动到指定的提交而不更改工作目录中的文件?

    • 1 个回答
  • Marko Smith

    Python解析野莓

    • 1 个回答
  • Marko Smith

    问题:“警告:检查最新版本的 pip 时出错。”

    • 2 个回答
  • Marko Smith

    帮助编写一个用值填充变量的循环。解决这个问题

    • 2 个回答
  • Marko Smith

    尽管依赖数组为空,但在渲染上调用了 2 次 useEffect

    • 2 个回答
  • Marko Smith

    数据不通过 Telegram.WebApp.sendData 发送

    • 1 个回答
  • Martin Hope
    Alexandr_TT 2020年新年大赛! 2020-12-20 18:20:21 +0000 UTC
  • Martin Hope
    Alexandr_TT 圣诞树动画 2020-12-23 00:38:08 +0000 UTC
  • Martin Hope
    Air 究竟是什么标识了网站访问者? 2020-11-03 15:49:20 +0000 UTC
  • Martin Hope
    Qwertiy 号码显示 9223372036854775807 2020-07-11 18:16:49 +0000 UTC
  • Martin Hope
    user216109 如何为黑客设下陷阱,或充分击退攻击? 2020-05-10 02:22:52 +0000 UTC
  • Martin Hope
    Qwertiy 并变成3个无穷大 2020-11-06 07:15:57 +0000 UTC
  • Martin Hope
    koks_rs 什么是样板代码? 2020-10-27 15:43:19 +0000 UTC
  • Martin Hope
    Sirop4ik 向 git 提交发布的正确方法是什么? 2020-10-05 00:02:00 +0000 UTC
  • Martin Hope
    faoxis 为什么在这么多示例中函数都称为 foo? 2020-08-15 04:42:49 +0000 UTC
  • Martin Hope
    Pavel Mayorov 如何从事件或回调函数中返回值?或者至少等他们完成。 2020-08-11 16:49:28 +0000 UTC

热门标签

javascript python java php c# c++ html android jquery mysql

Explore

  • 主页
  • 问题
    • 热门问题
    • 最新问题
  • 标签
  • 帮助

Footer

RError.com

关于我们

  • 关于我们
  • 联系我们

Legal Stuff

  • Privacy Policy

帮助

© 2023 RError.com All Rights Reserve   沪ICP备12040472号-5