代码执行没有错误,但随着数据量的增加,速度降低,解决方案在某些时候不通过条件。帮助优化代码。也许还有另一种计算公式的选择。
谢谢你。
import numpy as np
ar = np.loadtxt('input.txt', skiprows=1)
ar.astype(dtype=np.float16)
ar = ar[ar[:, 0].argsort()]
s = 0
p = 0
un = np.unique(ar[:, 0])
for _ in un:
for row in ar[ar[:, 0] == _][:, 1]:
s += (
ar[np.where((ar[:, 0] > _)
& ((ar[:, 1] > row)))].shape[0] +
ar[np.where((ar[:, 0] > _)
& ((ar[:, 1] == row)))].shape[0] / 2
)
p += (
ar[np.where(ar[:, 0] > _)].shape[0]
)
print(round(s / p, 7))
我没有对代码进行太多优化,但我无法进一步减少时间。
import numpy as np
ar = np.loadtxt('input.txt', skiprows=1)
ar = ar[ar[:, 0].argsort()]
s = 0
p = 0
un = np.unique(ar[:, 0])
un_1, duplicate_count = np.unique(ar, axis=0, return_counts=True)
pivot_table = np.concatenate((un_1, (np.array([duplicate_count])).T), axis=1)
for _ in un:
for row in pivot_table[pivot_table[:, 0] == _][:, 1]:
s += (
(
np.sum(pivot_table[np.where((pivot_table[:, 0] > _)
& (pivot_table[:, 1] > row))][:, 2])
+ np.sum(pivot_table[np.where((pivot_table[:, 0] > _)
& (pivot_table[:, 1] == row))][:, 2])
/ 2
)
* pivot_table[np.where((pivot_table[:, 0] == _) & (pivot_table[:, 1] == row))][:, 2]
)
p += (
np.sum(pivot_table[np.where(pivot_table[:, 0] > _)][:, 2])
* pivot_table[np.where((pivot_table[:, 0] == _) & (pivot_table[:, 1] == row))][:, 2]
)
print(round(float(s / p), 7))