luogu9221 「TAOI-1」Pentiment 题解

先考虑部分分怎么打。

根据个人习惯,规定下文中「直角蛇」是从最上面一行到达最下面一行。

subtask 2

不妨这样考虑:到达每一行后,都可以通过左右移动,到达下一行的任意一个位置。到达第一行的位置可以任选;从任意位置到达最后一行后,可以再通过左右移动,在任意位置结束。所以答案就是 \[ m^{n+1} \]

subtask 1 and 3

这两个子任务都可以用 \(O(nm)\) 的做法解决。

结合 subtask 2,我们对每一行考虑。发现第 \(i-1\) 行和第 \(i\) 行本质上是「输送与接收」的关系。

不妨称不能走的节点为关键点,我们能发现如下性质。

  • 关键点把第 \(i-1\) 行和第 \(i\) 行划分成了若干个连续段,只有当两行的连续段有交时,才能完成方案的传递。

  • 到达同一个连续段内任意节点的方案数是相等的。这个容易理解,我们到达这一段后可以左右任意走。

这两个性质启发我们这样做:

\(f(i,j)\) 为直角蛇到达 \((i,j)\) 的方案数。对于第 \(i\) 行,如果 \((i,j)\) 不是关键点,我们就把 \(f(i-1,j)\) 的方案下传到 \(f(i,j)\)。然后对 \(j\) 这一维做前缀和,最后扫一遍,每个点的方案就是其所在连续段的总和。

使用滚动数组,时间和空间都可以接受。

const int mod=998244353;
int n, m, q;
bitset<10002> v[10002];
bitset<1000002> v1[3];
int f[2][10002], r[10002], s[10002];
namespace sub3 {
    // 这是subtask3的代码。
    // 如果要求解subtask1,那么就交换行和列,用v1做标记即可。
	void solve() {
		rep(i,1,n) {
			int lst=m;
			per(j,m,1) {
				if(v[i][j]) r[j]=j, lst=j-1;
				else r[j]=lst;
			}
			rep(j,1,m) {
				f[i&1][j]=0;
				if(!v[i][j]) {
					if(i==1) s[j]=1; else s[j]=f[(i-1)&1][j];
					s[j]=(1ll*s[j]+s[j-1])%mod;
				}
			}
			int L=0, R=0;
			rep(j,1,m) {
				R=r[j];
				if(!v[i][j]) f[i&1][j]=(s[R]-s[L]+mod)%mod;
				else L=j, f[i&1][j]=0;
			}
		}
		int L=0, R=0, ans=0;
		rep(i,1,m) ans=(1ll*ans+f[n&1][i])%mod;
		printf("%d\n",ans);
	}
};

subtask 6

我们发现方案的下传像是在做类似于区间合并的东西。由于本人太菜,不会使用数据结构,所以就对着题解中的离线做法写了,就此学习一下此类问题的处理方法。

对于连续的不存在关键点的行,其方案数是容易求出的,所以有用的只有不同行的关键点。对此我们可以将所有关键点按照横坐标排序。

先不考虑连续空行的情况。我们应该先找到一个临界点 \(i\),满足 \(i\)\(i+1\) 在不同行,再维护一个上一行的末尾位置 \(k\)。这样我们就能得到这一整行的信息了,同时要维护上一行的区间以及区间内每个位置的方案数(每个位置的方案数都相等)。

struct node {
	int l, r; ll x;
    // 区间[l,r],每个点的方案数都是x
	node() {};
	node(int _l,int _r,ll _x) { l=_l, r=_r, x=_x; }
} f[N];

考虑如何区间合并。对于当前行的区间,只要上一行的某个区间与其有交,方案就能下传,看起来不很好做。但反过来想,如果当前行被上一行两个区间下传方案了,那么说明一定有关键点把上面那两个区间隔开。也就是说,如果我们把上一行的关键点当作当前行的关键点,这样得到的区间一定会被上一行唯一确定的一个区间下传方案。

我们先把两行的关键点都存下来,排序后去重。

对于两个关键点确定的一个区间 \([L,R]\),我们找到之前合并完的第一个与这个区间有交的区间,\([L,R]\) 的方案数就是 \(R-L+1\) 乘那个区间的方案。注意要把边界 \(0\)\(m+1\) 都加入。

p.clear();
int tot=0;
if(a[k].fi+1==a[i].fi) {
    for(int j=k;j&&a[j].fi==a[k].fi;--j) b[++tot]=a[j].se;
}
for(int j=k+1;j<=i;++j) b[++tot]=a[j].se, p[a[j].se]=1;
b[++tot]=0, b[++tot]=m+1;
uniq(b,tot);

vector<node> v;
for(int j=1,pos=1;j<tot;++j) {
	if(b[j]&&!p.count(b[j])) v.pb(node(b[j],b[j],0ll));
	int L=b[j]+1, R=b[j+1]-1;
	if(L>R) continue;
	while(pos<=cnt&&f[pos].r<L) ++pos;
	ll sum=f[pos].x*(R-L+1)%mod;
	v.pb(node(L,R,sum));
}

