Theme NexT works best with JavaScript enabled
0%

Java求解TopK问题

^ _ ^

TopK 问题

10亿个数中如何高效地找到最大的一个数以及最大的K个数

参考链接

  1. 参考1:https://github.com/weitingyuk/LeetCode-Notes-Waiting/blob/main/2021-02-17/TopK.md
  2. 参考2:https://zhuanlan.zhihu.com/p/72164039

方法一:全部排序

  • 思路:将n个数进行排序后取前k个
  • 时间复杂度:快排时间复杂度为O(nlogn)
  • 空间:在32位机器上,float类型占用4Byte,$10^9$个数需要占用越4GB的存储空间。
  • 缺点:全部排序需要将数据全部载入内存,如果电脑内存小于4GB,则无法使用该方法。另一方面,该方法对于内存大于4GB的电脑也不高效,因为问题要求只是找到TopK,而排序将所有元素都排序了,做了很多无用功。

实验

实验目的:生成一个$10^6$个数,用快速排序对其运行时间进行测试。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/**
* @ Author LuckyQ
* @ Date 2021-05-22 13:59
* @ Description 快速排序
* @ 时间复杂度 O(nlogn)
* @ 空间复杂度 O(n)
*/

import java.util.Random;
import java.lang.Math;

public class QuickSort{
// Return the index of nums[0] should be
private static int partion(int[] nums, int left, int right){
int num = nums[left];
int i = left + 1, j = right;

// Divide nums by num to this situation --> | <= num | num | > num |
while(i <= j){
while(i <= j && nums[i] >= num){
i++;
}
while(i <= j && nums[j] < num){
j--;
}
if(i < j){
int temp = nums[j];
nums[j] = nums[i];
nums[i] = temp;
}
}
int temp = nums[j];
nums[j] = num;
nums[left] = temp;

return j;
}

public static void sort(int[] nums, int left, int right){
if(right <= left){
return;
}
int mid = QuickSort.partion(nums, left, right);
sort(nums, left, mid - 1);
sort(nums, mid + 1, right);
}

public static int[] generateNums(int N, int highBound){
int[] nums = new int[N];
Random random = new Random();
for(int i = 0; i < N; i++){
nums[i] = random.nextInt(highBound);
}
return nums;
}

public static void printArray(int[] nums){
System.out.println("-----------------------------------\n");
for(int i = 0; i < nums.length; i++){
System.out.print(nums[i]);
if((i + 1) % 10 == 0){
System.out.println();
}
else{
System.out.print("\t");
}
}
System.out.println("\n-----------------------------------");
}

// N is the number of experience
public static double experience(int[] nums, int k, int N){
double time = 0.0;
for(int i = 0;i < N; i++){
int[] container = new int[k];

long startTime = System.currentTimeMillis();
QuickSort.sort(nums, 0, nums.length - 1);
for(int j = 0;j < k; j++){
container[j] = nums[j];
}
long endTime = System.currentTimeMillis();
// if(i == 0){
// AdvancedQuickSort.printArray(container);
// }
time += (endTime - startTime) / 1000.0;
}
return time / N;
}

public static void main(String[] args){
int[] nums = QuickSort.generateNums(1000000, 101);
// QuickSort.printArray(nums);
double averageTime = QuickSort.experience(nums, 100, 10);
System.out.println("程序运行平均时间:" + Math.round(averageTime * 1000) + "ms");
}

}

实验结果

方法二:快排的改进算法

  • 思路:本质上是运用了快排中分治思想。首先对数组第一个元素对数组进行一个划分,划分后使其左边的元素都不小于该元素,右边的元素都大于该元素。不妨将划分元素归为左部分。如果左边部分元素的数目等于K,则代表找到前K个数;如果左边部分元素的数目大于K,则继续对左部分数组进行划分;如果左边部分元素的数目小于K,则继续对右部元素做(K-左部元素个数)划分。
  • 时间复杂度O(nlogn)(不确定)
  • 空间复杂度:4n Byte
  • 缺点:数据也要全部载入内存,但是速度会比普通的快排快。

实验

实验目的:生成一个$10^6$个数,用快速排序对其运行时间进行测试。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
/**
* @ Author LuckyQ
* @ Date 2021-05-22 22:16
* @ Description 改进版快速排序
* @ 时间复杂度 O(nlogn)
* @ 空间复杂度 O(n)
*/

