想起来好久没写题解了,随便写一下把
感觉写多了div3后面的题就变得简单了,div3似乎没什么思维含量,甚至有时候能开出div3的2100....
心血来潮写一下这个*1800的题解,思路一下就出了,但是一开始多了个log被卡了,提醒一下自己
Problem - G - Codeforces
题意:
思路:
首先做法肯定要给询问排序,那就相当于离线
离线的写法还是有讲究的,最好是把答案全部求出来再去O(1)查询,不然容易被卡
排序之后枚举边的权值,就是把边一条条加上去,相当于Kruskal的过程了
问题是点对的贡献怎么算,很显然可以拆,对于一个连通块,贡献为sz * (sz - 1)
那么对于每个询问是不是都要考虑遍历所有连通块,但是这样是O(nq)的
这个也是套路了,考虑询问之间的变化量,其实就是数据结构多一格的思想
如果加边之后两个连通块变成一个了,贡献的变化量是什么呢,这个可以手推
之前是 x * (x - 1) / 2,y * (y - 1) / 2
之后是 (x + y) * (x + y - 1) / 2
减一减就是 x * y
那么在合并之后贡献 += x * y即可
一开始我写成这样
#include <bits/stdc++.h>
#define int long long
constexpr int N = 2e5 + 10;
constexpr int M = 2e5 + 10;
constexpr int mod = 1e9 + 7;
constexpr int Inf = 0x3f3f3f3f;
std::pair<int, int> q[N];
std::vector<std::array<int, 2> > E[N];
int n, m;
int b[N];
int f[N], sz[N];
int find(int x) {
return f[x] = (x == f[x]) ? x : find(f[x]);
}
void join(int u, int v) {
int f1 = find(u), f2 = find(v);
if (sz[f1] > sz[f2]) std::swap(f1, f2);
if (f1 != f2) {
f[f1] = f2;
sz[f2] += sz[f1];
}
}
void solve() {
std::cin >> n >> m;
for (int i = 1; i <= n; i ++) {
f[i] = i;
sz[i] = 1;
}
for (int i = 1; i <= n - 1; i ++) {
int u, v, w;
std::cin >> u >> v >> w;
E[w].push_back({u, v});
}
for (int i = 1; i <= m; i ++) {
std::cin >> q[i].first;
q[i].second = i;
}
std::sort(q + 1, q + 1 + m);
int ans = 0;
for (int i = 1; i <= m; i ++) {
int w = q[i].first;
for (auto [u, v] : E[w]) {
int f1 = find(u), f2 = find(v);
if (f1 != f2) {
ans += sz[f1] * sz[f2];
}
join(u, v);
}
b[q[i].second] = ans;
}
for (int i = 1; i <= m; i ++) {
std::cout << b[i] << " \n" [i == m];
}
}
signed main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t = 1;
while(t --) {
solve();
}
return 0;
}
这样子 T7了,但是枚举权值就没问题,不知道为什么这样会快,可能是STL的问题....
以后离线就写成把所有答案都求出来再去O(1)查询的形式会好一点,不容易被卡....
#include <bits/stdc++.h>
#define int long long
constexpr int N = 2e5 + 10;
constexpr int M = 2e5 + 10;
constexpr int mod = 1e9 + 7;
constexpr int Inf = 0x3f3f3f3f;
std::vector<std::array<int, 2> > E[N];
int n, m;
int b[N];
int f[N], sz[N];
int find(int x) {
return f[x] = (x == f[x]) ? x : find(f[x]);
}
void join(int u, int v) {
int f1 = find(u), f2 = find(v);
if (sz[f1] > sz[f2]) std::swap(f1, f2);
if (f1 != f2) {
f[f1] = f2;
sz[f2] += sz[f1];
}
}
void solve() {
std::cin >> n >> m;
for (int i = 1; i <= n; i ++) {
f[i] = i;
sz[i] = 1;
}
for (int i = 1; i <= n - 1; i ++) {
int u, v, w;
std::cin >> u >> v >> w;
E[w].push_back({u, v});
}
int ans = 0;
for (int w = 1; w <= 2e5; w ++) {
for (auto [u, v] : E[w]) {
int f1 = find(u), f2 = find(v);
if (f1 != f2) {
ans += sz[f1] * sz[f2];
}
join(u, v);
}
b[w] = ans;
}
int q = 1;
for (int i = 1; i <= m; i ++) {
std::cin >> q;
std::cout << b[q] << " \n" [i == m];
}
}
signed main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t = 1;
while(t --) {
solve();
}
return 0;
}