题解 P5393 [模板]下降幂多项式转普通多项式

题面
考虑转成点值表达再快速插值
点值的指数生成函数就是系数的生成函数卷上一个系数为阶乘逆元的函数
$O(nlogn)$
然后快速插值就好了
这里有一个优化:本来快速插值有一个多点求值的,这里直接用逆元$O(n)$算就好了
$O(nlogn)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 800010;
const int mod = 998244353, G = 3, Ginv = (998244353 + 1) / 3;
int rev[MAXN];
inline void get_rev(int l) {
for(int i = 1; i < (1 << l); i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
inline int power(int a, int b) {
ll res = a, ans = 1;
for(; b; b >>= 1, res = res * res % mod) if(b & 1) ans = ans * res % mod;
return ans;
}
int gg[MAXN], ggi[MAXN];
inline void ntt(int *a, int l, int type) {
for(int i = 0; i < (1 << l); i++)
if(rev[i] < i) a[i] ^= a[rev[i]] ^= a[i] ^= a[rev[i]];
for(register int p = 1; p < (1 << l); p <<= 1) {
int wn = (type == 1 ? gg[p] : ggi[p]);
for(register int s = 0; s < (1 << l); s += p << 1) {
int w = 1;
for(register int i = 0; i < p; i++) {
int h1 = a[s + i], h2 = 1ll * w * a[s + p + i] % mod;
a[s + i] = (h1 + h2) % mod;
a[s + p + i] = (h1 - h2 + mod) % mod;
w = 1ll * w * wn % mod;
}
}
}
if(type == -1) {
int inv = power(1 << l, mod - 2);
for(register int i = 0; i < (1 << l); i++) a[i] = 1ll * a[i] * inv % mod;
}
}
inline void add(int *a, int sizea, int *b, int sizeb, int *c) {
for(register int i = 0; i < max(sizea, sizeb); i++)
c[i] = ((i < sizea ? a[i] : 0) + (i < sizeb ? b[i] : 0)) % mod;
}
int f[MAXN], g[MAXN], h[MAXN];
inline void mult(int *a, int sizea, int *b, int sizeb, int *c) {
register int l = 0;
for(; (1 << l) < sizea + sizeb - 1; l++);
get_rev(l);
for(register int i = 0; i < (1 << l); i++) {
f[i] = (i < sizea ? a[i] : 0);
g[i] = (i < sizeb ? b[i] : 0);
}
ntt(f, l, 1);
ntt(g, l, 1);
for(int i = 0; i < (1 << l); i++)
f[i] = 1ll * f[i] * g[i] % mod;
ntt(f, l, -1);
for(register int i = 0; i < sizea + sizeb - 1; i++) c[i] = f[i];
}
int *p[MAXN];
int *sta[30], cnt;
int y[MAXN];
void get_p(int now, int l, int r) {
if(l == r) {
p[now] = (int*)malloc(8);
p[now][0] = (mod - l) % mod;
p[now][1] = 1;
return;
}
int mid = (l + r) >> 1;
get_p(now << 1, l, mid);
get_p(now << 1 | 1, mid + 1, r);
p[now] = (int*) malloc(sizeof(int[r - l + 2]));
mult(p[now << 1], mid - l + 2, p[now << 1 | 1], r - mid + 1, p[now]);
}
void inter(int now, int l, int r, int *ans) {
if(l == r) {
ans[0] = y[l];
return;
}
cnt++;
if(sta[cnt] == 0) sta[cnt] = (int*)malloc(sizeof(int[r - l + 1]));
int *ans1 = sta[cnt], mid = (l + r) >> 1;
inter(now << 1, l, mid, ans1);
mult(ans1, mid - l + 1, p[now << 1 | 1], r - mid + 1, ans);
inter(now << 1 | 1, mid + 1, r, ans1);
mult(ans1, r - mid, p[now << 1], mid - l + 2, ans1);
add(ans, r - l + 1, ans1, r - l + 1, ans);
cnt--;
}
int n;
int ar[MAXN];
int a[MAXN], pi[MAXN];
int main() {
for(int i = 1; i < MAXN; i <<= 1) gg[i] = power(G, (mod - 1) / 2 / i);
for(int i = 1; i < MAXN; i <<= 1) ggi[i] = power(Ginv, (mod - 1) / 2 / i);
scanf("%d", &n);
for(int i = 0; i < n; i++) scanf("%d", a + i);
pi[0] = 1;
for(int i = 1; i < n; i++) pi[i] = 1ll * pi[i - 1] * i % mod;
pi[n - 1] = power(pi[n - 1], mod - 2);
for(int i = n - 1; i > 0; i--) pi[i - 1] = 1ll * pi[i] * i % mod;
mult(a, n, pi, n, y);
for(int i = 0, now = 1; i < n; i++) y[i] = 1ll * y[i] * now % mod, now = 1ll * now * (i + 1) % mod;
// for(int i = 0; i < n; i++) printf("%d ", y[i]);
// puts("");
for(int i = 0; i < n; i++) y[i] = 1ll * y[i] * pi[i] % mod * ((n - i) & 1 ? pi[n - i - 1] : mod - pi[n - i - 1]) % mod;
get_p(1, 0, n - 1);
inter(1, 0, n - 1, ar);
for(int i = 0; i < n; i++) printf("%d ", ar[i]);
puts("");
return 0;
}