import java.util.Random;
import java.lang.Math;

public class AdvancedQuickSort{
// Return the index of nums[0] should be
private static int partion(int[] nums, int left, int right){
int num = nums[left];
int i = left + 1, j = right;

// Divide nums by num to this situation --> | <= num | num | > num |
while(i <= j){
while(i <= j && nums[i] >= num){
i++;
}
while(i <= j && nums[j] < num){
j--;
}
if(i < j){
int temp = nums[j];
nums[j] = nums[i];
nums[i] = temp;
}
}
int temp = nums[j];
nums[j] = num;
nums[left] = temp;

return j;
}

public static void sort(int[] nums, int left, int right, int k){
if(right <= left){
return;
}
int mid = AdvancedQuickSort.partion(nums, left, right);
int num = mid - left + 1;
if(num == k){
return;
}
else if(num > k){
AdvancedQuickSort.sort(nums, left, mid - 1, k);
}
else{
AdvancedQuickSort.sort(nums, mid + 1, right, k - num);
}
}

public static int[] generateNums(int N, int highBound){
int[] nums = new int[N];
Random random = new Random();
for(int i = 0; i < N; i++){
nums[i] = random.nextInt(highBound);
}
return nums;
}

public static void printArray(int[] nums){
System.out.println("-----------------------------------\n");
for(int i = 0; i < nums.length; i++){
System.out.print(nums[i]);
if((i + 1) % 10 == 0){
System.out.println();
}
else{
System.out.print("\t");
}
}
System.out.println("\n-----------------------------------");
}

// N is the number of experience
public static double experience(int[] nums, int k, int N){
double time = 0.0;
for(int i = 0;i < N; i++){
long startTime = System.currentTimeMillis();
int[] container = new int[k];
AdvancedQuickSort.sort(nums, 0, nums.length - 1, k);
for(int j = 0;j < k; j++){
container[j] = nums[j];
}
long endTime = System.currentTimeMillis();
// if(i == 0){
// AdvancedQuickSort.printArray(container);
// }
time += (endTime - startTime) / 1000.0;
}
return time / N;
}

public static void main(String[] args){
int[] nums = AdvancedQuickSort.generateNums(1000000, 101);
// AdvancedQuickSort.printArray(nums);
double averageTime = AdvancedQuickSort.experience(nums, 100, 10);
System.out.println("程序运行平均时间:" + Math.round(averageTime * 1000) + "ms");
}
}

实验结果

方法三:局部淘汰法

  • 思路:用一个容器保存前K个数,然后将剩余的所有数字一一与容器内的最小数字相比,如果某一后续元素比容器内最小数字大,则删掉容器内最小元素,并将该元素插入容器。
  • 时间复杂度O(Kn)
  • 空间:只用容器是一定要保存在内存中的,所以需要的空间为$4m$Byte,m为容器大小。
  • 缺点:时间复杂度较高

实验

实验目的:生成一个$10^6$个数,用局部淘汰法对其运行时间进行测试,其中K值为100。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/**
* @ Author LuckyQ
* @ Date 2021-05-22 21:11
* @ Description 局部淘汰法
* @ 时间复杂度 O(nk)
* @ 空间复杂度 O(k)
*/

import java.util.Random;
import java.lang.Math;

public class PartialElimination{

public static int[] partialElimination(int[] nums, int k){
int[] container = new int[k];
for(int i = 0; i < k; i++){
container[i] = nums[i];
}
for(int i = k; i < nums.length; i++){
// is there any item lower than nums[i] in container,if the answer is yes then index != -1
int index = -1, minNum = nums[i];
for(int j = 0; j < k; j++){
if(container[j] < minNum){
index = j;
minNum = container[j];
}
}
if(index != -1){
container[index] = nums[i];
}
}
return container;
}

public static int[] generateNums(int N, int highBound){
int[] nums = new int[N];
Random random = new Random();
for(int i = 0; i < N; i++){
nums[i] = random.nextInt(highBound);
}
return nums;
}

public static void printArray(int[] nums){
System.out.println("-----------------------------------\n");
for(int i = 0; i < nums.length; i++){
System.out.print(nums[i]);
if((i + 1) % 10 == 0){
System.out.println();
}
else{
System.out.print("\t");
}
}
System.out.println("\n-----------------------------------");
}

// k is the number of target, N is the number of experience
public static double experience(int[] nums, int k, int N){
double time = 0.0;
for(int i = 0;i < N; i++){
long startTime = System.currentTimeMillis();
int[] container = PartialElimination.partialElimination(nums, k);
long endTime = System.currentTimeMillis();
// PartialElimination.printArray(container);
time += (endTime - startTime) / 1000.0;
}
return time / N;
}

public static void main(String[] args){
int[] nums = PartialElimination.generateNums(1000000, 101);
// PartialElimination.printArray(nums);
double averageTime = PartialElimination.experience(nums, 100, 10);
System.out.println("程序运行平均时间:" + Math.round(averageTime * 1000) + "ms");
}
}

