C-Tokitsukaze and a+b=n (hard)_2023牛客寒假算法基础集训营2 (nowcoder.com)
题目描述
Tokitsukaze有一个整数n,以及m个区间[L, R]。
她想知道有多少种选法,满足:从m个区间中选择两个区间[L; R],[Lj;,R](i≠j),并从第一个区间选择一个整数a(Li≤a<R;),从第二个区间选择一个整数b(Lj≤b<R;),使得a+b=n。
对于两种选法,若i,j, a,b中有任意一个数不同,则算作不同的选法。
由于答案可能很大,请输出对998244 353取模后的结果。
输入描述:
第一行包含两个整数n, m (2≤n, m ≤4·105)。
接下来m行,每行包含两个整数L,R (1≤L≤R≤2·105)。
输出描述:
输出一个整数表示答案对998 244353 取模后的结果。
示例1
输入
复制
5 3
1 3
2 4
3 5
输出
12
说明
样例1解释:
选择第1个与第2个区间,即〔1,3]和2,4]时,共有3种选法,分别是: (1+4),(2+3),(3+ 2) ;
选择第1个与第3个区间,即[1,3]和[3,5]时,共有2种选法,分别是: (1+4),(2+3) ;
选择第2个与第1个区间,即[2,4和[1,3]时,共有3种选法,分别是: (2+3),(3+2),(4 +1) ;
选择第2个与第3个区间,即[2,4]和[3,5]时,共有1种选法,分别是: (2+3);
选择第3个与第1个区间,即[3,5]和1,3]时,共有⒉种选法,分别是: (3+2),(4+1
);
选择第3个与第2个区间,即「3,5]和「2,4时,共有1种选法,分别是: (3+2)。
所以总共是3+2+3+1+2+1=12种选法。
题解:
我们看到这么多区间,而范围只有2e5我们就应该联想到差分
我们记录所有区间,记录差分数组,最后可以得到每1~2e5中每个数字出现了多少次
答案应该就是
ans = (ans + (sum[n - i] * sum[i]) % mod) % mod;
但是我们少考虑了一个问题,就是这样的话身同一个区间也可能会对答案有贡献,所以我们应该
提前减去这部分贡献
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <numeric>
#include <cstring>
#include <cmath>
#include <map>
#include <unordered_map>
#include <bitset>
#include <set>
#include <random>
#include <ctime>
#include <queue>
#include <stack>
#include <climits>
#define buff \
ios::sync_with_stdio(false); \
cin.tie(0);
#define int long long
#define ll long long
#define PII pair<int, int>
#define px first
#define py second
typedef std::mt19937 Random_mt19937;
Random_mt19937 rnd(time(0));
using namespace std;
const int mod = 998244353;
const int inf = 2147483647;
const int N = 400009;
// int Mod(int a,int mod){return (a%mod+mod)%mod;}
// int lowbit(int x){return x&-x;}//最低位1及其后面的0构成的数值
// int qmi(int a, int k, int p){int res = 1 % p;while (k){if (k & 1) res = Mod(res * a , p);a = Mod(a * a , p);k >>= 1;}return res;}
// int inv(int a,int mod){return qmi(a,mod-2,mod);}
// int lcm(int a,int b){return a*b/__gcd(a,b);}
int n, m, sum[N];
void solve()
{
cin >> n >> m;
int ans = 0;
for (int i = 1; i <= m; i++)
{
int l, r;
cin >> l >> r;
sum[l]++, sum[r + 1]--;
int a = l, b = r, c = l, d = r;
ans -= max(0ll, min(n - a, d) - max(c, n- b) + 1);
cout << ans <<"\n";
}
for (int i = 1; i < N; i++)
sum[i] = (sum[i - 1] + sum[i]) % mod;
for (int i = 1; i <= n; i++)
ans = (ans + (sum[n - i] * sum[i]) % mod) % mod;
ans = (ans % mod + mod) % mod;
cout << ans << '\n';
}
signed main()
{
buff;
solve();
}