当時Code Festivalに参加していた時は、何をすればいいのか見当もつかず、出題陣のライブ解説も理解不能だったが、今解き直したらかなり簡単に思えた。A - Awkward
問題
N頂点の木が与えられる。木の頂点の並べ方はN!通りあるが、その中で以下を満たすものの個数を10^9 + 7で割った余りを求めよ。
- どの2個の隣接する頂点も、隣同士にならない。
解法
包除原理を使う。以下の値をx = 0, ..., N - 1について数え上げることができればよい。
- N - 1個の辺のうちx個を選ぶ。それらの辺については制約が破られている(隣同士になっている)ような並べ方の総数の合計。
- 例えば入力例3 (N=5, b = [1, 1, 3, 3])では120, 192, 96, 16, 0となる。問題の答えは120 - 192 + 96 - 16 + 0 = 8。
これを求めるにはどうすれば良いだろうか? まず、辺の集合Sをfixしたときに、その集合Sについては制約が破られているような並べ方の総数を求めよう。辺集合に枝分かれがある(同じ頂点にSの3本以上の辺が接続している)ときは0通りである。それ以外の時を考える。Sを辺集合とした木の部分グラフは、1点と長さ1以上のパスの寄せ集めになっている。パスは列の中で一続きの部分列になること、およびどちらを列の先頭に持ってくるかが2通りあることを考えると、場合の数は(ものの個数)! * 2^(長さ1以上のパスの個数)であることがわかる。
次に、集合Sではなく集合の大きさx = |S|をfixしたときの和を計算したい。「(ものの個数)!」についてはあとでまとめて掛けることにして、「2^(長さ1以上のパスの個数)」の合計をもとめることにすると、これは以下のようなDPで計算することができる。
- DP[v][x][k]: 頂点vを根とする部分木において、連結成分の個数がx個あり、vからは辺がk本伸びているときの、2^(長さ1以上のパスの個数)の合計
これは二乗の木 DP - (iwi) { 反省します - TopCoder部のテクニックを使えば、O(N^2)で計算することができる。
登場する典型
- 包除原理
- 2乗の木DP
実装上の注意点
- 遷移が複雑なのでバグらせないようにする
- 余計なループを回すとO(N^3)になるので回さないようにする
- RustだとDP配列をそのまま返してしまうのが賢明な気がする
提出: Submission #4342066 - CODE FESTIVAL 2017 Exhibition (Parallel) (Rust)
// ModInt省略 // Depends on ModInt.rs fn fact_init(w: usize) -> (Vec<ModInt>, Vec<ModInt>) { let mut fac = vec![ModInt::new(1); w]; let mut invfac = vec![0.into(); w]; for i in 1 .. w { fac[i] = fac[i - 1] * i as i64; } invfac[w - 1] = fac[w - 1].inv(); for i in (0 .. w - 1).rev() { invfac[i] = invfac[i + 1] * (i as i64 + 1); } (fac, invfac) } fn dfs(ch: &[Vec<usize>], v: usize) -> Vec<[ModInt; 3]> { let mut dp = vec![[ModInt::new(0); 3]; 2]; dp[1][0] = ModInt::new(1); for &w in &ch[v] { let sub = dfs(ch, w); let sub_sz = sub.len() - 1; let cur_sz = dp.len() - 1; let mut next_dp = vec![[ModInt::new(0); 3]; cur_sz + sub_sz + 1]; for i in 1..sub_sz + 1 { for j in 1..cur_sz + 1 { // add one edge for k in 0..2 { next_dp[i + j - 1][1] += sub[i][k] * dp[j][0]; next_dp[i + j - 1][2] += sub[i][k] * dp[j][1]; } // don't add an edge for k in 0..3 { for l in 0..3 { next_dp[i + j][l] += sub[i][k] * dp[j][l] * if k >= 1 { 2 } else { 1 }; } } } } dp = next_dp; } dp } fn solve() { let out = std::io::stdout(); let mut out = BufWriter::new(out.lock()); macro_rules! puts { ($($format:tt)*) => (write!(out,$($format)*).unwrap()); } input! { n: usize, b: [usize1; n - 1], } let mut ch = vec![vec![]; n]; for i in 1..n { ch[b[i - 1]].push(i); } const W: usize = 3000; let mut invtbl = vec![ModInt::new(0); W]; for i in 1..W { invtbl[i] = ModInt::new(i as i64).inv(); } let (fac, _invfac) = fact_init(W); let dp = dfs(&ch, 0); let mut tot = ModInt::new(0); let mut factor = ModInt::new(1); for i in (0..n + 1).rev() { let mut t = ModInt::new(0); for k in 0..3 { t += dp[i][k] * fac[i] * if k >= 1 { 2 } else { 1 }; } tot += t * factor; factor = -factor; } puts!("{}\n", tot); }
まとめ
当時に比べこの手の問題が多く出題されているなど、界隈が進歩しているので、今出題するとしたら800点くらいになりそうな気がする。