实验结果

方法四:分治法

  • 思路:将n个数据分为m份,每份$\frac{n}{m}$个数据,找到每份数据中最大的K个,最后在得到的mk个数据中找出最大的k个。m份数据可以分别在m台计算机中同时进行计算。
  • 时间复杂度O(nlog(n/m))
  • 空间:4 * max(n/m, mk) Byte
  • 缺点:分布式的部署和通信比较复杂

实验

实验目的:生成一个$10^6$个数,用一个线程模拟一台计算机,共设置了10台计算机。用分治法对其运行时间进行测试,其中K值为100。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
/**
* @ Author LuckyQ
* @ Date 2021-05-22 23:00
* @ Description 分治思想求TopK
* @ 时间复杂度 O(nlog(n/m))
* @ 空间复杂度 O(max(n/m, mk))
*/

import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.Random;
import java.lang.Math;
import java.util.List;
import java.util.ArrayList;

public class DivideConquer{
private static int partion(int[] nums, int left, int right){
int num = nums[left];
int i = left + 1, j = right;

// Divide nums by num to this situation --> | <= num | num | > num |
while(i <= j){
while(i <= j && nums[i] >= num){
i++;
}
while(i <= j && nums[j] < num){
j--;
}
if(i < j){
int temp = nums[j];
nums[j] = nums[i];
nums[i] = temp;
}
}
int temp = nums[j];
nums[j] = num;
nums[left] = temp;

return j;
}

public static void sort(int[] nums, int left, int right, int k){
if(right <= left){
return;
}
int mid = DivideConquer.partion(nums, left, right);
int num = mid - left + 1;
if(num == k){
return;
}
else if(num > k){
DivideConquer.sort(nums, left, mid - 1, k);
}
else{
DivideConquer.sort(nums, mid + 1, right, k - num);
}
}

public static int[] divideConquer(int[] nums, int m, int k){
int n = nums.length / m;

ExecutorService executor = Executors.newFixedThreadPool(m);
List<Future<int[]>> futureList = new ArrayList<>();
for(int i = 0; i < m; i++){
int[] partNums = new int[n];
for(int j = 0; j < n; j++){
partNums[j] = nums[i * n + j];
}
futureList.add(executor.submit(new ProcessThread(partNums, k)));
}

executor.shutdown();
int[] topkMerge = new int[m * k];
int index = 0;
for(Future<int[]> future: futureList){
try{
int[] container = future.get();
// DivideConquer.printArray(container, "computer[" + (index + 1) + "]");
for(int i = 0; i < container.length; i++){
topkMerge[index * k + i] = container[i];
}
index += 1;
}catch(Exception e){
e.printStackTrace();
}
}
// DivideConquer.printArray(topkMerge, "merge");
DivideConquer.sort(topkMerge, 0, topkMerge.length - 1, k);
int[] res = new int[k];
for(int j = 0; j < k; j++){
res[j] = topkMerge[j];
}

return res;
}

public static int[] generateNums(int N, int highBound){
int[] nums = new int[N];
Random random = new Random();
for(int i = 0; i < N; i++){
nums[i] = random.nextInt(highBound);
}
return nums;
}

public static void printArray(int[] nums, String title){
System.out.println("------------" + title + "-----------------------\n");
for(int i = 0; i < nums.length; i++){
System.out.print(nums[i]);
if((i + 1) % 10 == 0){
System.out.println();
}
else{
System.out.print("\t");
}
}
System.out.println("\n-----------------------------------");
}

// N is the number of experience
public static double experience(int[] nums, int m, int k, int N){
double time = 0.0;
for(int epoch = 0; epoch < N; epoch++){
long startTime = System.currentTimeMillis();
int[] result = DivideConquer.divideConquer(nums, m, k);
long endTime = System.currentTimeMillis();
// if(epoch == 0){
// DivideConquer.printArray(result,"result");
// }
time += (endTime - startTime) / 1000.0;
}

return time / N;
}

public static void main(String[] args){
int[] nums = DivideConquer.generateNums(1000000, 101);
// DivideConquer.printArray(nums,"origin");
double averageTime = DivideConquer.experience(nums, 10, 100, 10);
System.out.println("程序运行平均时间:" + Math.round(averageTime * 1000) + "ms");
}
}

