ARC058 F - 文字列大好きいろはちゃん / Iroha Loves Strings (解説を見た)

実装が大変すぎる。
F - 文字列大好きいろはちゃん / Iroha Loves Strings

問題

文字列の列s_1, ..., s_Nが与えられる。これの部分列を順序を保ったまま連結して、長さKの文字列を作ることを考える。ありえる文字列の辞書順最小値を求めよ。
制約

  • 1 <= N <= 2000
  • 1 <= K <= 10000
  • 1 <= |s_i| <= K
  • \sum |s_i| <= 10^6
  • アルファベットの大きさは26

解法

以下の解法では、辞書順を以下で定義する。

  • どちらかがどちらかのprefixである場合、長い方が小さい。
  • そうでない場合は普通の辞書順。

O(NK^2)のDPは色々考えられるが、どれもTLEする。そのため、うまいDPを考える必要がある。
以下のようにすると都合の良い性質が得られる。

  • DP[i][j]: i番目の文字列までを見て、j文字の文字列を作る時の最小値。ただし、(i+1)番目以降の文字列をどう使っても最終的に長さKの文字列を作れない場合は無効な値とする。

こうすることで、DP[i][j]のうち辞書順最小(ただし長い方が小さい)のものをCS[i]とおけば、CS[i]のprefix以外から遷移するのは無駄であるため、各jについてDP[i][j]はCS[i]のprefixか無効な値かのいずれかである。空間計算量はO(NK)である。
この上の遷移CS[i] -> CS[i + 1], DP[i][..] -> DP[i + 1][..]を高速化する方法を考える。ここで必要になるのは以下のタイプの比較である。

  • 添字u, vおよびブール値f1, f2が与えられた時、CS[i][..u] + (f1 ? S[i] : "") < CS[i][..v] + (f2 ? S[i] : "")か?

これはS[i] + CS[i]にz_algorithmを適用しておけば高速に計算できる。比較は以下のように考えるとわかりやすい多少はマシになる。

  • S[i]が先頭に来ているので、S[i]とS[i]のsuffixおよびCS[i]のsuffixのLCP(最長共通prefix)の長さは高速に計算できる。
  • それを利用して、S[i]とS[i]のスライスおよびCS[i]のスライスの辞書順比較も高速に計算できる(下記コードのcmp_slice)
  • 「CS[i][..u] < CS[i][..v] + S[i]か?」や「CS[i][..u] + S[i] < CS[i][..v] + S[i]か?」という問題を解く時には、先頭から見ていってどちらかのS[i]との境目で区切って比較する
    • 例: u > vのときCS[i].[..u] < CS[i][..v] + S[i]か? という問いに答えることを考える。先頭v文字は同一。なのでCS[i][v..u]とS[i]の比較結果が分かれば良い。これは上のデータ構造を使えば高速にできる。
    • 例2: u < v, u + |S[i]| > vのときCS[i][..u] + S[i] < CS[i][..v] + S[i]か?という問いに答えることを考える。先頭u文字は同一。u..v文字目を見て(<=> S[i][..v-u]とCS[i][u..v]を比較して)等しくなければそれが答えで、等しい場合はさらにv文字目以降を見る(<=> S[i][v-u..]とS[i]を比較する)。全ての比較がS[i]のprefixとの比較になっていることに注意。

上を頑張って実装するとACできる。計算量はO(NK)である。

登場する典型

実装上の注意点

  • 力任せに実装するとバグらせたり頭が爆発したりするので、極力共通化・単純化する
  • DPの更新を2ステップに分ける都合上、それぞれで必要な比較が微妙に違うので注意!
    • 1回目はCS[i + 1]を求めるためのものなので、辞書順で小さくなるかどうかを判定する (cmpsの返り値の0番目)
    • 2回目はDP[i][j]を求めるためのものなので、CS[i + 1]のprefixであるかどうかを判定する (cmpsの返り値の1番目)

提出: Submission #4032098 - AtCoder Regular Contest 058 (Rust)

/*
 * Z algorithm. Calculates an array a[i] = |lcp(s, s[i...|s|])|,
 * where s is the given string.
 * If n = s.length(), the returned array has length n + 1.
 * E.g. z_algorithm("ababa") = {5, 0, 3, 0, 1, 0}
 * Reference: http://snuke.hatenablog.com/entry/2014/12/03/214243
 * Verified by: AtCoder ARC055-C (http://arc055.contest.atcoder.jp/submissions/1061771)
 */
fn z_algorithm<T: PartialEq>(s: &[T]) -> Vec<usize> {
    let n = s.len();
    let mut ret = vec![0; n + 1];
    ret[0] = n;
    let mut i = 1; let mut j = 0;
    while i < n {
        while i + j < n && s[j] == s[i + j] { j += 1; }
        ret[i] = j;
        if j == 0 { i += 1; continue; }
        let mut k = 1;
        while i + k < n && k + ret[k] < j {
            ret[i + k] = ret[k];
            k += 1;
        }
        i += k; j -= k;
    }
    ret
}

