yukicoder No.1304 あなたは基本が何か知っていますか?私は知っています.

問題設定に少々改善の余地があると思われたが、高速な別解の存在が面白かった。
No.1304 あなたは基本が何か知っていますか?私は知っています. - yukicoder

問題

以下の条件を満たす長さ N の整数列 B を数え上げ、mod 998244353 で出力せよ。

  • B の各要素は数列 A の要素のいずれかと同じ
  • B[i] != B[i + 1]
  • B の全ての要素の bitwise xor は X 以上 Y 以下

制約

  • 2 <= N <= 40, N は偶数
  • 1 <= K <= 2 * 10^5
  • 0 <= A[i] <= 1023
  • 0 <= X <= Y <= 10^9
  • K^N <= 2 * 10^15

解法

O(NM + M log M + K log K) の解法 (解説の解法の高速化) を説明する。
DP[現在位置][xor 和] という DP を行う。B[i] = B[i + 1] のとき、(i + 1 までの xor 和) = (i - 1 までの xor 和) が成立するので、普通の遷移 DP[i + 1][s] -> DP[i + 2][s ^ a[j]] を行なった後、DP[i + 2][s] から DP[i][s] を「ある程度」引けばよい。どの程度引くべきかを考える。DP[i + 1][s] として計上されているのは i, i + 1 番目が x, y (x != y) であるようなものなので、DP[i + 2][s] から引くべきなのは i, i + 1, i + 2 番目が x, y, y (x != y) となるようなものである。これは (y が K-1 通りあるので) DP[i][s] の K-1 倍存在する。
例外は i <= 2 で、i = 1 のときは引く必要がなく、i = 2 のときは上の x に当たる要素が存在しないため、DP[0][s] の K 倍引けば良い。

以上の遷移をまとめると以下である:

  • DP[0][s] = (s == 0 ? 1 : 0)
  • DP[1][s] = \sum_j DP[0][s ^ A[j]]
  • DP[2][s] = \sum_j DP[1][s ^ A[j]] - K DP[0][s]
  • DP[i + 3][s] = \sum_j DP[i + 2][s ^ A[j]] - (K-1) DP[i+1][s] (i >= 0)

上の DP 配列を、アダマール変換した形で持っておくことを考える。アダマール変換で線型結合は線型結合に写り、xor 畳み込みは要素ごとの積に写る。よって、DP[i] をアダマール変換した物を EP[i] と表記すると、上の遷移は以下の形に書ける。

  • EP[0][s] = 1
  • EP[1][s] = EP[0][s] * H[s]
  • EP[2][s] = EP[1][s] * H[s] - K EP[0][s]
  • EP[i + 3][s] = EP[i + 2][s] * H[s] - (K - 1) EP[i + 1][s] (i >= 0)

ただし、H[s] は A[j] の位置に 1 を立てた配列をアダマール変換したものである。アダマール変換した形のままで遷移ができることに注目されたい。

以上から、M = max(A[i]) とおいて、状態数が O(NM) であり各状態が O(1) で計算できること、また最初に A のソート、最後にアダマール変換をする必要があることから、全体の計算量は O(NM + M log M + K log K) である。

なお、EP の漸化式は 3 項間漸化式であるため、行列累乗などで一般項を O(log N) 時間で求めることができる。これを利用して、O(M log N + M log M + K log K) 時間で計算することもできる。

登場する典型

実装上の注意点

  • アダマール変換の逆変換を行うときは、2 で割るのを忘れないようにする。

提出: #587829 (Rust)

// MInt omitted

// O(NM + M log M + K log K)-time solution a la editorial
fn main() {
    let out = std::io::stdout();
    let mut out = BufWriter::new(out.lock());
    macro_rules! puts {
        ($($format:tt)*) => (let _ = write!(out,$($format)*););
    }
    input! {
        n: usize, k: usize, x: usize, y: usize,
        a: [usize; k],
    }
    let mut a = a;
    a.sort(); a.dedup();
    let k = a.len();
    const W: usize = 1024;
    let mut dp = vec![vec![MInt::new(0); W]; n + 1];
    for i in 0..W {
        dp[0][i] += 1;
    }
    let inv2 = MInt::new(2).inv();
    let mut had = vec![MInt::new(0); W];
    for i in 0..k {
        had[a[i]] += 1;
    }
    for i in 0..10 {
        for j in 0..W {
            if (j & 1 << i) == 0 {
                let x = had[j];
                let y = had[j | 1 << i];
                had[j] = x + y;
                had[j | 1 << i] = x - y;
            }
        }
    }
    // From now on the values of dp are Hadamard-transformed, until every transition ends.
    for i in 1..n + 1 {
        for u in 0..W {
            let val = dp[i - 1][u];
            dp[i][u] += val;
        }
        for j in 0..W {
            dp[i][j] *= had[j];
        }
        if i >= 2 {
            for j in 0..W {
                dp[i][j] = dp[i][j] - dp[i - 2][j] * (k as i64 - if i == 2 { 0 } else { 1 });
            }
        }
    }
    for j in 0..10 {
        for u in 0..W {
            if (u & 1 << j) == 0 {
                let x = dp[n][u];
                let y = dp[n][u | 1 << j];
                dp[n][u] = (x + y) * inv2;
                dp[n][u | 1 << j] = (x - y) * inv2;
            }
        }
    }
    let mut tot = MInt::new(0);
    for i in x..min(y, W - 1) + 1 {
        tot += dp[n][i];
    }
    puts!("{}\n", tot);
}

まとめ

X と Y は 1023 以下で良いし、K は 1024 以下で A の要素は互いに異なるとしてよかったですね。