class ProcessThread implements Callable<int[]> {
private int[] nums;
private int k;

public ProcessThread(int[] nums, int k){
this.nums = nums;
this.k = k;
}

// AdvancedQuickSort
// Return the index of nums[0] should be
private int partion(int left, int right){
int num = nums[left];
int i = left + 1, j = right;

// Divide nums by num to this situation --> | <= num | num | > num |
while(i <= j){
while(i <= j && nums[i] >= num){
i++;
}
while(i <= j && nums[j] < num){
j--;
}
if(i < j){
int temp = nums[j];
nums[j] = nums[i];
nums[i] = temp;
}
}
int temp = nums[j];
nums[j] = num;
nums[left] = temp;

return j;
}

// AdvancedQuickSort
private void sort(int left, int right){
if(right <= left){
return;
}
int mid = partion(left, right);
int num = mid - left + 1;
if(num == k){
return;
}
else if(num > k){
sort(left, mid - 1);
}
else{
sort(mid + 1, right);
}
}

@Override
public int[] call(){
try{
sort(0, nums.length - 1);

int[] container = new int[k];
for(int i = 0; i < k; i++){
container[i] = nums[i];
}
return container;
}catch(Exception e){
e.printStackTrace();
return null;
}
}
}

实验结果

方法五:Hash法

思路:先通过Hash法,对数据集进行去重复。这样如果重复率很高的话,会减少很大的内存用量,从而缩小运算空间,然后通过分治法或最小堆法查找最大的K个数。

关于这个方法我有一个疑惑,前K个数不是包含相等的数吗,去重之后不就少数了吗?比如,生成了100万个取值范围在[0,100]的数,希望找到最大的1000个数。这样去重后最大K个数最多也才100个啊…

方法六:小顶堆

  • 思路:首先读入前K个数来创建大小为K的小顶堆,建堆的时间复杂度为O(K),然后遍历后续的数字,并于堆顶(最小)数字进行比较。如果比最小的数小,则继续读取后续数字;如果比堆顶数字大,则替换堆顶元素并重新调整堆为最小堆。
  • 时间复杂度O(NlogK)
  • 空间:$2^K - 1$

实验1

实验目的:生成一个$10^6$个数,用小顶堆法对其运行时间进行测试,其中K值为100。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
/**
* @ Author LuckyQ
* @ Date 2021-06-08 9:27
* @ Description 小顶堆实现topK
* @ 时间复杂度 O(nlogk)
* @ 空间复杂度 O(2^k-1)
*/

import java.util.Random;
import java.lang.Math;

