Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- int bitPow(int num, int pow) {
- if (pow == 0) return 1;
- if (pow % 2 == 0) {
- int temp = bitPow(num, pow / 2);
- return temp * temp % M;
- }
- return bitPow(num, pow - 1) * num % M;
- }
- // M = 998244353;
- void ntt(vector<int>& a, bool inv) {
- int n = a.size();
- if (n == 1) {
- return;
- }
- vector<int> a0(n / 2), a1(n / 2);
- for (int i = 0; i < n / 2; i++) {
- a0[i] = a[2 * i];
- a1[i] = a[2 * i + 1];
- }
- ntt(a0, 0);
- ntt(a1, 0);
- int w = 1, wn = bitPow(3, (M - 1) / n);
- for (int i = 0; i < n / 2; i++) {
- a[i] = (a0[i] + w * a1[i] % M) % M;
- a[i + n / 2] = (a0[i] - w * a1[i] % M + M) % M;
- w = wn * w % M;
- }
- int nn = bitPow(n, M - 2);
- if (inv) {
- for (int i = 0; i < n; i++) a[i] = a[i] * nn % M;
- reverse(++a.begin(), a.end());
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement