yukicoder No.732 3PrimeCounting

想定解よりも理論的に速い解法を思いついたので記念に。
No.732 3PrimeCounting - yukicoder

問題

N以下の相異なる素数の3つ組(a,b,c)の中で、a+b+cも素数であるものの個数を数えよ。2 <= a < b < c <= Nとする。
5<=N<=10^5

解法

相異なる素数(a,b,c)の和としてある数を表す方法の総数がわかれば、それを3N以下の全ての素数について数えて総和を答えればよい。
a+b+c=kとなる相異なる素数の組(a,b,c)の個数を数えたい。包除原理を使うことで、

  • 相異なるとは限らない素数a,b,cの和としてkを実現する方法の総数 A_k
  • 相異なるとは限らない素数の組(a,b)であって、2a+b=kを満たすものの個数 B_k
  • 素数aであって、3a=kを満たすものの個数 C_k

がわかれば、A_k - 3B_k + 2C_kが求める答えであることがわかる。これらの値は高速数論変換(NTT)と中国剰余定理(Garnerのアルゴリズム)により高速に計算できる。計算量はO(NlogN)である。

想定解は、N以下の素数の個数が10^4くらいしかないことを利用したO(N^2/log^2N)アルゴリズムである。理論的にはNTTの方が速いが、問題の制約では想定解の方が速い。

実装上の注意点

  • NTTは定数倍が重いので注意する。
  • 最終的な答えの上限から、NTTを何個並列で実行すべきかを判断する。今回の場合は、最終的な答えがN^3<=10^15を超えないことがわかるので、10^9程度の大きさのNTT-friendlyな素数を2個用意して、2並列で実行すればよい。
  • 2並列の場合は特殊化されたGarnerのアルゴリズムが使えるが、計算途中でのオーバーフローに注意する。今回は(素数の積)<2^62なので問題ない。
  • (a,b,c)に順序をつける必要があるので、最後に6=3!で割るのを忘れない!!

提出: #303386 No.732 3PrimeCounting - yukicoder (Rust)

/// FFT (in-place)
/// R: Ring + Copy
/// Verified by: ATC001-C (http://atc001.contest.atcoder.jp/submissions/1175827)
mod fft {
    use std::ops::*;
    fn inplace_internal_fft<R>(
        f: &[R], output: &mut [R], pztbl: &[R], one: R,
        x: usize, fstart: usize, fstep: usize,
        n: usize, ostart: usize)
        where R: Copy +
        Add<Output = R> +
        Sub<Output = R> +
        Mul<Output = R> {
        if n == 1 {
            output[ostart] = f[fstart];
            return;
        }
        inplace_internal_fft(f, output, pztbl, one, x + 1,
                             fstart, 2 * fstep, n / 2, ostart);
        inplace_internal_fft(f, output, pztbl, one, x + 1,
			     fstart + fstep, 2 * fstep, n / 2, ostart + n / 2);
        let mut cnt = 0;
        for i in 0 .. n / 2 {
            let pzeta = pztbl[cnt];
            let f0 = output[ostart + i];
            let f1 = output[ostart + i + n / 2];
            let tmp = pzeta * f1;
            output[ostart + i] = f0 + tmp;
            output[ostart + i + n / 2] = f0 - tmp;
            cnt += 1 << x;
        }
    }
    /// n should be a power of 2. zeta is a primitive n-th root of unity.
    /// one is unity
    /// Note that the result should be multiplied by 1/sqrt(n).
    pub fn transform<R>(f: &[R], zeta: R, one: R) -> Vec<R>
        where R: Copy +
        Add<Output = R> +
        Sub<Output = R> +
        Mul<Output = R> {
        let n = f.len();
        assert!(n.is_power_of_two());
        let mut pztbl = vec![one; n];
        for i in 1 .. n {
            pztbl[i] = pztbl[i - 1] * zeta;
        }
        let mut output = vec![zeta; n];
        inplace_internal_fft(&f, &mut output, &pztbl, one, 0, 0, 1, n, 0);
        output
    }
}

mod mod_int { /* 省略 */ }

macro_rules! define_mod {
    ($struct_name: ident, $modulo: expr) => {
        #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
        struct $struct_name {}
        impl mod_int::Mod for $struct_name { fn m() -> i64 { $modulo } }
    }
}
const MOD1: i64 = 998244353;
const MOD2: i64 = 1004535809;
define_mod!(P1, MOD1);
define_mod!(P2, MOD2);
type ModInt1 = mod_int::ModInt<P1>;
type ModInt2 = mod_int::ModInt<P2>;

use mod_int::*;

const N: usize = 1 << 19;

fn calc<M: Mod>(n: usize, pr: &[bool], zeta: ModInt<M>) -> ModInt<M> {
    let zeta = zeta.pow((M::m() - 1) / N as i64);
    let zeta_inv = zeta.inv();
    let mut f = vec![ModInt::new(0); N];
    let mut f2 = vec![ModInt::new(0); N];
    let mut f3 = vec![ModInt::new(0); N];
    for i in 1 .. n + 1 {
        if pr[i] {
            f[i] = ModInt::new(1);
            f2[2 * i] = ModInt::new(1);
            f3[3 * i] = ModInt::new(1);
        }
    }
    // f^3 - 3 * f2 * f + 2 * f3
    let f = fft::transform(&f, zeta, ModInt::new(1));
    let f2 = fft::transform(&f2, zeta, ModInt::new(1));
    let f3 = fft::transform(&f3, zeta, ModInt::new(1));
    let mut g = vec![ModInt::new(0); N];
    for i in 0 .. N {
        g[i] = f[i].pow(3) - ModInt::new(3) * f2[i] * f[i] + f3[i] + f3[i];
    }
    let g = fft::transform(&g, zeta_inv, ModInt::new(1));
    let mut tot = ModInt::new(0);
    for i in 2 .. N {
        if pr[i] { tot += g[i]; }
    }
    tot *= ModInt::new(N as i64).inv();
    tot *= ModInt::new(6).inv();
    tot
}

/// Depends on ModInt.rs
fn garner2<M1: mod_int::Mod, M2: mod_int::Mod>(a: mod_int::ModInt<M1>,
                                               b: mod_int::ModInt<M2>)
                                               -> i64 {
    let factor2 = mod_int::ModInt::new(M1::m()).inv();
    let factor1 = mod_int::ModInt::new(M2::m()).inv();
    ((b * factor2).x * M1::m() + (a * factor1).x * M2::m()) % (M1::m() * M2::m())
}


fn solve() {
    let out = std::io::stdout();
    let mut out = BufWriter::new(out.lock());
    macro_rules! puts {
        ($format:expr) => (write!(out,$format).unwrap());
        ($format:expr, $($args:expr),+) => (write!(out,$format,$($args),*).unwrap())
    }
    let mut pr = vec![true; N];
    pr[0] = false;
    pr[1] = false;
    for i in 2 .. N {
        if !pr[i] { continue; }
        for j in 2 .. (N - 1) / i + 1 { pr[i * j] = false; }
    }
    input! {
        n: usize,
    }
    let a: ModInt1 = calc(n, &pr, ModInt::new(3));
    let b: ModInt2 = calc(n, &pr, ModInt::new(3));
    puts!("{}\n", garner2(a, b));
}

まとめ

初手でこういう面倒な方針が思いつく癖は直した方が良いかもしれない。あと、この記事を書いている途中でC_kを計算する必要がないことに気付いた。