代码执行没有错误,但随着数据量的增加,速度降低,解决方案在某些时候不通过条件。帮助优化代码。也许还有另一种计算公式的选择。
谢谢你。
import numpy as np
ar = np.loadtxt('input.txt', skiprows=1)
ar.astype(dtype=np.float16)
ar = ar[ar[:, 0].argsort()]
s = 0
p = 0
un = np.unique(ar[:, 0])
for _ in un:
for row in ar[ar[:, 0] == _][:, 1]:
s += (
ar[np.where((ar[:, 0] > _)
& ((ar[:, 1] > row)))].shape[0] +
ar[np.where((ar[:, 0] > _)
& ((ar[:, 1] == row)))].shape[0] / 2
)
p += (
ar[np.where(ar[:, 0] > _)].shape[0]
)
print(round(s / p, 7))
我没有对代码进行太多优化,但我无法进一步减少时间。
import numpy as np
ar = np.loadtxt('input.txt', skiprows=1)
ar = ar[ar[:, 0].argsort()]
s = 0
p = 0
un = np.unique(ar[:, 0])
un_1, duplicate_count = np.unique(ar, axis=0, return_counts=True)
pivot_table = np.concatenate((un_1, (np.array([duplicate_count])).T), axis=1)
for _ in un:
for row in pivot_table[pivot_table[:, 0] == _][:, 1]:
s += (
(
np.sum(pivot_table[np.where((pivot_table[:, 0] > _)
& (pivot_table[:, 1] > row))][:, 2])
+ np.sum(pivot_table[np.where((pivot_table[:, 0] > _)
& (pivot_table[:, 1] == row))][:, 2])
/ 2
)
* pivot_table[np.where((pivot_table[:, 0] == _) & (pivot_table[:, 1] == row))][:, 2]
)
p += (
np.sum(pivot_table[np.where(pivot_table[:, 0] > _)][:, 2])
* pivot_table[np.where((pivot_table[:, 0] == _) & (pivot_table[:, 1] == row))][:, 2]
)
print(round(float(s / p), 7))
你的算法是二次的。无论您如何优化它,任何比较所有记录对的算法都是相同的。
一个快速的解决方案是基于一种用于计算排列/数组中反转数的算法 - 合并排序的修改计算 NlogN 中的反转数。我们对反转的总数不感兴趣,但是对于数组的每个元素,我们可以计算它前面有多少比它小的元素。
输入对
(t_i, y_i)
按t_i
(如果有点t_i
相等,y_i
则按降序排序)排序并用两个计数器填充:对按 排序的记录列表进行合并排序
t_i
构建按 排序的记录列表y_i
。lt_i
它分两次计算:首先,在归并排序中,计算较小或相等的元素,然后从它们中减去相等元素的数量。eq_i
计入一次通过已排序的列表。使用计数器
lt_i
,eq_i
您可以在线性时间内计算 ROC-AUC。NlogN 算法的总复杂度。我无法访问检查系统,我将您的解决方案的性能与我在随机数据集上的性能进行了比较。检查了两组唯一值和重复值。在所有情况下,ROC-AUC 的前六位数都匹配。
最大的困难是由重复
t_i
和y_i
各种组合的处理引起的。细节我就不说了,太多了。交货时间:
PS
print(round(s / p, 7))
- 使用格式化打印四舍五入的值更好。真正的小数不能在计算机中精确表示。您冒着打印未舍入结果的风险。所以更好:print(f'{s / p:.7f}')
。