对于合并,只需要将上一行的关键点当作连接区间的桥梁。具体地,我们将通过上述做法得到的区间都存下来,然后把上一行的关键点当作长度为 \(1\) 的区间加进去。合并时,只需要扫一边所有区间,根据端点判断是否可以合并即可,方案数就直接累加。

cnt=0;
for(auto xx:v) {
	if(!cnt||f[cnt].r!=xx.l-1) f[++cnt]=xx;
	else f[cnt].r=xx.r, (f[cnt].x+=xx.x)%=mod;
}

对于连续的没有关键点的行,我们只需要先将方案下放并累加,然后乘 \(m\) 的对应次幂,最后只留下 \([1,m]\)

if(a[k].fi+1!=a[i].fi) {
	int dx=a[i].fi-a[k].fi-2;
    ll sum=0;
	rep(j,1,cnt) (sum+=f[j].x*(f[j].r-f[j].l+1)%mod)%=mod;
	(sum*=fp(m,dx))%=mod;
	f[1]=node(1,m,sum);
	cnt=1;
}

做完之后,最后一定还剩下若干没有关键点的行,照做即可。

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define uint unsigned long long
#define PII pair<int,int>
#define MP make_pair
#define fi first
#define se second
#define pb emplace_back
#define SET(a,b) memset(a,b,sizeof(a))
#define CPY(a,b) memcpy(a,b,sizeof(b))
#define rep(i,j,k) for(int i=(j);i<=(k);++i)
#define per(i,j,k) for(int i=(j);i>=(k);--i)
int read() {
	int a=0, f=1; char c=getchar();
	while(!isdigit(c)) {
		if(c=='-') f=-1;
		c=getchar();
	}
	while(isdigit(c)) a=a*10+c-'0', c=getchar();
	return a*f;
}
const int N=1e5+5, mod=998244353;
int n, m, q, cnt, ans, b[N], t[N];
PII a[N];
map<int,bool> p;
struct node {
	int l, r; ll x;
	node() {};
	node(int _l,int _r,ll _x) { l=_l, r=_r, x=_x; }
} f[N];
ll fp(int a,int b) {
	ll c=1;
	for(;b;a=1ll*a*a%mod,b>>=1) if(b&1) c=c*a%mod;
	return c;
}
void uniq(int* b,int& tot) {
	sort(b+1,b+tot+1);
	tot=unique(b+1,b+tot+1)-b-1;
}

signed main() {
	n=read(), m=read(), q=read();
	rep(i,1,q) a[i].fi=read(), a[i].se=read();
    int k=0;
	f[++cnt]=node(1,m,1);
	rep(i,1,q) if(i==q||a[i].fi!=a[i+1].fi) {
		if(a[k].fi+1!=a[i].fi) {

			int dx=a[i].fi-a[k].fi-2; ll sum=0;
			rep(j,1,cnt) (sum+=f[j].x*(f[j].r-f[j].l+1)%mod)%=mod;
			(sum*=fp(m,dx))%=mod;
			f[1]=node(1,m,sum);
			cnt=1;
		}
		
		p.clear();
	
		int tot=0;
		if(a[k].fi+1==a[i].fi) for(int j=k;j&&a[j].fi==a[k].fi;--j) b[++tot]=a[j].se;
	
		for(int j=k+1;j<=i;++j) b[++tot]=a[j].se, p[a[j].se]=1;
		
		b[++tot]=0, b[++tot]=m+1;
		uniq(b,tot);

	
		vector<node> v;
		for(int j=1,pos=1;j<tot;++j) {
			if(b[j]&&!p.count(b[j])) v.pb(node(b[j],b[j],0ll));
	
			int L=b[j]+1, R=b[j+1]-1;

			if(L>R) continue;
			while(pos<=cnt&&f[pos].r<L) ++pos;

			ll sum=f[pos].x*(R-L+1)%mod;
			v.pb(node(L,R,sum));
		}

		cnt=0;
		for(auto xx:v) {
			if(!cnt||f[cnt].r!=xx.l-1) f[++cnt]=xx;
			else f[cnt].r=xx.r, (f[cnt].x+=xx.x)%=mod;

		}
		k=i;
	}
	int dx=n-a[k].fi;
    ll sum=0;
	rep(i,1,cnt) (sum+=f[i].x*(f[i].r-f[i].l+1)%mod)%=mod;
	printf("%lld\n",sum*fp(m,dx)%mod);
}

计数部分不难,难在对区间的处理。

怎么说,在考场上,除非时间很充足并且有很大把握,否则是不会去写这种题的。

但过一遍这道题也有所收获。

初看这道题,很容易与某组合典题联系起来,从而想到利用关键点去容斥。尽管正解不是这样做,但最终也需要在关键点上下功夫,算是完善一下科技树并锻炼代码能力了。


luogu9221 「TAOI-1」Pentiment 题解
https://yozora0908.github.io/2023/lg9221-solution/
作者
yozora0908
发布于
2023年8月30日
许可协议