実装上の注意点
- 重心分解を使う都合上かなり手数が多いので、デバッグプリントなどを駆使して絶対に間違いがないようにすること
提出: Submission #4339696 - CODE FESTIVAL 2018 Final (Parallel) (Rust)
fn centroid_decompose(g: &[Vec<(usize, i64)>]) -> Vec<(usize, usize, i64)> {
fn find_subtree_sizes(g: &[Vec<(usize, i64)>], v: usize, par: usize,
dp: &mut [usize], vis: &[bool]) {
let mut sum = 1;
for &(w, _) in &g[v] {
if par == w || vis[w] { continue; }
find_subtree_sizes(g, w, v, dp, vis);
sum += dp[w];
}
dp[v] = sum;
}
fn centroid_decompose_inner(g: &[Vec<(usize, i64)>], v: usize, par: usize,
cost: i64, edges: &mut Vec<(usize, usize, i64)>,
dp: &mut [usize], vis: &mut [bool]) {
let n = g.len();
find_subtree_sizes(g, v, n, dp, vis);
let (cent, dist) = {
let sz = dp[v];
let find_centroid = |mut v: usize, mut par: usize| {
let mut dist = 0;
loop {
let mut has_majority = false;
for &(w, c) in &g[v] {
if par == w || vis[w] { continue; }
if dp[w] > sz / 2 {
dist += c;
par = v;
v = w;
has_majority = true;
break;
}
}
if !has_majority {
return (v, dist);
}
}
};
find_centroid(v, n)
};
let g_cent = g[cent].clone();
if par < n {
edges.push((par, cent, dist + cost));
}
vis[cent] = true;
for &(w, c) in &g_cent {
if !vis[w] {
centroid_decompose_inner(g, w, cent, c, edges, dp, vis);
}
}
}
let n = g.len();
let mut edges = vec![];
let mut dp = vec![0; n];
let mut vis = vec![false; n];
centroid_decompose_inner(&g, 0, n, 0, &mut edges, &mut dp, &mut vis);
edges
}
const B: usize = 17;
fn init_lca_dfs(g: &[Vec<(usize, i64)>], v: usize, par: &mut [usize],
dep: &mut [usize], dep_dist: &mut [i64]) {
for &(w, c) in &g[v] {
if w == par[v] { continue; }
par[w] = v;
dep[w] = dep[v] + 1;
dep_dist[w] = dep_dist[v] + c;
init_lca_dfs(g, w, par, dep, dep_dist);
}
}
fn init_lca(g: &[Vec<(usize, i64)>]) -> (Vec<usize>, Vec<i64>, Vec<Vec<usize>>) {
let n = g.len();
let mut lca = vec![vec![n; n]; B];
let mut dep = vec![0; n];
let mut dep_dist = vec![0; n];
let mut par = vec![n; n];
init_lca_dfs(g, 0, &mut par, &mut dep, &mut dep_dist);
for v in 0..n {
lca[0][v] = par[v];
}
for i in 0..B - 1 {
for v in 0..n {
let w = lca[i][v];
lca[i + 1][v] = if w >= n {
n
} else {
lca[i][w]
};
}
}
(dep, dep_dist, lca)
}
fn solve() {
let out = std::io::stdout();
let mut out = BufWriter::new(out.lock());
macro_rules! puts {
($format:expr) => (write!(out,$format).unwrap());
($format:expr, $($args:expr),+) => (write!(out,$format,$($args),*).unwrap())
}
let mut x = 0x15262627i64;
let a = 0x245711;
let b = 0x13331;
let mut next = || {
x = x.wrapping_mul(a).wrapping_add(b);
x
};
input! {
n: usize,
m: usize,
abd: [(usize1, usize1, i64); n - 1],
secx: [(i64, i64, usize1, i64); m],
}
let mut g = vec![vec![]; n];
for &(a, b, d) in &abd {
g[a].push((b, d));
g[b].push((a, d));
}
let edges = centroid_decompose(&g);
let mut tree = vec![vec![]; n];
let mut par = vec![n; n];
for &(p, child, cost) in &edges {
tree[p].push((child, cost));
par[child] = p;
}
let (dep, dep_dist, lca_aux) = init_lca(&g);
let lca = |mut x: usize, mut y: usize| {
if dep[x] > dep[y] {
std::mem::swap(&mut x, &mut y);
}
for i in (0..B).rev() {
if dep[y] >= dep[x] + (1 << i) {
y = lca_aux[i][y];
}
}
assert_eq!(dep[x], dep[y]);
if x == y {
return x;
}
for i in (0..B).rev() {
if dep[x] <= 1 << i { continue; }
if lca_aux[i][x] == lca_aux[i][y] { continue; }
x = lca_aux[i][x];
y = lca_aux[i][y];
assert_ne!(x, y);
}
x = lca_aux[0][x];
y = lca_aux[0][y];
assert_eq!(x, y);
x
};
let gdist = |x: usize, y: usize| {
let l = lca(x, y);
let v = dep_dist[x] + dep_dist[y] - 2 * dep_dist[l];
v
};
let mut secx = secx;
secx.sort();
let mut dp = vec![0; m];
let mut pool = vec![Treap::new(); n];
fn update_element<F>(set: &mut Treap<(i64, i64)>, a: i64, b: i64, next: &mut F)
where F: FnMut() -> i64 {
let (idx, _) = set.find_index(&(a, b));
if idx >= 1 {
let &(_, pro) = set.at(idx - 1).unwrap();
if pro >= b { return; }
}
set.insert_mut((a, b), next());
while let Some(&(_cost, pro)) = set.at(idx + 1) {
if b >= pro {
set.erase_at_mut(idx + 1);
} else {
break;
}
}
}
for i in 0..m {
let (s, e, c, x) = secx[i];
let mut cur = c;
loop {
let dist = gdist(cur, c);
let (idx, _) = pool[cur].find_index(&(s - dist + 1, 0));
if idx >= 1 {
let &(last, profit) = pool[cur].at(idx - 1).unwrap();
assert!(last + dist <= s);
dp[i] = max(dp[i], profit);
}
if par[cur] >= n { break; }
cur = par[cur];
}
dp[i] += x;
let mut cur = c;
loop {
let dist = gdist(cur, c);
update_element(&mut pool[cur], dist + e, dp[i], &mut next);
if par[cur] >= n { break; }
cur = par[cur];
}
}
let ma: i64 = dp.into_iter().max().unwrap();
puts!("{}\n", ma);
}