CADDi 2018 F - Square

線形代数をやりすぎて迷走してしまった。
https://atcoder.jp/contests/caddi2018/tasks/caddi2018_d

問題

N行N列のマス目がある。これに0または1の値を書き込む。マス目のうちいくつかにはすでに数字が書き込まれている。
完成したマス目は以下の条件を満たす必要がある。

  • 1 <= i < j <= Nであるような整数i,jに対して、(i,i)を左上、(j,j)を右下にするような正方形の中に含まれる1の個数は偶数である。

条件を満たすマス目の埋め方を数え上げよ。

解法

(i,j)に書き込む数字をa[i][j]と表記する。
まず、|i-j| >= 3ならばa[i][j] = a[j][i]である。これは冒頭の規則を(i,j), (i, j-1), (i+1,j), (i+1,j-1)に対して適用すればわかる。
残りは対角成分付近である。対角成分に着目すると、a[i+1][i+1] = a[i][i]+(a[i][i+1]+a[i+1][i]), a[i+1][i+1] = a[i][i+2] + a[i+2][i]がわかる。
ここから、対角成分付近の自由度を計算しつつ、ありえる対角成分の数え上げをすればよい。
対角成分の数え上げは、以下の問題を解くことで可能である。

  • 長さnの0または1からなる数列bおよびそのmod 2での階差数列cについて、b[i]=jという形の制約と、c[i]=jという形の制約が何個か与えられる。条件を全て満たす(b,c)を数え上げよ。

これはDPでできる。(コード中のcalc_acc)
また、対角成分付近の自由度の数え上げは、各iについて(i+1,i)と(i,i+1)についてはmax(0,もともと埋まっていない個数-1)が自由度、(i+2,i)と(i,i+2)についても同様、とすればよい。(マイナス1は上のcの数え上げ部分に吸収されている分であることに注意。)

実装上の注意点

  • 端で壊れがちなので気をつける
  • 制約の矛盾などを発見するときにビット演算を使うが、そこでバグらせないように細心の注意を払う
  • 数え上げだが、実質自由度を求める問題なので、本当に自由度があっているか確認する

提出: Submission #3849384 - CADDi 2018 (Rust)

// ModInt 省略
fn calc_acc(acc: &[i32], inb: &[i32]) -> ModInt {
    let n = acc.len();
    assert_eq!(inb.len(), n - 1);
    // eprintln!("acc = {:?}, inb = {:?}", acc, inb);
    let mut dp = vec![[ModInt::new(0); 2]; n];
    let one = ModInt::new(1);
    if acc[0] == 0 {
        dp[0][0] = one;
        dp[0][1] = one;
    } else {
        dp[0][acc[0] as usize - 1] = one;
    }
    for i in 1 .. n {
        if inb[i - 1] == 0 {
            for k in 0 .. 2 {
                dp[i][0] += dp[i - 1][k];
            }
            dp[i][1] = dp[i][0];
        } else {
            let diff = (inb[i - 1] - 1) as usize;
            for k in 0 .. 2 {
                dp[i][k ^ diff] += dp[i - 1][k];
            }
        }
        if acc[i] != 0 {
            let opp = 1 - (acc[i] - 1) as usize;
            dp[i][opp] = ModInt::new(0);
        }
    }
    dp[n - 1][0] + dp[n - 1][1]
}


fn solve() {
    let out = std::io::stdout();
    let mut out = BufWriter::new(out.lock());
    macro_rules! puts {
        ($format:expr) => (write!(out,$format).unwrap());
        ($format:expr, $($args:expr),+) => (write!(out,$format,$($args),*).unwrap())
    }
    input! {
        n: usize,
        m: usize,
        abc: [(i32, i32, i32); m],
    }
    let mut dim: i64 = 0;
    let mut hm = HashMap::new();
    for &(a, b, c) in abc.iter() {
        let (a, b) = if a < b || (a - b).abs() <= 2 { (a - 1, b - 1) } else { (b - 1, a - 1) };
        *hm.entry((a, b)).or_insert(0) |= 1 << c;
    }
    // rest, |x - y| <= 2, even
    let mut rest = vec![vec![0; 3]; n];
    let mut odd = vec![Vec::new(); n - 1];
    for (&(a, b), &val) in hm.iter() {
        if val == 3 {
            puts!("0\n");
            return;
        }
        if (a - b).abs() >= 3 {
            dim -= 1;
        }
        if (a - b).abs() <= 2 && (a + b) % 2 == 0 {
            let idx = (a + b) as usize / 2;
            let u = (b - a + 2) as usize / 2;
            rest[idx][u] = val;
        }
        if (a - b).abs() == 1 {
            odd[(a + b) as usize / 2].push((a, b, val));
        }
    }
    let mut adden = 0;
    if n > 3 {
        for i in 0 .. n {
            let lock = if i <= 1 {
                3 + i as i64
            } else if i <= n - 2 {
                5
            } else {
                (n - 1 - i) as i64 + 2
            };
            adden += n as i64 - lock;
        }
    }
    dim += adden / 2;
    // eprintln!("dim = {}", dim);
    let mut cons = vec![0; n]; // constraints
    for i in 0 .. n {
        cons[i] |= rest[i][1];
    }
    // rest even
    for i in 1 .. n - 1 {
        if rest[i][0] != 0 && rest[i][2] != 0 {
            let x = (rest[i][0] - 1) ^ (rest[i][2] - 1);
            cons[i] |= 1 << x;
        }
        let mut f = 0;
        if rest[i][0] == 0 { f += 1; }
        if rest[i][2] == 0 { f += 1; }
        dim += max(0, f - 1);
    }
    for i in 0 .. n {
        if cons[i] == 3 {
            puts!("0\n");
            return;
        }
    }
    let mut inb = vec![0; n - 1];
    for i in 0 .. n - 1 {
        if odd[i].len() == 2 {
            let x = (odd[i][0].2 - 1) ^ (odd[i][1].2 - 1);
            inb[i] |= 1 << x;
        }
        let mut f = 2 - odd[i].len() as i64;
        dim += max(f - 1, 0);
    }
    // rest odd
    let fac = calc_acc(&cons, &inb);
    /*
    eprintln!("dim = {}", dim);
    eprintln!("fac = {}", fac);
     */
    puts!("{}\n", ModInt::new(2).pow(dim) * fac);    
}

まとめ

  • 頭が数学モードになっていると、競技プログラミングのことをすっかり忘れてしまう…。
  • AtCoderのRustのバージョンが低いため最初の提出がCEになり、順位が1つ下がった (は?)