public class SmallHeap{
public static int[] getTopK(int[] nums, int N, int K){
// res是结果数组
int[] res = new int[K];
for(int i = 0; i < K; i++){
res[i] = nums[i];
}
// 建立一个大小为 K 的小顶堆,设堆顶为0
buildHeap(res, K);
// 对于剩余 N-K 个数字,逐个判断是否能加入小顶堆
// 若大于小顶堆的堆顶(最小值),则将其替换最小堆堆顶并调准堆。
for(int i = K; i < N; i++){
if(nums[i] > res[0]){
res[0] = nums[i];
shiftDown(res, K, 0);
}
}
return res;
}

// 建堆
private static void buildHeap(int[] nums, int size){
// 最小的树即最后一个叶子的根,它的坐标为(size - 1) 其父亲结点坐标为(size - 1 - 1) / 2
for(int i = (size - 2) / 2; i >= 0; i--){
shiftDown(nums, size, i);
}
}

// 向下调准
private static void shiftDown(int[] nums, int size, int index){
int left = 2 * index + 1;
while(left < size){
int min = left;
int right = left + 1;
if(right < size && nums[right] < nums[left]){
min = right;
}
if(nums[index] < nums[min]){
break;
}
int tmp = nums[min];
nums[min] = nums[index];
nums[index] = tmp;
// 更新 left, index
index = min;
left = 2 * index + 1;
}
}

public static int[] generateNums(int N, int highBound){
int[] nums = new int[N];
Random random = new Random();
for(int i = 0; i < N; i++){
nums[i] = random.nextInt(highBound);
}
return nums;
}

public static void printArray(int[] nums){
System.out.println("-----------------------------------\n");
for(int i = 0; i < nums.length; i++){
System.out.print(nums[i]);
if((i + 1) % 10 == 0){
System.out.println();
}
else{
System.out.print("\t");
}
}
System.out.println("\n-----------------------------------");
}

// N is the number of experience
public static double experience(int[] nums, int k, int N){
double time = 0.0;
for(int i = 0;i < N; i++){
long startTime = System.currentTimeMillis();
int[] container = SmallHeap.getTopK(nums, nums.length, k);
long endTime = System.currentTimeMillis();
// if(i == 0){
// SmallHeap.printArray(container);
// }
time += (endTime - startTime) / 1000.0;
}
return time / N;
}

public static void main(String[] args){
int[] nums = SmallHeap.generateNums(1000000, 101);
// SmallHeap.printArray(nums);
double averageTime = SmallHeap.experience(nums, 100, 10);
System.out.println("程序运行平均时间:" + Math.round(averageTime * 1000) + "ms");
}
}

实验结果

实验2

实验1代码中手动实现了小顶堆,其实还可以使用 Java 中的优先队列 PriorityQueue 进行实现。

实验目的:重复上述实验,只是将小顶堆使用优先队列进行实现。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
/**
* @ Author LuckyQ
* @ Date 2021-06-10 13:42
* @ Description 队列实现小顶堆的topK
* @ 时间复杂度 O(nlogk)
* @ 空间复杂度 O(2^k-1)
*/

import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Random;
import java.lang.Math;

public class PriorityHeap {
public static int[] getTopK(int[] nums, int N, int K){
// 初始化优先队列,默认情况下是小顶堆
Queue<Integer> queue = new PriorityQueue<>();
// 将前 K 个元素放入优先队列
for(int i = 0; i < K; i++){
queue.add(nums[i]);
}
// 对于剩余 N-K 个数字,逐个判断是否能加入小顶堆
// 若大于小顶堆的堆顶(最小值),则将其替换最小堆堆顶并调准堆。
for(int i = K; i < N; i++){
if(nums[i] > queue.peek()){
queue.poll();
queue.add(nums[i]);
}
}
// res是结果数组
int[] res = new int[K];
for(int i = 0; i < K; i++){
res[i] = queue.poll();
}
return res;
}

public static int[] generateNums(int N, int highBound){
int[] nums = new int[N];
Random random = new Random();
for(int i = 0; i < N; i++){
nums[i] = random.nextInt(highBound);
}
return nums;
}

public static void printArray(int[] nums){
System.out.println("-----------------------------------\n");
for(int i = 0; i < nums.length; i++){
System.out.print(nums[i]);
if((i + 1) % 10 == 0){
System.out.println();
}
else{
System.out.print("\t");
}
}
System.out.println("\n-----------------------------------");
}

// N is the number of experience
public static double experience(int[] nums, int k, int N){
double time = 0.0;
for(int i = 0;i < N; i++){
long startTime = System.currentTimeMillis();
int[] container = PriorityHeap.getTopK(nums, nums.length, k);
long endTime = System.currentTimeMillis();
// if(i == 0){
// PriorityHeap.printArray(container);
// }
time += (endTime - startTime) / 1000.0;
}
return time / N;
}

public static void main(String[] args){
int[] nums = SmallHeap.generateNums(1000000, 101);
// PriorityHeap.printArray(nums);
double averageTime = PriorityHeap.experience(nums, 100, 10);
System.out.println("程序运行平均时间:" + Math.round(averageTime * 1000) + "ms");
}
}

实验结果: