yukicoder No.802 だいたい等差数列
こういうのはかなり得意だと思っている。
No.802 だいたい等差数列 - yukicoder
問題
以下の条件を満たす長さNの整数列Aの個数を数え、10^9+7で割った余りを求めよ。
- 1 <= A[1] <= ... <= A[N] <= M
- i = 1, ..., N - 1に対して D1 <= A[i + 1] - A[i] <= D2
制約
- 2 <= N <= 3 * 10^5
- 1 <= M <= 10^6
- 0 <= D1 <= D2 <= M
解法
まず愚直DPを考える。DP[i][j] := (Aの第i項まで見たとき、A[i]=jとなるような数列の個数)としたとき、以下のような遷移を持つDPを行えばよい。
- init: DP[1][j] = 1
- step: DP[i + 1][j + k] += DP[i][j] (D1 <= k <= D2)
これのDP[N][1] + ... DP[N][M]が答えである。これを実行することは計算量が多すぎて不可能なので、高速化を行う。
DP配列を多項式と考えることにして、f(i) = \sum DP[i][j] * x^jと置く。(最終的に興味があるのはx^M以下の項なので、x^{M+1}以上の項は無視する。また多項式と級数(項が無限個続くもの)を断りなく同一視する。)
最初はf(1) = x + x^2 + ... + x^M = x/(1 - x) + O(x^{M + 1})である。
またstepを多項式の言葉に直すとf(i + 1) = f(i) * (x^{D_1} + ... + x^{D_2}) = f(i) * (x^{D_1} - x^{D_2 + 1}) / (1 - x)である。
また、多項式A + Bx + ... + Cx^Mに1 / (1 - x)を掛けたもの(のマクローリン展開)のx^Mの係数はA + B + ... + Cである。
よってこれらを組み合わせると、x * (x^{D_1} - x^{D_2 + 1})^{N - 1} / (1 - x)^{N + 1} (をマクローリン展開したもの)におけるx^Mの係数が答えであることがわかる。
これはどう計算すべきだろうか? まず、(x^{D_1} - x^{D_2 + 1})^{N - 1} = \sum_{i=0}^{N - 1} C(N - 1, i) * (-1)^i * x^{D_1 * (N - 1 - i) + (D_2 + 1) * i)}なので、「ある整数kに対してx^k / (1 - x)^{N + 1}におけるx^Mの係数」を高速に計算できればよいことがわかる(このようなクエリはN回実行される)。
これは1 / (1 - x)^{N + 1}におけるx^{M - k}の係数に他ならないので、M - k < 0であれば0で、M - k >= 0であればC(M - k + N, N)である。
登場する典型
- 母関数
- 愚直DPの高速化
実装上の注意点
- 二項係数の扱いに気をつける
- 母関数を使う時は式変形が複雑になりがちなので、式変形が間違っていないかどうかサンプルなどを使って確かめる
提出: #324980 (Rust)
fn fact_init(w: usize) -> (Vec<ModInt>, Vec<ModInt>) { let mut fac = vec![ModInt::new(1); w]; let mut invfac = vec![0.into(); w]; for i in 1 .. w { fac[i] = fac[i - 1] * i as i64; } invfac[w - 1] = fac[w - 1].inv(); for i in (0 .. w - 1).rev() { invfac[i] = invfac[i + 1] * (i as i64 + 1); } (fac, invfac) } fn solve() { let out = std::io::stdout(); let mut out = BufWriter::new(out.lock()); macro_rules! puts { ($($format:tt)*) => (write!(out,$($format)*).unwrap()); } input! { n: usize, m: usize, d1: usize, d2: usize, } let n = n - 1; let (fac, invfac) = fact_init(n + m + 2); let mut tot = ModInt::new(0); for i in 0..n + 1 { let deg = d1 * (n - i) + (d2 + 1) * i + 1; if deg > m { continue; } let rem = m - deg; let tmp = fac[rem + n + 1] * invfac[n + 1] * invfac[rem] * fac[n] * invfac[n - i] * invfac[i]; if i % 2 == 0 { tot += tmp; } else { tot -= tmp; } } puts!("{}\n", tot); }
まとめ
母関数周りも今後定跡になりそうな気がする。