LOJ#3381 函数调用 题解
分析
函数调用,满足不出现递归,本身就构成了一张 DAG。
由于乘法操作是全局乘法,所以可以提前预处理乘法操作,而加法操作最终的贡献,取决于在它之后乘上多少次。
设 \(pls_i\) 为第 \(i\) 个函数执行的加法的值,如果没有的话则为 \(0\),\(id_i\) 为第 \(i\) 个函数要执行加法的元素的编号,没有的话为 \(0\),\(mul_i\) 为第 \(i\) 个函数对全局的乘法值,如果没有的话则为 \(1\)。
设 \(cnt_i\) 为第 \(i\) 个函数加法操作之后又被乘了多少。
对于函数 \(i\) 要调用的函数 \(x\),连边 \((i \rightarrow x)\)。用 DFS 可以预处理出当运行完任意一个函数时,对全局的乘法操作,而函数 \(i\) 必然在 \(x\) 之后操作,因此 \(x\) 加法操作的最终值依赖于 \(i\),因此处理 \(i\) 要早于处理 \(x\),拓扑排序即可。
但是函数的调用是有序的,依次从左到右。因此必须新建立一个节点 \(0\),向所有调用的函数连边,并且在遍历 \(x\) 能到达的所有点时,必须要按照倒序。
DFS 之后,\(mul_0\) 即为全局最终乘上的值,而对于每一个加法函数 \(i\),都会对 \(a_{id_i}\) 产生 \(pls_i \cdot cnt_i\) 的贡献。
CODE
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define pb push_back
int read() {
int a=0, f=1; char c=getchar();
while(!isdigit(c)) {
if(c=='-') f=-1;
c=getchar();
}
while(isdigit(c)) a=a*10+c-'0', c=getchar();
return a*f;
}
const int M=1e6+5, N=1e5+5, mod=998244353;
int n, m, a[N], id[N], pls[N], mul[N], cnt[N];
int in[N], op[N];
bool v[N];
vector<int> p[M];
void dfs(int x) {
if(v[x]) return;
v[x]=1;
// 关于为什么这里要用v数组来判环
// 因为从0连边之前是DAG,但是连边之后就可能不是了
for(auto y:p[x]) {
dfs(y);
(mul[x]*=mul[y])%=mod;
}
}
void toposort() {
queue<int> q;
for(int i=0;i<=m;++i) if(!in[i]) q.push(i);
while(q.size()) {
int x=q.front(); q.pop();
int dlt=1;
for(int i=p[x].size()-1;~i;--i) {
// 倒序遍历
int y=p[x][i];
(cnt[y]+=cnt[x]*dlt)%=mod;
// 这里的层次一定要分清楚
// y的加法操作只会被x影响,要按照拓扑序处理,而不是看全局乘上的值
(dlt*=mul[y])%=mod;
if(--in[y]==0) q.push(y);
}
}
}
signed main() {
freopen("call.in","r",stdin);
freopen("call.out","w",stdout);
n=read();
for(int i=1;i<=n;++i) a[i]=read();
m=read();
for(int i=1;i<=m;++i) {
op[i]=read();
if(op[i]==1) id[i]=read(), pls[i]=read(), mul[i]=1;
else if(op[i]==2) mul[i]=read();
else {
int k=read();
mul[i]=1;
for(int j=1;j<=k;++j) {
int x=read();
p[i].pb(x), ++in[x];
}
}
}
int q=read();
for(int i=1;i<=q;++i) {
int x=read();
p[0].pb(x), ++in[x];
}
cnt[0]=mul[0]=1;
dfs(0);
toposort();
for(int i=1;i<=n;++i) (a[i]*=mul[0])%=mod;
for(int i=1;i<=m;++i) if(op[i]==1) (a[id[i]]+=cnt[i]*pls[i]%mod)%=mod;
for(int i=1;i<=n;++i) printf("%lld%c",a[i]," \n"[i==n]);
}
LOJ#3381 函数调用 题解
https://yozora0908.github.io/2022/loj3381-solution/