题解 P5394 [模板]下降幂多项式乘法

题面
和普通多项式乘法一样,考虑点值
直接0~n-1的点值比较方便
然后颓颓柿子,发现点值的指数生成函数就是系数的普通生成函数卷上一个系数为阶乘逆元的函数
于是$O(nlogn)$做就好了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 800010;
const int mod = 998244353, G = 3, Ginv = (998244353 + 1) / 3;
int rev[MAXN];
inline void get_rev(int l) {
for(int i = 1; i < (1 << l); i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
inline int power(int a, int b) {
ll res = a, ans = 1;
for(; b; b >>= 1, res = res * res % mod) if(b & 1) ans = ans * res % mod;
return ans;
}
int gg[MAXN], ggi[MAXN];
inline void ntt(int *a, int l, int type) {
for(int i = 0; i < (1 << l); i++)
if(rev[i] < i) a[i] ^= a[rev[i]] ^= a[i] ^= a[rev[i]];
for(register int p = 1; p < (1 << l); p <<= 1) {
int wn = (type == 1 ? gg[p] : ggi[p]);
for(register int s = 0; s < (1 << l); s += p << 1) {
int w = 1;
for(register int i = 0; i < p; i++) {
int h1 = a[s + i], h2 = 1ll * w * a[s + p + i] % mod;
a[s + i] = (h1 + h2) % mod;
a[s + p + i] = (h1 - h2 + mod) % mod;
w = 1ll * w * wn % mod;
}
}
}
if(type == -1) {
int inv = power(1 << l, mod - 2);
for(register int i = 0; i < (1 << l); i++) a[i] = 1ll * a[i] * inv % mod;
}
}
inline void add(int *a, int sizea, int *b, int sizeb, int *c) {
for(register int i = 0; i < max(sizea, sizeb); i++)
c[i] = ((i < sizea ? a[i] : 0) + (i < sizeb ? b[i] : 0)) % mod;
}
int f[MAXN], g[MAXN], h[MAXN];
inline void mult(int *a, int sizea, int *b, int sizeb, int *c) {
register int l = 0;
for(; (1 << l) < sizea + sizeb - 1; l++);
get_rev(l);
for(register int i = 0; i < (1 << l); i++) {
f[i] = (i < sizea ? a[i] : 0);
g[i] = (i < sizeb ? b[i] : 0);
}
ntt(f, l, 1);
ntt(g, l, 1);
for(int i = 0; i < (1 << l); i++)
f[i] = 1ll * f[i] * g[i] % mod;
ntt(f, l, -1);
for(register int i = 0; i < sizea + sizeb - 1; i++) c[i] = f[i];
}
int n, m;
int a[MAXN], b[MAXN], c[MAXN], fac[MAXN], inv[MAXN], dinv[MAXN];
int da[MAXN], db[MAXN], dc[MAXN];
int main() {
for(int i = 1; i < MAXN; i <<= 1) gg[i] = power(G, (mod - 1) / 2 / i);
for(int i = 1; i < MAXN; i <<= 1) ggi[i] = power(Ginv, (mod - 1) / 2 / i);
scanf("%d%d", &n, &m);
n++, m++;
for(int i = 0; i < n; i++) scanf("%d", a + i);
for(int i = 0; i < m; i++) scanf("%d", b + i);
fac[0] = 1;
for(int i = 1; i < n + m - 1; i++) fac[i] = 1ll * i * fac[i - 1] % mod;
inv[n + m - 2] = power(fac[n + m - 2], mod - 2);
for(int i = n + m - 2; i > 0; i--) inv[i - 1] = 1ll * i * inv[i] % mod;
mult(a, n, inv, n + m - 1, da);
mult(b, m, inv, n + m - 1, db);
for(int i = 0; i < n + m - 1; i++) dc[i] = 1ll * da[i] * db[i] % mod * fac[i] % mod;
for(int i = 0; i < n + m - 1; i++) dinv[i] = (i & 1 ? mod - inv[i] : inv[i]) % mod;
mult(dc, n + m - 1, dinv, n + m - 1, c);
for(int i = 0; i < n + m - 1; i++) printf("%d%c", c[i], " \n"[i == n + m - 2]);
return 0;
}