为了更好的阅读体检,可以查看我的算法学习博客
在线评测链接:P1053
题目内容
你有一个序列,,...,,然后给你一些区间[l,r].对于每一个区间,你需要找到下式的最小值,对于所有可能的x
输入格式
第一行一个整数代表序列长度。
接下来一行有N个正整数,用空格隔开。
接下来一行一个整数,代表询问的区间次数。
接下来Q行,每行一个区间l,
输出格式
输出Q行。每行代表对应的区间的结果。
样例
5 2 3 3 4 4 3 1 2 2 2 2 5 1 0 2
思路:可持久化线段树
我们一步步剖析这道题。
对于这些数,想要找到一个数 x ,使得 最小,则 x 必然是这些数的中位数。详见下方证明。
统计 这些数中,小于等于中位数的数的和为 ltSum,个数为 ltCnt,统计大于中位数的数的和为 gtSum,个数为 gtCnt。
则
因为最多有 次询问,所以单次询问的时间复杂度至多为
如此,我们需要一个数据结构能够在的时间内获得区间 [l, r] 的中位数。
再通过这个数据结构获取到
-
小于这个中位数的所有数之和 ltSum,以及所有数的数量 ltCnt
-
大于这个中位数的所有数之和 gtSum,以及所有数的数量 gtCnt
这个数据结构叫作主席树,也叫可持久化线段树,可以用来求解区间第 k 大。
点击查看主席树教程
时间复杂度:因为 n 与 Q 同阶,故时间复杂度为 O(n\log n)
证明:
假设 是单调递增的。
-
当,
-
当 ,,,
-
当,,,|
故 ,最小。
对 和 a_{r-1} ,分析过程与上类似,以此类推,x 应该为, 的中位数。
代码
#include <bits/stdc++.h> using namespace std; #define sz(x) (int(x.size())) typedef long long ll; const int N = 100010; int a[N], n, Q; vector<int> nums; // 获取每个点离散化后的下标 int get_idx(int x) { return int(lower_bound(nums.begin(), nums.end(), x) - nums.begin()); } struct Node { int l, r; int cnt; ll sum; }tr[N * 21]; // 每个点的root int root[N], idx; // 初始化 int build(int l, int r) { int p = ++idx; if (l == r) return p; int mid = (l + r) >> 1; tr[p].l = build(l, mid); tr[p].r = build(mid + 1, r); return p; } // 插入新的点 int insert(int p, int l, int r, int x) { int q = ++idx; tr[q] = tr[p]; if (l == r) { tr[q].cnt += 1; tr[q].sum += nums[x]; return q; } int mid = (l + r) >> 1; if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x); else tr[q].r = insert(tr[p].r, mid + 1, r, x); tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt; tr[q].sum = tr[tr[q].l].sum + tr[tr[q].r].sum; return q; } // 求区间第 k 大 int query_kth_number_idx(int q, int p, int l, int r, int k) { if (l == r) return l; int lcnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt; int mid = (l + r) >> 1; if (k <= lcnt) return query_kth_number_idx(tr[q].l, tr[p].l, l, mid, k); return query_kth_number_idx(tr[q].r, tr[p].r, mid + 1 , r, k - lcnt); } // 求区间内与 x 的距离之和 ll query_sum_of_dist(int q, int p, int l, int r, int x) { if (l == r) return 0; int mid = (l + r) >> 1; if (x <= nums[mid]) { // 说明右子树的值全部大于 x,gtSum - gtCnt * x ll gtSum = tr[tr[q].r].sum - tr[tr[p].r].sum; ll gtCnt = tr[tr[q].r].cnt - tr[tr[p].r].cnt; return (gtSum - gtCnt * x) + query_sum_of_dist(tr[q].l, tr[p].l, l, mid, x); } else { // 说明左子树的值全部小于 x,ltCnt * x - ltSum ll ltCnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt; ll ltSum = tr[tr[q].l].sum - tr[tr[p].l].sum; return (ltCnt * x - ltSum) + query_sum_of_dist(tr[q].r, tr[p].r, mid + 1, r, x); } } int main() { scanf("%d", &n); for (int i = 1; i <= n; ++i) { scanf("%d", &a[i]); nums.push_back(a[i]); } // 离散化 sort(nums.begin(), nums.end()); nums.erase(unique(nums.begin(), nums.end()), nums.end()); // 构建可持久化权值线段树 root[0] = build(0, sz(nums) - 1); for (int i = 1; i <= n; ++i) { root[i] = insert(root[i - 1], 0, sz(nums) - 1, get_idx(a[i])); } scanf("%d", &Q); while (Q--) { int l, r; scanf("%d%d", &l, &r); // 获取中位数,注意这里中位数的索引应该从 1 开始 int k = (r - l + 1 + 1) / 2; int kidx = query_kth_number_idx(root[r], root[l - 1], 0, sz(nums) - 1, k); printf("%lld\n", query_sum_of_dist(root[r], root[l - 1], 0, sz(nums) - 1, nums[kidx])); } return 0; }
题目内容均收集自互联网,如如若此项内容侵犯了原著者的合法权益,可联系我: (CSDN网站注册用户名: 塔子哥学算法) 进行删除。