分析
就是把线段扔到了树上,注意区间查询要比较两个端点的函数值,
把区间赋值转换成两部分,从起点到LCA的区间是斜率为负数的线段,
从终点到LCA的区间是斜率为正数的线段。
代码
#include <cstdio>
#include <cctype>
#include <algorithm>
#define rr register
using namespace std;
typedef long long lll;
const int N = 100011;
struct node {
int y, w, next;
} e[N << 1];
int p[N << 2], Tot, dep[N], top[N], k = 1, tot, fat[N], nfd[N], dfn[N], son[N], big[N], n, m, as[N];
struct rec {
lll a, b;
} line[N << 1];
lll dis[N], w[N << 2];
inline signed iut() {
rr int ans = 0, f = 1;
rr char c = getchar();
while (!isdigit(c)) f = (c == '-') ? -f : f, c = getchar();
while (isdigit(c)) ans = (ans << 3) + (ans << 1) + (c ^ 48), c = getchar();
return ans * f;
}
inline void print(lll ans) {
if (ans < 0)
ans = -ans, putchar('-');
if (ans > 9)
print(ans / 10);
putchar(ans % 10 + 48);
}
inline lll min(lll a, lll b) { return a < b ? a : b; }
inline lll calc(int t, int x) { return line[t].a * dis[nfd[x]] + line[t].b; }
inline void build(int k, int l, int r) {
p[k] = 1, w[k] = line[1].b;
if (l == r)
return;
rr int mid = (l + r) >> 1;
build(k << 1, l, mid);
build(k << 1 | 1, mid + 1, r);
}
inline void update(int k, int l, int r, int x, int y, int z) {
rr int mid = (l + r) >> 1;
if (x <= l && r <= y) {
rr lll la = calc(p[k], l), lb = calc(z, l);
rr lll ra = calc(p[k], r), rb = calc(z, r), mb = min(lb, rb);
if (la <= lb && ra <= rb)
return;
if (la >= lb && ra >= rb) {
p[k] = z, w[k] = min(w[k], mb);
return;
}
rr double pos = 1.0 * (line[p[k]].b - line[z].b) / (line[z].a - line[p[k]].a);
if (la >= lb) {
if (pos <= dis[nfd[mid]])
update(k << 1, l, mid, x, y, z);
else
update(k << 1 | 1, mid + 1, r, x, y, p[k]), p[k] = z;
} else {
if (pos > dis[nfd[mid]])
update(k << 1 | 1, mid + 1, r, x, y, z);
else
update(k << 1, l, mid, x, y, p[k]), p[k] = z;
}
w[k] = min(min(w[k], mb), min(w[k << 1], w[k << 1 | 1]));
return;
}
if (x <= mid)
update(k << 1, l, mid, x, y, z);
if (mid < y)
update(k << 1 | 1, mid + 1, r, x, y, z);
w[k] = min(w[k], min(w[k << 1], w[k << 1 | 1]));
}
inline signed lca(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]])
x ^= y, y ^= x, x ^= y;
x = fat[top[x]];
}
if (dep[x] > dep[y])
x ^= y, y ^= x, x ^= y;
return x;
}
inline lll query(int k, int l, int r, int x, int y) {
if (l == x && r == y)
return w[k];
rr int mid = (l + r) >> 1;
rr lll lT = calc(p[k], x), rT = calc(p[k], y), mT = min(lT, rT);
if (y <= mid)
return min(mT, query(k << 1, l, mid, x, y));
else if (x > mid)
return min(mT, query(k << 1 | 1, mid + 1, r, x, y));
else
return min(mT, min(query(k << 1, l, mid, x, mid), query(k << 1 | 1, mid + 1, r, mid + 1, y)));
}
inline void Update(int x, int LCA, int z) {
for (; top[x] != top[LCA]; x = fat[top[x]]) update(1, 1, n, dfn[top[x]], dfn[x], z);
update(1, 1, n, dfn[LCA], dfn[x], z);
}
inline lll Query(int x, int y) {
rr lll ans = line[1].b;
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]])
x ^= y, y ^= x, x ^= y;
ans = min(ans, query(1, 1, n, dfn[top[x]], dfn[x]));
x = fat[top[x]];
}
if (dep[x] > dep[y])
x ^= y, y ^= x, x ^= y;
return ans = min(ans, query(1, 1, n, dfn[x], dfn[y]));
}
inline void dfs1(int x, int fa) {
dep[x] = dep[fa] + 1, fat[x] = fa, son[x] = 1;
for (rr int i = as[x], mson = -1; i; i = e[i].next)
if (e[i].y != fa) {
dis[e[i].y] = dis[x] + e[i].w, dfs1(e[i].y, x), son[x] += son[e[i].y];
if (son[e[i].y] > mson)
big[x] = e[i].y, mson = son[e[i].y];
}
}
inline void dfs2(int x, int linp) {
dfn[x] = ++tot, nfd[tot] = x, top[x] = linp;
if (!big[x])
return;
dfs2(big[x], linp);
for (rr int i = as[x]; i; i = e[i].next)
if (e[i].y != fat[x] && e[i].y != big[x])
dfs2(e[i].y, e[i].y);
}
signed main() {
n = iut();
m = iut();
for (rr int i = 1; i < n; ++i) {
rr int x = iut(), y = iut(), w = iut();
e[++k] = (node){ y, w, as[x] }, as[x] = k;
e[++k] = (node){ x, w, as[y] }, as[y] = k;
}
line[Tot = 1] = (rec){ 0, 123456789123456789ll };
dfs1(1, 0), dfs2(1, 1), build(1, 1, n);
for (rr int i = 1; i <= m; ++i)
if (iut() & 1) {
rr int x = iut(), y = iut(), A = iut(), B = iut(), LCA = lca(x, y);
line[++Tot] = (rec){ -A, dis[x] * A + B }, Update(x, LCA, Tot);
line[++Tot] = (rec){ A, (dis[x] - dis[LCA] * 2) * A + B }, Update(y, LCA, Tot);
} else
print(Query(iut(), iut())), putchar(10);
return 0;
}