本番Eまで解いてF, G, Hのどれを解くか、という状態になったときに一番可能性の高かったこの問題に粘着した。結局時間内には解けなかったが、悔しかったので全力で頑張ったら解けた。
https://codeforces.com/contest/1119/problem/H
問題
整数kが与えられる。数列がn個ある。どの数列も長さ(x + y + z)で、要素は0から2^k-1までの整数である。i番目の数列s[i]はa[i]をx個、b[i]をy個、c[i]をz個含む。
t = 0, 1, ..., 2^k - 1に対して、以下の個数を数え上げて998244353で割った余りを求めよ:
- 長さnの整数列(u[0], ..., u[n - 1])であって、s[0][u[0]] xor s[1][u[1]] xor ... xor s[n - 1][u[n - 1]] = tを満たすもの。
制約
- 1 <= n <= 10^5
- 1 <= k <= 17
- 0 <= a[i], b[i], c[i] <= 2^k - 1
- 0 <= x, y, z <= 10^9
解法
各列s[i]について長さ2^kの頻度表を作り、それをv[i]と呼ぶ。v[i]のアダマール変換をh[i]とすれば、求めるものはh[i]の要素ごとの積の逆アダマール変換である。
これを愚直にやるとO(n * k * 2^k)時間かかってしまうので高速化する。一般性を失わず、a[i] = 0としてよい。(b[i]とc[i]にa[i]をxorしておき、最終的な結果にa[i]たちのxorをxorする。) こうすると、アダマール変換後の数列h[i]に登場する数はx ± y ± zという形の4通りのみとなる。それぞれがいつ起こるかは以下の通り:
- h[i][j] = x - y - z: popcount(b[i] & j)が奇数、popcount(c[i]&j)が奇数
- h[i][j] = x + y - z: popcount(b[i] & j)が偶数、popcount(c[i]&j)が奇数
- h[i][j] = x - y + z: popcount(b[i] & j)が奇数、popcount(c[i]&j)が偶数
- h[i][j] = x + y + z: popcount(b[i] & j)が偶数、popcount(c[i]&j)が偶数
各jに対して、これらが現れる回数を数え上げればよい。
実はこれもアダマール変換を用いることでできる。例えば長さ2^kの配列pを、p[j] = (popcount(b[i] & j)もpopcount(c[i]&j)も奇数 ? 1 : 0)で定めると、配列pは以下によってつくられる長さ2^kの配列qのアダマール変換として得られる:
- qを0で初期化, q[0] += 1/4, q[b[i] ^ c[i]] += 1/4, q[b[i]] -= 1/4, q[c[i]] -= 1/4
これを利用すれば、毎回配列のたかだか16箇所に足し算をし、最後にまとめてアダマール変換を行うことで、x ± y ± zに乗せるべき指数がわかる。
計算量はO(2^k * k + n)である。
登場する典型
- アダマール変換
実装上の注意点
- 手数が多い上に、アダマール変換を使う都合上所々で係数の調整が必要なので、注意して実装する。
- (x-y-z) % MODみたいな演算をバグらせない
- MOD < 10^9なので2 * MODを足すだけでは不十分、3*MODを足す (1敗)
提出: #52511941 (Rust)
// ModInt省略 fn solve() { let out = std::io::stdout(); let mut out = BufWriter::new(out.lock()); macro_rules! puts { ($($format:tt)*) => (write!(out,$($format)*).unwrap()); } input! { n: usize, k: usize, x: i64, y: i64, z: i64, abc: [(usize, usize, usize); n], } let inv2 = ModInt::new(2).inv(); let mut base = 0; // bias: 4 * let mut tbl = vec![vec![0i64; 1 << k]; 4]; for &(a, b, c) in &abc { base ^= a; let b = a ^ b; let c = a ^ c; tbl[0][0] += 4; tbl[1][0] += 2; tbl[1][b] -= 2; tbl[2][0] += 2; tbl[2][c] -= 2; tbl[3][0] += 1; tbl[3][b ^ c] += 1; tbl[3][b] -= 1; tbl[3][c] -= 1; } for bits in 0..1 << k { for i in 0..2 { for j in 0..4 { if (j & 1 << i) != 0 { tbl[j ^ 1 << i][bits] -= tbl[j][bits]; } } } } // Hadamard for c in 0..4 { for i in 0..k { for bits in 0..1 << k { if (bits & 1 << i) == 0 { let x = tbl[c][bits]; let y = tbl[c][bits | 1 << i]; tbl[c][bits] = x + y; tbl[c][bits | 1 << i] = x - y; } } } for bits in 0..1 << k { tbl[c][bits] >>= 2; } } let mut g = vec![ModInt::new(0); 4]; for kind in 0..4 { let mut t = x; t += if (kind & 1) == 0 { y } else { -y }; t += if (kind & 2) == 0 { z } else { -z }; g[kind] = ModInt::new(t + 3 * MOD); } let mut prod = vec![ModInt::new(1); 1 << k]; for bits in 0..1 << k { for c in 0..4 { prod[bits] *= g[c].pow(tbl[c][bits]); } } // Hadamard for i in 0..k { for bits in 0..1 << k { if (bits & 1 << i) == 0 { let x = prod[bits]; let y = prod[bits | 1 << i]; prod[bits] = x + y; prod[bits | 1 << i] = x - y; } } } let fac = inv2.pow(k as i64); for bits in 0..1 << k { prod[bits] *= fac; } for bits in 0..1 << k { puts!("{}{}", prod[bits ^ base], if bits + 1 == (1 << k) { "\n" } else { " " }); } }