题目描述
简要题意:太长了,就不总结了,自己看吧。
分析
我们首先考虑
m
=
1
m = 1
m=1 的情况:
T
>
0
T > 0
T>0 时,显然我们可以
O
(
n
)
O(n)
O(n) 的维护一个 前缀积 和 前缀积的逆元,然后每次询问
O
(
1
)
O(1)
O(1) 得到答案就好了。时间复杂度
O
(
n
)
O(n)
O(n)。
T
=
0
T = 0
T=0 时,我们需要在
O
(
n
)
O(n)
O(n) 的复杂度算出答案。设
d
p
i
dp_i
dpi 表示终点是
i
i
i 的所有路径的答案,那么转移就是
d
p
i
=
d
p
i
−
1
×
a
i
+
a
i
dp_i = dp_{i - 1} \times a_i + a_i
dpi=dpi−1×ai+ai。然后把所有
d
p
i
dp_i
dpi 累加起来就好了。
然后我们考虑 m = 2 m = 2 m=2 的情况:
因为有一条限制是路径上的位置 列数要单调不降,所以我们 以列为阶段 dp,设 d p i , 0 / 1 dp_{i, 0/1} dpi,0/1 表示当起点固定,从起点到第 0 / 1 0/1 0/1 行,第 i i i 列的所有合法路径的答案。但是发现在同一列的时候可以向上走,这样的话同一列就能相互转移了。好像有后效性?
我们考虑怎样把这个后效性去掉。我们直接列出 d p i , j dp_{i, j} dpi,j 关于 第 i − 1 i - 1 i−1 列的转移:
d p i , j = d p i − 1 , j ∗ a j , i + d p i − 1 , 1 − j ∗ a 1 − j , i ∗ a j , i dp_{i, j} = dp_{i - 1, j} * a_{j, i} + dp_{i - 1, 1 - j} * a_{1 - j, i} * a_{j, i} dpi,j=dpi−1,j∗aj,i+dpi−1,1−j∗a1−j,i∗aj,i。
相当于我们是以 第 i − 1 i - 1 i−1 列是由哪个位置到达第 i i i 列 划分,这样能够保证不重不漏,并且没有后效性。
有了这个式子,我们可以在 O ( T n ) O(Tn) O(Tn) 的复杂度内处理 T > 0 T > 0 T>0 的答案。在 O ( n 2 ) O(n^2) O(n2) 的复杂度内处理出 T = 0 T = 0 T=0 的答案。并且当 T ≤ 1 0 5 , n ≤ 5000 T \leq 10^5,n \leq 5000 T≤105,n≤5000 时。我们可以预处理出来任意两点作为起点和终点的答案。时间复杂度是 O ( n 2 ) O(n^2) O(n2)。
这样我们就有了 45 p t s 45pts 45pts。
接下来我们思考正解:
不难发现,转移式满足 矩阵乘法 的运算规则,并且需要加速阶段,每一阶段状态数很少,符合矩阵乘法优化DP的特点。
设 [ d p 0 d p 1 ] \begin{bmatrix}dp_0 & dp_1 \end{bmatrix} [dp0dp1] 表示到某一列第 0 / 1 0 / 1 0/1 行的 dp 值。那么对于第 i i i 列而言,可以构建一下伴随矩阵:
[ a 0 , i a 0 , i × a 1 , i a 1 , i × a 0 , i a 1 , i ] \begin{bmatrix} a_{0,i} & a_{0, i} \times a_{1, i}\\ a_{1, i} \times a_{0, i} & a_{1, i} \end{bmatrix} [a0,ia1,i×a0,ia0,i×a1,ia1,i]
能够发现不同列的伴随矩阵是不同的,我们使用线段树维护区间矩阵的乘积即可。
那么 T > 0 T > 0 T>0 的情况就可以在 O ( T l o g 2 n ) O(Tlog_2n) O(Tlog2n) 的复杂度内解决。
至于 T = 0 T = 0 T=0 的情况,我们可以在矩阵中多维护一维 S S S 表示固定一个起点时,从起点到当前列上一列的所有路径的和。伴随矩阵变成一个下列一个 3 × 3 3 \times 3 3×3 的矩阵:
[ a 0 , i a 0 , i × a 1 , i 1 a 1 , i × a 0 , i a 1 , i 1 0 0 1 ] \begin{bmatrix} a_{0,i} & a_{0, i} \times a_{1, i} & 1\\ a_{1, i} \times a_{0, i} & a_{1, i} & 1\\ 0 & 0 & 1 \end{bmatrix} a0,ia1,i×a0,i0a0,i×a1,ia1,i0111
然后每次枚举起点算出答案并累加就好了。时间复杂度 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n)。
CODE:
#include<bits/stdc++.h>// 动态DP
using namespace std;
const int N = 1e5 + 10;
typedef long long LL;
const LL mod = 1e9 + 7;
inline int read(){
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)){if(c == '-') f = -1; c = getchar();}
while(isdigit(c)){x = (x << 1) + (x << 3) + (c ^ 48); c = getchar();}
return x * f;
}
struct matrix{
LL mat[3][3];
friend matrix operator * (matrix a, matrix b){
matrix c;
memset(c.mat, 0, sizeof c.mat);
for(int i = 0; i < 3; i++)
for(int j = 0; j < 3; j++)
for(int k = 0; k < 3; k++)
c.mat[i][j] = (c.mat[i][j] + a.mat[i][k] * b.mat[k][j]) % mod;
return c;
}
};
struct SegmentTree{
int l, r; matrix c;
#define l(x) t[x].l
#define r(x) t[x].r
#define c(x) t[x].c
}t[N * 4];
int n, m, T, sx, sy, ex, ey;
LL a[2][N];
void update(int p){c(p) = c(p << 1) * c(p << 1 | 1);}
void build(int p, int l, int r){
l(p) = l, r(p) = r;
if(l == r){
c(p).mat[0][0] = a[0][l] % mod; c(p).mat[1][0] = (a[0][l] * a[1][l]) % mod; c(p).mat[2][0] = 0;
c(p).mat[0][1] = (a[0][l] * a[1][l]) % mod; c(p).mat[1][1] = a[1][l] % mod; c(p).mat[2][1] = 0;
c(p).mat[0][2] = 1; c(p).mat[1][2] = 1; c(p).mat[2][2] = 1;
return ;
}
int mid = (l + r >> 1);
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
update(p);
}
matrix ask(int p, int l, int r){
if(l <= l(p) && r >= r(p)) return c(p);
int mid = (l(p) + r(p) >> 1);
if(r <= mid) return ask(p << 1, l, r);
else if(l > mid) return ask(p << 1 | 1, l, r);
else return ask(p << 1, l, r) * ask(p << 1 | 1, l, r);
}
LL query(int sx, int sy, int ex, int ey){
matrix res;
res.mat[0][sx] = a[sx][sy];
res.mat[0][1 - sx] = (a[sx][sy] * a[1 - sx][sy]) % mod;
if(sy != ey){
matrix tmp = ask(1, sy + 1, ey);
res = res * tmp;
}
return res.mat[0][ex];
}
LL solve(){
LL res = 0;
for(int i = 1; i <= n; i++){//枚举起点
for(int j = 0; j <= 1; j++){
matrix tmp;
tmp.mat[0][j] = a[j][i];
tmp.mat[0][1 - j] = (a[j][i] * a[1 - j][i]) % mod;
tmp.mat[0][2] = 0;
if(i != n){
matrix cur = ask(1, i + 1, n);
tmp = tmp * cur;
}
res = (res + tmp.mat[0][2] + tmp.mat[0][0] + tmp.mat[0][1]) % mod;
}
}
return res;
}
int main(){
m = read(), n = read(), T = read();
for(int i = 0; i < m; i++)
for(int j = 1; j <= n; j++)
a[i][j] = 1LL * read() % mod;
build(1, 1, n);//构建矩阵
if(T){
while(T--){
sx = read(), sy = read(), ex = read(), ey = read();
sx--, ex--;
printf("%lld\n", query(sx, sy, ex, ey));
}
}
else printf("%lld\n", solve());
return 0;
}