yukicoder No.802 だいたい等差数列

こういうのはかなり得意だと思っている。
No.802 だいたい等差数列 - yukicoder

問題

以下の条件を満たす長さNの整数列Aの個数を数え、10^9+7で割った余りを求めよ。

  • 1 <= A[1] <= ... <= A[N] <= M
  • i = 1, ..., N - 1に対して D1 <= A[i + 1] - A[i] <= D2

制約

  • 2 <= N <= 3 * 10^5
  • 1 <= M <= 10^6
  • 0 <= D1 <= D2 <= M

解法

まず愚直DPを考える。DP[i][j] := (Aの第i項まで見たとき、A[i]=jとなるような数列の個数)としたとき、以下のような遷移を持つDPを行えばよい。

  • init: DP[1][j] = 1
  • step: DP[i + 1][j + k] += DP[i][j] (D1 <= k <= D2)

これのDP[N][1] + ... DP[N][M]が答えである。これを実行することは計算量が多すぎて不可能なので、高速化を行う。
DP配列を多項式と考えることにして、f(i) = \sum DP[i][j] * x^jと置く。(最終的に興味があるのはx^M以下の項なので、x^{M+1}以上の項は無視する。また多項式級数(項が無限個続くもの)を断りなく同一視する。)
最初はf(1) = x + x^2 + ... + x^M = x/(1 - x) + O(x^{M + 1})である。
またstepを多項式の言葉に直すとf(i + 1) = f(i) * (x^{D_1} + ... + x^{D_2}) = f(i) * (x^{D_1} - x^{D_2 + 1}) / (1 - x)である。
また、多項式A + Bx + ... + Cx^Mに1 / (1 - x)を掛けたもの(のマクローリン展開)のx^Mの係数はA + B + ... + Cである。
よってこれらを組み合わせると、x * (x^{D_1} - x^{D_2 + 1})^{N - 1} / (1 - x)^{N + 1} (をマクローリン展開したもの)におけるx^Mの係数が答えであることがわかる。
これはどう計算すべきだろうか? まず、(x^{D_1} - x^{D_2 + 1})^{N - 1} = \sum_{i=0}^{N - 1} C(N - 1, i) * (-1)^i * x^{D_1 * (N - 1 - i) + (D_2 + 1) * i)}なので、「ある整数kに対してx^k / (1 - x)^{N + 1}におけるx^Mの係数」を高速に計算できればよいことがわかる(このようなクエリはN回実行される)。
これは1 / (1 - x)^{N + 1}におけるx^{M - k}の係数に他ならないので、M - k < 0であれば0で、M - k >= 0であればC(M - k + N, N)である。

登場する典型

  • 母関数
    • 愚直DPの高速化

実装上の注意点

  • 二項係数の扱いに気をつける
  • 母関数を使う時は式変形が複雑になりがちなので、式変形が間違っていないかどうかサンプルなどを使って確かめる

提出: #324980 (Rust)

fn fact_init(w: usize) -> (Vec<ModInt>, Vec<ModInt>) {
    let mut fac = vec![ModInt::new(1); w];
    let mut invfac = vec![0.into(); w];
    for i in 1 .. w {
        fac[i] = fac[i - 1] * i as i64;
    }
    invfac[w - 1] = fac[w - 1].inv();
    for i in (0 .. w - 1).rev() {
        invfac[i] = invfac[i + 1] * (i as i64 + 1);
    }
    (fac, invfac)
}

fn solve() {
    let out = std::io::stdout();
    let mut out = BufWriter::new(out.lock());
    macro_rules! puts {
        ($($format:tt)*) => (write!(out,$($format)*).unwrap());
    }
    input! {
        n: usize,
        m: usize,
        d1: usize,
        d2: usize,
    }
    let n = n - 1;
    let (fac, invfac) = fact_init(n + m + 2);
    let mut tot = ModInt::new(0);
    for i in 0..n + 1 {
        let deg = d1 * (n - i) + (d2 + 1) * i + 1;
        if deg > m { continue; }
        let rem = m - deg;
        let tmp = fac[rem + n + 1] * invfac[n + 1] * invfac[rem]
            * fac[n] * invfac[n - i] * invfac[i];
        if i % 2 == 0 {
            tot += tmp;
        } else {
            tot -= tmp;
        }
    }
    puts!("{}\n", tot);
}

まとめ

母関数周りも今後定跡になりそうな気がする。