Educational Codeforces Round 57 (Rated for Div. 2) G. Lucky Tickets

NTTのお手本みたいな問題だった。
Problem - G - Codeforces

問題

nを偶数とする。長さnの数字列は、先頭n/2個の和と末尾n/2個の和が等しい時luckyであると呼ばれる。
数字列に出現できる数の集合{d_1, ..., d_k}が決まっている時、luckyな数字列の総数をmod 998244353で求めよ。

  • 2 <= n <= 2 * 10^5
  • k <= 10
  • d_kは相異なる0から9までの数字

解法

長さn/2の数字列の和としてあり得るものと、その場合の数を求めれば良い。
DP[今の位置][今の総和]というDPを考えると、自然に数え上げナップサックの形になる。
数え上げナップサックは自然に多項式の積に帰着できることが多い。今回も帰着でき、(x^{d_1} + ... + x^{d_k})^{n/2}の係数を計算すれば良いことになる。
これは順NTT -> 各値をn/2乗 -> 逆NTTで計算できる。

実装上の注意点

  • NTTを使う時はmod がNTT-friendlyか確認する。今回は998244353 = 119 * 2^23 + 1なので、2^23要素までのNTTが可能である。
  • 扱う多項式の次数の最大値を正しく見積もる。今回は最大9次の多項式を最大100000乗するので、2^20-1次まで管理できれば十分である。よって、2^20要素のNTTで十分である。

提出: Submission #47673213 - Codeforces (Rust)

// ModInt省略

mod fft {
    use std::ops::*;
    /// n should be a power of 2. zeta is a primitive n-th root of unity.
    /// one is unity
    /// Note that the result should be multiplied by 1/sqrt(n).
    pub fn transform<R>(f: &mut [R], zeta: R, one: R)
        where R: Copy +
        Add<Output = R> +
        Sub<Output = R> +
        Mul<Output = R> {
        let n = f.len();
        assert!(n.is_power_of_two());
        {
            let mut i = 0;
            for j in 1 .. n - 1 {
                let mut k = n >> 1;
                loop {
                    i ^= k;
                    if k <= i { break; }
                    k >>= 1;
                }
                if j < i { f.swap(i, j); }
            }
        }
        let mut zetapow = Vec::new();
        {
            let mut m = 1;
            let mut cur = zeta;
            while m < n {
                zetapow.push(cur);
                cur = cur * cur;
                m *= 2;
            }
        }
        let mut m = 1;
        while m < n {
            let base = zetapow.pop().unwrap();
            let mut r = 0;
            while r < n {
                let mut w = one;
                for s in r .. r + m {
                    let u = f[s];
                    let d = f[s + m] * w;
                    f[s] = u + d;
                    f[s + m] = u - d;
                    w = w * base;
                }
                r += 2 * m;
            }
            m *= 2;
        }
    }
}

fn solve() {
    let out = std::io::stdout();
    let mut out = BufWriter::new(out.lock());
    macro_rules! puts {
        ($($format:tt)*) => (write!(out,$($format)*).unwrap());
    }
    input! {
        n: usize,
        k: usize,
        d: [usize; k],
    }
    const N: usize = 1 << 20;
    let one = ModInt::new(1);
    let mut f = vec![ModInt::new(0); N];
    let zeta = ModInt::new(3).pow((MOD - 1) / N as i64);
    for i in 0 .. k {
        f[d[i]] += one;
    }
    fft::transform(&mut f, zeta, one);
    for i in 0 .. N {
        f[i] = f[i].pow(n as i64 / 2);
    }
    fft::transform(&mut f, zeta.inv(), one);
    let factor = ModInt::new(N as i64).inv();
    for i in 0 .. N {
        f[i] *= factor;
    }
    let mut tot = ModInt::new(0);
    for i in 0 .. N {
        tot += f[i].pow(2);
    }
    puts!("{}\n", tot);
}

まとめ

世界がすべて多項式に見えるタイプではない人はこういう問題をどうやって解くんだろう…?