CODE FESTIVAL 2017 Exhibition A - Awkward

当時Code Festivalに参加していた時は、何をすればいいのか見当もつかず、出題陣のライブ解説も理解不能だったが、今解き直したらかなり簡単に思えた。A - Awkward

問題

N頂点の木が与えられる。木の頂点の並べ方はN!通りあるが、その中で以下を満たすものの個数を10^9 + 7で割った余りを求めよ。

  • どの2個の隣接する頂点も、隣同士にならない。

解法

包除原理を使う。以下の値をx = 0, ..., N - 1について数え上げることができればよい。

  • N - 1個の辺のうちx個を選ぶ。それらの辺については制約が破られている(隣同士になっている)ような並べ方の総数の合計。
    • 例えば入力例3 (N=5, b = [1, 1, 3, 3])では120, 192, 96, 16, 0となる。問題の答えは120 - 192 + 96 - 16 + 0 = 8。

これを求めるにはどうすれば良いだろうか? まず、辺の集合Sをfixしたときに、その集合Sについては制約が破られているような並べ方の総数を求めよう。辺集合に枝分かれがある(同じ頂点にSの3本以上の辺が接続している)ときは0通りである。それ以外の時を考える。Sを辺集合とした木の部分グラフは、1点と長さ1以上のパスの寄せ集めになっている。パスは列の中で一続きの部分列になること、およびどちらを列の先頭に持ってくるかが2通りあることを考えると、場合の数は(ものの個数)! * 2^(長さ1以上のパスの個数)であることがわかる。
次に、集合Sではなく集合の大きさx = |S|をfixしたときの和を計算したい。「(ものの個数)!」についてはあとでまとめて掛けることにして、「2^(長さ1以上のパスの個数)」の合計をもとめることにすると、これは以下のようなDPで計算することができる。

  • DP[v][x][k]: 頂点vを根とする部分木において、連結成分の個数がx個あり、vからは辺がk本伸びているときの、2^(長さ1以上のパスの個数)の合計

これは二乗の木 DP - (iwi) { 反省します - TopCoder部のテクニックを使えば、O(N^2)で計算することができる。

登場する典型

  • 包除原理
  • 2乗の木DP

実装上の注意点

  • 遷移が複雑なのでバグらせないようにする
  • 余計なループを回すとO(N^3)になるので回さないようにする
    • RustだとDP配列をそのまま返してしまうのが賢明な気がする

提出: Submission #4342066 - CODE FESTIVAL 2017 Exhibition (Parallel) (Rust)

// ModInt省略
// Depends on ModInt.rs
fn fact_init(w: usize) -> (Vec<ModInt>, Vec<ModInt>) {
    let mut fac = vec![ModInt::new(1); w];
    let mut invfac = vec![0.into(); w];
    for i in 1 .. w {
        fac[i] = fac[i - 1] * i as i64;
    }
    invfac[w - 1] = fac[w - 1].inv();
    for i in (0 .. w - 1).rev() {
        invfac[i] = invfac[i + 1] * (i as i64 + 1);
    }
    (fac, invfac)
}

fn dfs(ch: &[Vec<usize>], v: usize) -> Vec<[ModInt; 3]> {
    let mut dp = vec![[ModInt::new(0); 3]; 2];
    dp[1][0] = ModInt::new(1);
    for &w in &ch[v] {
        let sub = dfs(ch, w);
        let sub_sz = sub.len() - 1;
        let cur_sz = dp.len() - 1;
        let mut next_dp = vec![[ModInt::new(0); 3]; cur_sz + sub_sz + 1];
        for i in 1..sub_sz + 1 {
            for j in 1..cur_sz + 1 {
                // add one edge
                for k in 0..2 {
                    next_dp[i + j - 1][1] += sub[i][k] * dp[j][0];
                    next_dp[i + j - 1][2] += sub[i][k] * dp[j][1];
                }
                // don't add an edge
                for k in 0..3 {
                    for l in 0..3 {
                        next_dp[i + j][l] += sub[i][k] * dp[j][l]
                            * if k >= 1 { 2 } else { 1 };
                    }
                }
            }
        }
        dp = next_dp;
    }
    dp
}

fn solve() {
    let out = std::io::stdout();
    let mut out = BufWriter::new(out.lock());
    macro_rules! puts {
        ($($format:tt)*) => (write!(out,$($format)*).unwrap());
    }
    input! {
        n: usize,
        b: [usize1; n - 1],
    }
    let mut ch = vec![vec![]; n];
    for i in 1..n {
        ch[b[i - 1]].push(i);
    }
    const W: usize = 3000;
    let mut invtbl = vec![ModInt::new(0); W];
    for i in 1..W {
        invtbl[i] = ModInt::new(i as i64).inv();
    }
    let (fac, _invfac) = fact_init(W);
    let dp = dfs(&ch, 0);
    let mut tot = ModInt::new(0);
    let mut factor = ModInt::new(1);
    for i in (0..n + 1).rev() {
        let mut t = ModInt::new(0);
        for k in 0..3 {
            t += dp[i][k] * fac[i]
                * if k >= 1 { 2 } else { 1 };
        }
        tot += t * factor;
        factor = -factor;
    }
    puts!("{}\n", tot);
}

まとめ

当時に比べこの手の問題が多く出題されているなど、界隈が進歩しているので、今出題するとしたら800点くらいになりそうな気がする。