fn calc_poss(s: &[Vec<char>], k: usize) -> Vec<Vec<bool>> {
    let n = s.len();
    let mut poss = vec![vec![false; k + 1]; n + 1];
    poss[n][0] = true;
    for i in (0 .. n).rev() {
        let len = s[i].len();
        for j in 0 .. k - len + 1 {
            poss[i][j + len] |= poss[i + 1][j];
        }
        for j in 0 .. k + 1 {
            poss[i][j] |= poss[i + 1][j];
        }
    }
    poss
}

fn solve() {
    use std::cmp::Ordering::*;
    let out = std::io::stdout();
    let mut out = BufWriter::new(out.lock());
    macro_rules! puts {
        ($($format:tt)*) => (write!(out,$($format)*).unwrap());
    }
    input! {
        n: usize,
        k: usize,
        s: [chars; n],
    }
    let poss = calc_poss(&s, k);
    let mut cs = vec![Vec::new(); n + 1];
    let mut dp = vec![vec![false; k + 1]; n + 1];
    dp[0][0] = true;
    for i in 0 .. n {
        let len = s[i].len();
        let mut css = s[i].clone();
        css.extend_from_slice(&cs[i]);
        let zarr = z_algorithm(&css);
        // cs[i + 1] is implicitly represented by
        // cs[i][..cs_len] + if s_appended { s[i] } else { "" },
        // where (cs_len, s_appended) = cs_next.
        let mut cs_next = (0, false);
        // compares css[a .. a + blen] and css[b .. b + blen].
        // Returns (comparison, is_prefix).
        // If one is a prefix of the other, the longer one is smaller.
        // Either a or b must be 0.
        let cmp_slice = |a: usize, alen: usize, b: usize, blen: usize| {
            let common_len = match (a, b) {
                (0, _) => zarr[b],
                (_, 0) => zarr[a],
                _ => unreachable!(),
            };
            if min(alen, blen) <= common_len {
                (blen.cmp(&alen), true) // longer is smaller
            } else {
                (css[a + common_len].cmp(&css[b + common_len]), false)
            }
        };
        // Compares cs[i][..j] + (if fappend { s[i] } else { "" }) and the current cs[i + 1].
        // If one is a prefix of the other, the longer one wins (is smaller).
        let cmps = |(j, fappend): (usize, bool),
        (cs_len, s_appended): (usize, bool)| {
            match (fappend, s_appended) {
                (true, true) if j == cs_len => (Equal, true),
                (true, true) if j > cs_len + len =>
                    cmp_slice(len + cs_len, j - cs_len, 0, len),
                (true, true) if cs_len > j + len =>
                    cmp_slice(0, len, len + j, cs_len - j),
                (true, true) if j > cs_len => {
                    let res = cmp_slice(len + cs_len, j - cs_len, 0, j - cs_len);
                    if res.0 != Equal { return res }
                    cmp_slice(0, len, j - cs_len, len - (j - cs_len))
                },
                (true, true) if cs_len > j => {
                    let res = cmp_slice(0, cs_len - j, len + j, cs_len - j);
                    if res.0 != Equal { return res }
                    cmp_slice(cs_len - j, len - (cs_len - j), 0, len)
                },
                (true, true) => unreachable!(),

                (true, false) if cs_len <= j => (Less, true),
                (true, false) => cmp_slice(0, len, len + j, cs_len - j),

                (false, true) if j <= cs_len => (Greater, true),
                (false, true) => cmp_slice(len + cs_len, j - cs_len, 0, len),

                (false, false) => (cs_len.cmp(&j), true),
            }
        };
        // get min
        for j in 0 .. k + 1 {
            if !poss[i + 1][k - j] || !dp[i][j] { continue; }
            if cmps((j, false), cs_next).0 == Less {
                cs_next = (j, false);
            }
        }
        for j in 0 .. k - len + 1 {
            if !poss[i + 1][k - j - len] || !dp[i][j] { continue; }
            if cmps((j, true), cs_next).0 == Less {
                cs_next = (j, true);
            }
        }
        // update dp
        for j in 0 .. k + 1 {
            if !poss[i + 1][k - j] || !dp[i][j] { continue; }
            dp[i + 1][j] |= cmps((j, false), cs_next).1;
        }
        for j in 0 .. k - len + 1 {
            if !poss[i + 1][k - j - len] || !dp[i][j] { continue; }
            dp[i + 1][j + len] |= cmps((j, true), cs_next).1;
        }
        cs[i + 1] = cs[i][..cs_next.0].to_vec();
        if cs_next.1 {
            cs[i + 1].extend_from_slice(&s[i]);
        }
    }

    puts!("{}\n", cs[n][..k].iter().cloned().collect::<String>());
}

まとめ

実装の簡略化を頑張ったが、それでもかなり複雑になってしまった。実装が上手く行かず鬱病になったが、某氏にこれは2度と実装したくない問題worst5に入ると言われたので気分が楽になった(は?)