目录
- 一、归并排序(Merge Sort)
- 1.1 二路归并
- 1.2 归并排序算法
- 1.3 应用:计算逆序对的数量
- 二、快速排序(Quick Sort)
- 2.1 快速排序算法
- 2.2 应用:快速选择
- 三、模板汇总
- References
一、归并排序(Merge Sort)
在讲解什么是归并排序之前,我们先来看一下什么是二路归并。
1.1 二路归并
二路归并指的是将两个递增序列合并成一个新的递增序列(这个序列可以是链表、数组等)。例如,对于序列 A = [ 1 , 3 , 5 , 7 ] A=[1,3,5,7] A=[1,3,5,7], B = [ 2 , 4 , 6 , 8 ] B=[2,4,6,8] B=[2,4,6,8],它们合并后应当是 C = [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] C=[1,2,3,4,5,6,7,8] C=[1,2,3,4,5,6,7,8]。
二路归并算法的实现非常简单:
#include <iostream>
using namespace std;
void merge(const int a[], size_t a_len, const int b[], size_t b_len, int c[]) {
size_t i = 0, j = 0, k = 0;
while (i < a_len && j < b_len) {
// a[i] <= b[j] 是为了保证稳定性
if (a[i] <= b[j]) c[k++] = a[i++];
else c[k++] = b[j++];
}
// 此时一个数组已空,另一个数组非空,将非空的数组并入 c 中
while (i < a_len) c[k++] = a[i++];
while (j < b_len) c[k++] = b[j++];
}
int main() {
int a[] = {1, 3, 5, 7};
int b[] = {2, 4, 6, 7, 8};
int c[9];
merge(a, 4, b, 5, c);
for (auto i: c) cout << i << " ";
// 输出:1 2 3 4 5 6 7 7 8
return 0;
}
当然,我们也可以调用 <algorithm>
库里的 merge
函数(官方文档),用法如下:
#include <iostream>
#include <algorithm>
using namespace std;
int main() {
int a[] = {1, 3, 5, 7};
int b[] = {2, 4, 6, 8, 10};
int c[9];
merge(a, a + 4, b, b + 5, c); // 注意是左闭右开区间
for (auto i: c) cout << i << " ";
// 输出:1 2 3 4 5 6 7 8 10
return 0;
}
1.2 归并排序算法
接下来基于二路归并的思想实现归并排序。定义一个 merge_sort(a, l, r)
函数用来实现对数组
a
a
a 中的
[
l
,
r
]
[l,r]
[l,r] 区间实现归并排序,于是 merge_sort(a, 0, n - 1)
代表对整个数组实现归并排序。当 l >= r
时,此时至多只有一个数字,不用排序,直接返回即可;当 l < r
时,我们将区间
[
l
,
r
]
[l,r]
[l,r] 尽量分成等长的两段,即取
m
i
d
=
⌊
l
+
r
2
⌋
mid=\left\lfloor\frac{l+r}{2}\right\rfloor
mid=⌊2l+r⌋,于是我们可以调用 merge_sort(a, l, mid)
和 merge_sort(a, mid + 1, r)
得到两个递增序列,然后再使用二路归并算法将这两个递增序列合并为一个递增序列。
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1e5 + 10;
int q[N], tmp[N];
void merge_sort(int a[], int l, int r) {
if (l >= r) return;
int mid = l + r >> 1;
merge_sort(a, l, mid), merge_sort(a, mid + 1, r);
// merge_sort是左闭右闭,但merge是左闭右开
merge(a + l, a + mid + 1, a + mid + 1, a + r + 1, tmp + l);
for (int i = l; i <= r; i++) a[i] = tmp[i];
}
int main() {
int n;
cin >> n;
for (int i = 0; i < n; i++) cin >> q[i];
merge_sort(q, 0, n - 1);
for (int i = 0; i < n; i++) cout << q[i] << " ";
return 0;
}
归并排序基于分治思想将数组分段排序后合并,时间复杂度在最优、最坏与平均情况下均为 O ( n log n ) O(n\log n) O(nlogn),空间复杂度为 O ( n ) O(n) O(n)。
这是因为归并排序会不断二分长度为
n
n
n 的区间,第一层有
2
2
2 个子区间,第二层有
4
4
4 个子区间,第三层有
8
8
8 个子区间,如此进行下去最终会得到
O
(
log
n
)
O(\log n)
O(logn) 层,而每一层进行二路归并的时间复杂度均为
O
(
n
)
O(n)
O(n),因此总时间复杂度为
O
(
n
log
n
)
O(n\log n)
O(nlogn)。由于归并排序在合并的时候使用了额外的数组 tmp
,它的大小与被排序的数组相同,因此空间复杂度为
O
(
n
)
O(n)
O(n)。
1.3 应用:计算逆序对的数量
原题链接:AcWing 788. 逆序对的数量
这里简要概述一下题目。若 i < j i<j i<j 且 a [ i ] > a [ j ] a[i]>a[j] a[i]>a[j],则称 ( a [ i ] , a [ j ] ) (a[i],a[j]) (a[i],a[j]) 为一个逆序对。现给定数组 a a a,求其中逆序对的个数。
此题如果暴力求解,则时间复杂度为 O ( n 2 ) O(n^2) O(n2),会TLE。为此考虑使用归并排序解决,时间复杂度为 O ( n log n ) O(n\log n) O(nlogn)。
具体来讲,假设我们定义的 merge_sort(a, l, r)
函数在对数组
a
a
a 中的
[
l
,
r
]
[l,r]
[l,r] 区间进行排序的同时也能返回区间
[
l
,
r
]
[l,r]
[l,r] 中的逆序对数量,那么 merge_sort(a, 0, n - 1)
就是数组
a
a
a 中所有逆序对的数量。当 l >= r
时,此时至多只有一个数字,不可能构成逆序对,返回
0
0
0 即可。下面重点关注 l < r
的情形。
同样地,我们将区间 [ l , r ] [l,r] [l,r] 划分成尽量等长的两个区间: [ l , m i d ] [l, mid] [l,mid] 和 [ m i d + 1 , r ] [mid + 1, r] [mid+1,r]。此时区间 [ l , r ] [l,r] [l,r] 内的逆序对由以下三部分构成:
- [ l , m i d ] [l,mid] [l,mid] 中的逆序对;
- [ m i d + 1 , r ] [mid+1,r] [mid+1,r] 中的逆序对;
- [ l , m i d ] [l,mid] [l,mid] 中的一个数和 [ m i d + 1 , r ] [mid+1,r] [mid+1,r] 中的一个数构成的逆序对;
显然,前两种逆序对的数量可以分别由 merge_sort(a, l, mid)
和 merge_sort(a, mid + 1, r)
来表示。如何求第三种逆序对的数量呢?
注意到当我们调用了 merge_sort(a, l, mid)
和 merge_sort(a, mid + 1, r)
后,
[
l
,
m
i
d
]
[l,mid]
[l,mid] 和
[
m
i
d
+
1
,
r
]
[mid+1,r]
[mid+1,r] 已经是有序区间了。设
a
[
i
]
∈
[
l
,
m
i
d
]
a[i]\in[l,mid]
a[i]∈[l,mid] 和
a
[
j
]
∈
[
m
i
d
+
1
,
r
]
a[j]\in[mid+1,r]
a[j]∈[mid+1,r] 可以构成逆序对,从而
a
[
i
]
>
a
[
j
]
a[i]>a[j]
a[i]>a[j],进而
a
[
i
.
.
m
i
d
]
>
a
[
j
]
a[i..mid]>a[j]
a[i..mid]>a[j](
a
[
i
.
.
m
i
d
]
a[i..mid]
a[i..mid] 表示
a
[
i
]
,
a
[
i
+
1
]
,
⋯
,
a
[
m
i
d
]
a[i],a[i+1],\cdots,a[mid]
a[i],a[i+1],⋯,a[mid]),故对
a
[
j
]
a[j]
a[j] 而言,逆序对的数量为
m
i
d
−
i
+
1
mid-i+1
mid−i+1。
数组
a
a
a 的最大逆序对数量为
n
(
n
−
1
)
2
\frac{n(n-1)}{2}
2n(n−1),若
n
n
n 取到
1
0
5
10^5
105,则可知
1
2
1
0
10
=
5
×
1
0
9
>
2
×
1
0
9
\frac1210^{10}=5\times10^9>2\times 10^9
211010=5×109>2×109,因此必须开 long long
。
#include <iostream>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10;
int q[N], tmp[N];
LL merge_sort(int a[], int l, int r) {
if (l >= r) return 0;
int mid = l + r >> 1;
LL res = merge_sort(a, l, mid) + merge_sort(a, mid + 1, r);
int i = l, j = mid + 1, k = l;
while (i <= mid && j <= r) {
if (a[i] <= a[j]) tmp[k++] = a[i++];
else {
tmp[k++] = a[j++];
res += mid - i + 1;
}
}
while (i <= mid) tmp[k++] = a[i++];
while (j <= r) tmp[k++] = a[j++];
for (int i = l; i <= r; i++) a[i] = tmp[i];
return res;
}
int main() {
int n;
cin >> n;
for (int i = 0; i < n; i++) cin >> q[i];
cout << merge_sort(q, 0, n - 1) << endl;
return 0;
}
二、快速排序(Quick Sort)
2.1 快速排序算法
快速排序,又称分区交换排序(英语:partition-exchange sort),简称「快排」,是一种被广泛运用的排序算法。
类似地,定义一个 quick_sort(a, l, r)
函数用来实现对数组
a
a
a 中的
[
l
,
r
]
[l,r]
[l,r] 区间实现快速排序,于是 quick_sort(a, 0, n - 1)
代表对整个数组实现快速排序。当 l >= r
时,此时至多只有一个数字,不用排序,直接返回即可;当 l < r
时,按以下三个步骤执行
- 确定分界点 x x x,其中 x ∈ { a [ l ] , a [ l + 1 ] , ⋯ , a [ r ] } x\in\{a[l],a[l+1],\cdots,a[r]\} x∈{a[l],a[l+1],⋯,a[r]}。一般常用的取值有 a [ l ] , a [ r ] , a [ m i d ] , m i d = ⌊ l + r 2 ⌋ a[l],a[r],a[mid],\;mid=\lfloor\frac{ l+r}{2}\rfloor a[l],a[r],a[mid],mid=⌊2l+r⌋。这里我们取 a [ m i d ] a[mid] a[mid];
- 调整区间。即把区间分成两部分,确保左边区间的所有值 ≤ x \leq x ≤x,右边区间的所有值 ≥ x \geq x ≥x。为实现这一步,我们可以采用双指针解法;
- 递归地处理左右两个区间。
快排的关键之处在于如何调整区间。我们可以设置左右两个指针初始时分别位于区间的左端点和右端点。当左指针指向的值 ≤ x \leq x ≤x 时就让它向右移动,当右指针指向的值 ≥ x \geq x ≥x 时就让它向左移动。当两个指针都停下时,交换两个指针指向的值。这样就可以保证左指针左边的值始终 ≤ x \leq x ≤x,右指针右边的值始终 ≥ x \geq x ≥x。若左指针和右指针相遇或左指针位于右指针的右侧则调整结束。
下图展示了快排的整个流程:
#include <iostream>
using namespace std;
const int N = 1e5 + 10;
int q[N];
void quick_sort(int a[], int l, int r) {
if (l >= r) return;
int mid = l + r >> 1;
int i = l - 1, j = r + 1, x = a[mid];
while (i < j) {
while (a[++i] < x);
while (a[--j] > x);
if (i < j) swap(a[i], a[j]);
}
quick_sort(a, l, j), quick_sort(a, j + 1, r);
}
int main() {
int n;
cin >> n;
for (int i = 0; i < n; i++) cin >> q[i];
quick_sort(q, 0, n - 1);
for (int i = 0; i < n; i++) cout << q[i] << " ";
return 0;
}
快速排序的最优时间复杂度和平均时间复杂度为 O ( n log n ) O(n\log n) O(nlogn),最坏时间复杂度为 O ( n 2 ) O(n^2) O(n2)。
2.2 应用:快速选择
原题链接:AcWing 786. 第k个数
这里简要概述一下题目。给定一个长度为 n n n 的整数数列,请在不超过 O ( n log n ) O(n\log n) O(nlogn) 的时间内求出数列从小到大排序后的第 k k k 个数。
显然我们可以用快速排序或归并排序对该数列排序后再输出 k k k 个数,但如果使用快速选择算法,时间复杂度甚至可以下降到 O ( n ) O(n) O(n)。
定义一个 quick_select(a, l, r, k)
函数用来选取数组
a
a
a 中区间
[
l
,
r
]
[l, r]
[l,r] 内的第
k
k
k 小的数,那么本题答案即为 quick_select(a, 0, n - 1, k)
。
如何计算 quick_select(a, l, r, k)
呢?我们可以采用上述快排的做法,即采用双指针将区间
[
l
,
r
]
[l,r]
[l,r] 分成两部分。记左区间
[
l
,
j
]
[l, j]
[l,j] 的长度为
S
L
S_L
SL,右区间
[
j
+
1
,
r
]
[j+1,r]
[j+1,r] 的长度为
S
R
S_R
SR,显然,左区间的所有数字要小于等于右区间的所有数字。此时,若
k
≤
S
L
k\leq S_L
k≤SL,那么要找的数就位于左区间,此时递归执行 quick_select(a, l, j, k)
即可。若
k
>
S
L
k>S_L
k>SL,那么要找的数位于右区间。注意到
[
l
,
r
]
[l,r]
[l,r] 中第
k
k
k 小的数等价于
[
j
+
1
,
r
]
[j+1,r]
[j+1,r] 中第
k
−
S
L
k-S_L
k−SL 小的数,因此递归执行 quick_select(q, j + 1, r, k - sl)
即可。
#include <iostream>
using namespace std;
const int N = 1e5 + 10;
int q[N];
int quick_select(int a[], int l, int r, int k) {
if (l >= r) return a[l];
int mid = l + r >> 1;
int i = l - 1, j = r + 1, x = a[mid];
while (i < j) {
while (a[++i] < x);
while (a[--j] > x);
if (i < j) swap(a[i], a[j]);
}
int sl = j - l + 1;
if (k <= sl) return quick_select(a, l, j, k);
else return quick_select(a, j + 1, r, k - sl);
}
int main() {
int n, k;
cin >> n >> k;
for (int i = 0; i < n; i++) cin >> q[i];
cout << quick_select(q, 0, n - 1, k) << endl;
return 0;
}
当然,上述代码还可以继续化简。先前的 quick_select
函数中的
k
k
k 代表的是第几个数,如果将它的含义更改为「下标」,我们就不需要 sl
这个变量了:
#include <iostream>
using namespace std;
const int N = 1e5 + 10;
int q[N];
int quick_select(int a[], int l, int r, int k) {
if (l >= r) return a[l];
int mid = l + r >> 1;
int i = l - 1, j = r + 1, x = a[mid];
while (i < j) {
while (a[++i] < x);
while (a[--j] > x);
if (i < j) swap(a[i], a[j]);
}
if (k <= j) return quick_select(a, l, j, k);
else return quick_select(a, j + 1, r, k);
}
int main() {
int n, k;
cin >> n >> k;
for (int i = 0; i < n; i++) cin >> q[i];
// 注意函数调用也发生了变化
cout << quick_select(q, 0, n - 1, k - 1) << endl;
return 0;
}
快速选择的平均时间复杂度为 O ( n ) O(n) O(n),最坏时间复杂度为 O ( n 2 ) O(n^2) O(n2)。
快选和快排的区别是,快排每次要递归左右两边,而快选只需要递归其中一边,这就决定了快排平均 O ( n log n ) O(n\log n) O(nlogn),快选平均 O ( n ) O(n) O(n)。
当然,我们也可以调用 <algorithm>
库里的 nth_element
函数(官方文档),它的用法如下:
template< class RandomIt >
void nth_element( RandomIt first, RandomIt nth, RandomIt last );
nth_element
是部分排序算法,它重排 [first, last)
中的元素,使得 nth
所指向的元素恰好是 [first, last)
重排后该位置会出现的元素。并且,该元素左边的数均小于等于它,该元素右边的数均大于等于它。
使用该函数,上述代码可以简化为
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1e5 + 10;
int q[N];
int main() {
int n, k;
cin >> n >> k;
for (int i = 0; i < n; i++) cin >> q[i];
nth_element(q, q + k - 1, q + n);
cout << q[k - 1] << endl;
return 0;
}
三、模板汇总
归并排序:
void merge_sort(int a[], int l, int r) {
if (l >= r) return;
int mid = l + r >> 1;
merge_sort(a, l, mid), merge_sort(a, mid + 1, r);
int i = l, j = mid + 1, k = l;
while (i <= mid && j <= r) {
if (a[i] <= a[j]) tmp[k++] = a[i++];
else tmp[k++] = a[j++];
}
while (i <= mid) tmp[k++] = a[i++];
while (j <= r) tmp[k++] = a[j++];
for (int i = l; i <= r; i++) a[i] = tmp[i];
}
注意,tmp
数组需要事先声明为全局变量,大小和数组 a
相同。
快速排序:
void quick_sort(int a[], int l, int r) {
if (l >= r) return;
int mid = l + r >> 1;
int i = l - 1, j = r + 1, x = a[mid];
while (i < j) {
while (a[++i] < x);
while (a[--j] > x);
if (i < j) swap(a[i], a[j]);
}
quick_sort(a, l, j), quick_sort(a, j + 1, r);
}
快速选择:
int quick_select(int a[], int l, int r, int k) {
if (l >= r) return a[l];
int mid = l + r >> 1;
int i = l - 1, j = r + 1, x = a[mid];
while (i < j) {
while (a[++i] < x);
while (a[--j] > x);
if (i < j) swap(a[i], a[j]);
}
if (k <= j) return quick_select(a, l, j, k);
else return quick_select(a, j + 1, r, k);
}
注意,这里的 k
是下标。
References
[1] https://oi-wiki.org/basic/sort-intro/
[2] https://www.acwing.com/activity/content/punch_the_clock/11/
[3] https://en.wikipedia.org/wiki/Quicksort