Jaydeep999997

Unique Strings - NTT

Feb 14th, 2021 (edited)
492
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 10.24 KB | None | 0 0
  1. #include<bits/stdc++.h>
  2. using namespace std;
  3.  
  4.  
  5. #define endl '\n'
  6. #define ff first
  7. #define ss second
  8. #define pb push_back
  9. // #define int long long
  10. #define sz(v) (int)v.size()
  11. #define inf 2147483647
  12. #define llinf 9223372036854775807
  13. #define all(v) v.begin(),v.end()
  14. #define bp(n) __builtin_popcountll(n)
  15. #define f(i,l,r) for(int i=l;i<=r;i++)
  16. #define rf(i,r,l) for(int i=r;i>=l;i--)
  17. #define fast ios_base::sync_with_stdio(false),cin.tie(NULL),cout.tie(NULL)
  18.  
  19. const int N = 5e3 + 5, mod = 1e9 + 7, bit = 61;
  20.  
  21. // Credit : neal
  22.  
  23. // Be careful about overflow as we are using integer everywhere
  24.  
  25. template<const int &MOD>
  26. struct _m_int  // This is our new type which will do operations under modulo MOD
  27. {
  28.     int val;   // Value of variable
  29.  
  30.     _m_int(int64_t v = 0)  // Typecasting into our data type from 64 bit integer
  31.     {
  32.         if (v < 0) v = v % MOD + MOD;
  33.         if (v >= MOD) v %= MOD;
  34.         val = v;
  35.     }
  36.  
  37.     static int mod_inv(int a, int m = MOD)
  38.     {
  39.         // https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Example
  40.         int g = m, r = a, x = 0, y = 1;
  41.  
  42.         while (r != 0)
  43.         {
  44.             int q = g / r;
  45.             g %= r; swap(g, r);
  46.             x -= q * y; swap(x, y);
  47.         }
  48.  
  49.         return x < 0 ? x + m : x;
  50.     }
  51.  
  52.     explicit operator int() const
  53.     {
  54.         return val;
  55.     }
  56.  
  57.     explicit operator int64_t() const
  58.     {
  59.         return val;
  60.     }
  61.  
  62.     _m_int& operator += (const _m_int &other)  // Addition
  63.     {
  64.         val -= MOD - other.val;
  65.         if (val < 0) val += MOD;
  66.         return *this;
  67.     }
  68.  
  69.     _m_int& operator -= (const _m_int &other)  // Subtraction
  70.     {
  71.         val -= other.val;
  72.         if (val < 0) val += MOD;
  73.         return *this;
  74.     }
  75.  
  76.     static unsigned fast_mod(uint64_t x, unsigned m = MOD) // Mod operation
  77.     {
  78. #if !defined(_WIN32) || defined(_WIN64)
  79.         return x % m;
  80. #endif
  81.         // Optimized mod for Codeforces 32-bit machines.
  82.         // x must be less than 2^32 * m for this to work, so that x / m fits in a 32-bit integer.
  83.         unsigned x_high = x >> 32, x_low = (unsigned) x;
  84.         unsigned quot, rem;
  85.         asm("divl %4\n"
  86.             : "=a" (quot), "=d" (rem)
  87.             : "d" (x_high), "a" (x_low), "r" (m));
  88.         return rem;
  89.     }
  90.  
  91.     _m_int& operator*=(const _m_int &other)  // Multiplication
  92.     {
  93.         val = fast_mod((uint64_t) val * other.val);
  94.         return *this;
  95.     }
  96.  
  97.     _m_int& operator/=(const _m_int &other)  // Division
  98.     {
  99.         return *this *= other.inv();
  100.     }
  101.  
  102.     friend _m_int operator+(const _m_int &a, const _m_int &b) { return _m_int(a) += b; }
  103.     friend _m_int operator-(const _m_int &a, const _m_int &b) { return _m_int(a) -= b; }
  104.     friend _m_int operator*(const _m_int &a, const _m_int &b) { return _m_int(a) *= b; }
  105.     friend _m_int operator/(const _m_int &a, const _m_int &b) { return _m_int(a) /= b; }
  106.  
  107.     _m_int& operator++()  // Pre-increment
  108.     {
  109.         val = val == MOD - 1 ? 0 : val + 1;
  110.         return *this;
  111.     }
  112.  
  113.     _m_int& operator--()  // Pre-decrement
  114.     {
  115.         val = val == 0 ? MOD - 1 : val - 1;
  116.         return *this;
  117.     }
  118.  
  119.     _m_int operator++(int) { _m_int before = *this; ++*this; return before; }  // Post increment
  120.     _m_int operator--(int) { _m_int before = *this; --*this; return before; }  // Post decrement
  121.  
  122.     _m_int operator-() const  // Change sign
  123.     {
  124.         return val == 0 ? 0 : MOD - val;
  125.     }
  126.  
  127.     // Boolean operators
  128.     bool operator==(const _m_int &other) const { return val == other.val; }
  129.     bool operator!=(const _m_int &other) const { return val != other.val; }
  130.     bool operator< (const _m_int &other) const { return val <  other.val; }
  131.     bool operator<=(const _m_int &other) const { return val <= other.val; }
  132.     bool operator> (const _m_int &other) const { return val >  other.val; }
  133.     bool operator>=(const _m_int &other) const { return val >= other.val; }
  134.  
  135.     _m_int inv() const  // Calculating inverse
  136.     {
  137.         return mod_inv(val);
  138.     }
  139.  
  140.     _m_int pow(int64_t p) const  // Calculating power
  141.     {
  142.         if (p < 0)
  143.             return inv().pow(-p);
  144.  
  145.         _m_int a = *this, result = 1;
  146.  
  147.         while (p > 0)
  148.         {
  149.             if (p & 1)
  150.             {
  151.                 result *= a;
  152.             }
  153.             a *= a;
  154.             p >>= 1;
  155.         }
  156.  
  157.         return result;
  158.     }
  159.  
  160.     friend ostream& operator<<(ostream &os, const _m_int &m)  // Writing output
  161.     {
  162.         return os << m.val;
  163.     }
  164.  
  165. };
  166.  
  167. extern const int MOD = 1e9 + 7;
  168. using mod_int = _m_int<MOD>;
  169.  
  170.  
  171. template<const int &MOD>
  172. struct NTT
  173. {
  174.     using ntt_int = _m_int<MOD>;
  175.  
  176.     vector<ntt_int> roots = {0, 1};
  177.     vector<int> bit_reverse;
  178.     int max_size = -1;
  179.     ntt_int root;
  180.  
  181.     void reset()
  182.     {
  183.         roots = {0, 1};
  184.         max_size = -1;
  185.     }
  186.  
  187.     static bool is_power_of_two(int n) {  return (n & (n - 1)) == 0;  }
  188.  
  189.     static int round_up_power_two(int n)
  190.     {
  191.         while (n & (n - 1))
  192.             n = (n | (n - 1)) + 1;
  193.  
  194.         return max(n, 1);
  195.     }
  196.  
  197.     // Given n (a power of two), finds k such that n == 1 << k.
  198.     static int get_length(int n)
  199.     {
  200.         assert(is_power_of_two(n));
  201.         return __builtin_ctz(n);
  202.     }
  203.  
  204.     // Rearranges the indices to be sorted by lowest bit first, then second lowest, etc., rather than highest bit first.
  205.     // This makes even-odd div-conquer much easier.
  206.     void bit_reorder(int n, vector<ntt_int> &values)
  207.     {
  208.         if ((int) bit_reverse.size() != n)
  209.         {
  210.             bit_reverse.assign(n, 0);
  211.             int length = get_length(n);
  212.  
  213.             for (int i = 1; i < n; i++)
  214.                 bit_reverse[i] = (bit_reverse[i >> 1] >> 1) | ((i & 1) << (length - 1));
  215.         }
  216.  
  217.         for (int i = 0; i < n; i++)
  218.         {
  219.             if (i < bit_reverse[i])
  220.             {
  221.                 swap(values[i], values[bit_reverse[i]]);
  222.             }
  223.         }
  224.     }
  225.  
  226.     void find_root()
  227.     {
  228.         max_size = 1 << __builtin_ctz(MOD - 1);
  229.         root = 2;
  230.  
  231.         // Find a max_size-th primitive root of MOD.
  232.         while (!(root.pow(max_size) == 1 && root.pow(max_size / 2) != 1))
  233.         {
  234.             root++;
  235.         }
  236.     }
  237.  
  238.     void prepare_roots(int n)
  239.     {
  240.         if (max_size < 0)
  241.             find_root();
  242.  
  243.         assert(n <= max_size);
  244.  
  245.         if ((int) roots.size() >= n)
  246.             return;
  247.  
  248.         int length = get_length(roots.size());
  249.         roots.resize(n);
  250.  
  251.         // The roots array is set up such that for a given power of two n >= 2, roots[n / 2] through roots[n - 1] are
  252.         // the first half of the n-th primitive roots of MOD.
  253.         while (1 << length < n)
  254.         {
  255.             // z is a 2^(length + 1)-th primitive root of MOD.
  256.             ntt_int z = root.pow(max_size >> (length + 1));
  257.  
  258.             for (int i = 1 << (length - 1); i < 1 << length; i++)
  259.             {
  260.                 roots[2 * i] = roots[i];
  261.                 roots[2 * i + 1] = roots[i] * z;
  262.             }
  263.  
  264.             length++;
  265.         }
  266.     }
  267.  
  268.     void fft_iterative(int N, vector<ntt_int> &values)
  269.     {
  270.         assert(is_power_of_two(N));
  271.         prepare_roots(N);
  272.         bit_reorder(N, values);
  273.  
  274.         for (int n = 1; n < N; n *= 2)
  275.         {
  276.             for (int start = 0; start < N; start += 2 * n)
  277.             {
  278.                 for (int i = 0; i < n; i++)
  279.                 {
  280.                     ntt_int even = values[start + i];
  281.                     ntt_int odd = values[start + n + i] * roots[n + i];
  282.                     values[start + n + i] = even - odd;
  283.                     values[start + i] = even + odd;
  284.                 }
  285.             }
  286.         }
  287.     }
  288.  
  289.     void invert_fft(int N, vector<ntt_int> &values)
  290.     {
  291.         ntt_int inv_N = ntt_int(N).inv();
  292.  
  293.         for (int i = 0; i < N; i++)
  294.         {
  295.             values[i] *= inv_N;
  296.         }
  297.  
  298.         reverse(values.begin() + 1, values.end());
  299.         fft_iterative(N, values);
  300.     }
  301.  
  302.     const int FFT_CUTOFF = 150;
  303.  
  304.     // Note: `circular = true` can be used for a 2x speedup when only the `max(n, m) - min(n, m) + 1` fully overlapping
  305.     // ranges are needed. It computes results using indices modulo the power-of-two FFT size; see the brute force below.
  306.     template<typename T>
  307.     vector<T> mod_multiply(const vector<T> &_left, const vector<T> &_right, bool circular = false)
  308.     {
  309.         if (_left.empty() || _right.empty())
  310.             return {};
  311.  
  312.         vector<ntt_int> left(_left.begin(), _left.end());
  313.         vector<ntt_int> right(_right.begin(), _right.end());
  314.  
  315.         int n = left.size();
  316.         int m = right.size();
  317.  
  318.         int output_size = circular ? round_up_power_two(max(n, m)) : n + m - 1;
  319.  
  320.         // Brute force when either n or m is small enough.
  321.         if (min(n, m) < FFT_CUTOFF)
  322.         {
  323.             auto &&mod_output_size = [&](int x)
  324.             {
  325.                 return x < output_size ? x : x - output_size;
  326.             };
  327.  
  328.             static const uint64_t U64_BOUND = numeric_limits<uint64_t>::max() - (uint64_t) MOD * MOD;
  329.             vector<uint64_t> result(output_size, 0);
  330.  
  331.             for (int i = 0; i < n; i++)
  332.             {
  333.                 for (int j = 0; j < m; j++)
  334.                 {
  335.                     int index = mod_output_size(i + j);
  336.                     result[index] += (uint64_t) (int64_t) left[i] * (int64_t) right[j];
  337.  
  338.                     if (result[index] > U64_BOUND)
  339.                         result[index] %= MOD;
  340.                 }
  341.             }
  342.             for (uint64_t &x : result)
  343.                 if (x >= MOD)
  344.                     x %= MOD;
  345.  
  346.             return vector<T>(result.begin(), result.end());
  347.         }
  348.  
  349.         int N = round_up_power_two(output_size);
  350.         left.resize(N, 0);
  351.         right.resize(N, 0);
  352.  
  353.         if (left == right)
  354.         {
  355.             fft_iterative(N, left);
  356.             right = left;
  357.         }
  358.         else
  359.         {
  360.             fft_iterative(N, left);
  361.             fft_iterative(N, right);
  362.         }
  363.  
  364.         for (int i = 0; i < N; i++)
  365.         {
  366.             left[i] *= right[i];
  367.         }
  368.  
  369.         invert_fft(N, left);
  370.         return vector<T>(left.begin(), left.begin() + output_size);
  371.     }
  372.  
  373.     template<typename T>
  374.     vector<T> mod_power(const vector<T> &v, int exponent)
  375.     {
  376.         assert(exponent >= 0);
  377.         vector<T> result = {1};
  378.  
  379.         if (exponent == 0)
  380.             return result;
  381.  
  382.         for (int k = 31 - __builtin_clz(exponent); k >= 0; k--)
  383.         {
  384.             result = mod_multiply(result, result);
  385.  
  386.             if (exponent >> k & 1)
  387.                 result = mod_multiply(result, v);
  388.         }
  389.  
  390.         return result;
  391.     }
  392.  
  393.     template<typename T>
  394.     vector<T> mod_multiply_all(const vector<vector<T>> &polynomials)
  395.     {
  396.         if (polynomials.empty())
  397.             return {1};
  398.  
  399.         struct compare_size
  400.         {
  401.             bool operator()(const vector<T> &x, const vector<T> &y)
  402.             {
  403.                 return x.size() > y.size();
  404.             }
  405.         };
  406.  
  407.         priority_queue<vector<T>, vector<vector<T>>, compare_size> pq(polynomials.begin(), polynomials.end());
  408.  
  409.         while (pq.size() > 1)
  410.         {
  411.             vector<T> a = pq.top(); pq.pop();
  412.             vector<T> b = pq.top(); pq.pop();
  413.             pq.push(mod_multiply(a, b));
  414.         }
  415.  
  416.         return pq.top();
  417.     }
  418. };
  419.  
  420. NTT<MOD> ntt;
  421.  
  422. mod_int fact[N], ifact[N];
  423.  
  424.  
  425. void pre()
  426. {
  427.     fact[0] = 1;
  428.     int i;
  429.     for (i = 1; i < N; i++)
  430.     {
  431.         fact[i] = (fact[i - 1] * i);
  432.     }
  433.     ifact[N - 1] = fact[N - 1].inv();
  434.     for (i = N - 2; i >= 0; i--)
  435.     {
  436.         ifact[i] = (ifact[i + 1] * (i + 1));
  437.     }
  438. }
  439.  
  440. signed main()
  441. {
  442.     fast;
  443.  
  444.     pre();
  445.     string s;
  446.     int mp[26] = {0};
  447.     cin >> s;
  448.     for (auto &x : s)
  449.     {
  450.         mp[x - 'a']++;
  451.     }
  452.     vector<mod_int> res{1};
  453.     f(i, 0, 25)
  454.     {
  455.         if (mp[i] == 0)
  456.         {
  457.             continue;
  458.         }
  459.         vector<mod_int> now;
  460.         f(j, 0, mp[i])
  461.         {
  462.             now.pb(ifact[j]);
  463.         }
  464.         res = ntt.mod_multiply(res, now);
  465.     }
  466.     int n = sz(s);
  467.     n = min(n, sz(res));
  468.     mod_int ans = 0;
  469.     f(i, 1, n)
  470.     {
  471.         ans += (fact[i] * res[i]);
  472.     }
  473.     cout << ans;
  474.     return 0;
  475. }
Add Comment
Please, Sign In to add comment