第6回 ドワンゴからの挑戦状 予選 C - Cookie Distribution

かなり好きなタイプの問題。
C - Cookie Distribution

問題

N 人の人にクッキーを配りたい。i = 1, ..., K に対して、以下の操作を順番に行う。

  • N 人の中から a[i] 人を等確率にランダムに選び、その人たちにクッキーを 1 個ずつ渡す。

最終的に 人 i がもらったクッキーの枚数を c[i] とし、 c[1] * c[2] * ... c[N] を 嬉しさと呼ぶことにする。
嬉しさの期待値に C(N, a[1]) * ... * C(N, a[K]) を掛けたものは整数になるが、これを mod (10^9 + 7) で求めよ。

制約

  • 1 <= K <= 20
  • 1 <= N <= 1000
  • 1 <= a[i] <= N

解法

dp[i][j] := (i 回目まで処理を行った後の、c[1] * ... c[j] の期待値 (j = 0 の場合は 1)) と置く。
dp[i + 1][j] を dp[i][...] を使って表したい。ここで、c[1] * ... * c[j] を数える代わりに、人 1 から 人 j までがもらったクッキーの j 個組を数えることにする。 (両者は同じものである)
a[i + 1] 個のクッキーを配った後のこの量は、0 <= s <= j となる s ごとに以下の値を求めて合計すればわかる:

  • j 個のうち s 個はもともと配られていたもの、残りの (j - s) 個は今配られた a[i + 1] 個のうちのどれかとしたときの、j 個組の個数

j 個のうち s 個の場所の選び方が C(j, s) 通り、個数が dp[i][s] (対称性より)、a[i + 1] 個のクッキーのうち (j - s) 個が残りの場所に配られる確率は C(n - (j - s), n - a[i + 1]) / C(n, a[i + 1]) である。これらを全て掛けて dp[i][s] * C(j, s) * C(n - (j - s), n - a[i + 1]) / C(n, a[i + 1]) が求める量である。
これを全ての s について足し合わせれば良い。計算量は、DP テーブルのエントリ数が O(KN) 個、各エントリの計算に O(N) 時間かかるので、合計 O(KN^2) である。

なお、dp[i][s] -> dp[i + 1][j] の遷移に登場する変数が s, j - s の形でしか登場しないので、任意 mod の畳み込みを用いることで O(KN log N) にすることもできる。(参考: AtCoder AGC #005 : F - Many Easy Problems - kmjp's blog)

登場する典型

  • 積の期待値を求めたい場合に、組合せ論的に扱いやすいものに置き換えて考える
  • 積の期待値 = 期待値の積 ではないので注意
    • 今回の場合、c[i] 同士は全然独立ではないので、この式は成り立たない

実装上の注意点

とくになし

提出: #9422487 (Rust)

fn solve() {
    let out = std::io::stdout();
    let mut out = BufWriter::new(out.lock());
    macro_rules! puts {
        ($($format:tt)*) => (write!(out,$($format)*).unwrap());
    }
    input! {
        n: usize, k: usize,
        a: [usize; k],
    }
    let (fac, invfac) = fact_init(n + 1);
    let comb = |x, y| {
        if x < y {
            ModInt::new(0)
        } else {
            fac[x] * invfac[y] * invfac[x - y]
        }
    };
    let mut dp = vec![vec![ModInt::new(0); n + 1]; k + 1];
    dp[0][0] += 1;
    for i in 0..k {
        for j in 0..n + 1 {
            let val = dp[i][j];
            dp[i + 1][j] += val;
            for l in 0..j {
                if a[i] + l >= j {
                    let val = dp[i][l] * comb(j, l) * comb(n + l - j, a[i] + l - j);
                    dp[i + 1][j] += val * invfac[n] * fac[a[i]] * fac[n - a[i]];
                }
            }
        }
    }
    let mut ans = dp[k][n];
    for i in 0..k {
        ans *= comb(n, a[i]);
    }
    puts!("{}\n", ans);
}

まとめ

個人的には D が比較的難しかった (solved 数を見る限り C > D だったらしいが)。おかげで E にあまり時間が掛けられなかった。ゆるせね〜 (責任転嫁)