    首先先假定以 (1) 为根做一遍 dfs,那么在 (rt) 为根的时候,对于一个点 (x),如果 (rt) 不在 (x) 的以 (1) 为根时的子树中,那么 (x)(rt) 为根时的子树和以 (1) 时的子树一样。

    如果 (rt)(x) 的以 (1) 为根时的子树中,那么我们求出 (y) 表示在以 (1) 为根时,(x) 的孩子中,子树里面有 (rt) 的。那么那么 (x)(rt) 为根时的子树就是除去 (y) 在以 (1) 为根时的子树的全部部分。

    如果 (rt = x),那么显然子树就是整棵树。

    然后根据子树的 dfs 序的连续性,我们就可以把原题转化为这样的问题:

    给定两个区间 ([l_1, r_1])([l_2, r_2]),求出两个区间中有多少对点的点权一样。

    有了前面一道题 [Snoi2017]一个简单的询问 的经验,我们知道,这个题目的做法是把一个有四个参数的询问拆分成四个有一个参数的询问。


    #define fec(i, x, y) (int i = head[x], y = g[i].to; i; i = g[i].ne, y = g[i].to)
    #define dbg(...) fprintf(stderr, __VA_ARGS__)
    #define File(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
    #define fi first
    #define se second
    #define pb push_back
    template<typename A, typename B> inline char smax(A &a, const B &b) {return a < b ? a = b, 1 : 0;}
    template<typename A, typename B> inline char smin(A &a, const B &b) {return b < a ? a = b, 1 : 0;}
    typedef long long ll; typedef unsigned long long ull; typedef std::pair<int, int> pii;
    namespace io {
    	const int SIZE = (1 << 21) + 1;
    	char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55]; int f, qr;
    	// getchar
    	#define gc() (iS == iT ? (iT = (iS = ibuf) + fread (ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS ++)) : *iS ++)
    	// print the remaining part
    	inline void flush () {
    		fwrite (obuf, 1, oS - obuf, stdout);
    		oS = obuf;
    	// putchar
    	inline void putc (char x) {
    		*oS ++ = x;
    		if (oS == oT) flush ();
    	// input a signed integer
    	template <class I>
    	inline void gi (I &x) {
    		for (f = 1, c = gc(); c < '0' || c > '9'; c = gc()) if (c == '-') f = -1;
    		for (x = 0; c <= '9' && c >= '0'; c = gc()) x = x * 10 + (c & 15); x *= f;
    	// print a signed integer
    	template <class I>
    	inline void print (I &x) {
    		if (!x) putc ('0'); if (x < 0) putc ('-'), x = -x;
    		while (x) qu[++ qr] = x % 10 + '0', x /= 10;
    		while (qr) putc (qu[qr --]);
    #define read io::gi
    const int N = 100000 + 7;
    const int M = 500000 + 7;
    #define bl(x) (((x) - 1) / blo + 1)
    int n, m, dfc, Q, ansi, blo;
    ll val;
    ll ans[M];
    int a[N], b[N], cl[N], cr[N];
    int dep[N], f[N], siz[N], son[N], top[N], dfn[N], pre[N];
    struct Query {
    	int opt, l, r;
    	ll *ans;
    	inline Query() {}
    	inline Query(const int &opt, const int &l, const int &r, ll *ans) : opt(opt), l(l), r(r), ans(ans) {
    		if (l > r) std::swap(this->l, this->r);
    		assert(this->l <= this->r);
    	inline bool operator < (const Query &b) const { return bl(l) == bl(b.l) ? r < b.r : l < b.l; }
    } q[M << 4];
    struct Edge { int to, ne; } g[N << 1]; int head[N], tot;
    inline void addedge(int x, int y) { g[++tot].to = y, g[tot].ne = head[x], head[x] = tot; }
    inline void adde(int x, int y) { addedge(x, y), addedge(y, x); }
    inline void dfs1(int x, int fa = 0) {
    	f[x] = fa, dep[x] = dep[fa] + 1, siz[x] = 1;
    	for fec(i, x, y) if (y != fa) dfs1(y, x), siz[x] += siz[y], siz[y] > siz[son[x]] && (son[x] = y);
    inline void dfs2(int x, int pa) {
    	top[x] = pa, dfn[x] = ++dfc, pre[dfc] = x;
    	if (!son[x]) return; dfs2(son[x], pa);
    	for fec(i, x, y) if (y != son[x] && y != f[x]) dfs2(y, y);
    inline int gson(int x, int p) {
    	int g = 0;
    	while (top[x] != top[p]) g = top[x], x = f[g];
    	return x == p ? g : son[p];
    inline bool intr(int x, int p) { return dfn[x] >= dfn[p] && dfn[x] <= dfn[p] + siz[p] - 1; }
    inline void addq(int l1, int r1, int l2, int r2, ll *ans) {
    	q[++Q] = Query(1, r1, r2, ans);
    	if (l1 > 1) q[++Q] = Query(-1, l1 - 1, r2, ans);
    	if (l2 > 1) q[++Q] = Query(-1, l2 - 1, r1, ans);
    	if (l1 > 1 && l2 > 1) q[++Q] = Query(1, l1 - 1, l2 - 1, ans);
    inline void lsh() {
    	memcpy(b, a, sizeof(int) * (n + 1));
    	std::sort(b + 1, b + n + 1);
    	int dis = std::unique(b + 1, b + n + 1) - b - 1;
    	for (int i = 1; i <= n; ++i) a[i] = std::lower_bound(b + 1, b + dis + 1, a[i]) - b;
    inline void addl(int x) {
    	val += cr[a[pre[x]]];
    inline void addr(int x) {
    	val += cl[a[pre[x]]];
    inline void dell(int x) {
    	val -= cr[a[pre[x]]];
    inline void delr(int x) {
    	val -= cl[a[pre[x]]];
    inline void work() {
    	blo = sqrt(n);
    	std::sort(q + 1, q + Q + 1);
    	int l = 0, r = 0;
    	for (int i = 1; i <= Q; ++i) {
    		while (r < q[i].r) addr(++r);
    		while (l < q[i].l) addl(++l);
    		while (l > q[i].l) dell(l--);
    		while (r > q[i].r) delr(r--);
    		*q[i].ans += q[i].opt * val;
    	for (int i = 1; i <= ansi; ++i) io::print(ans[i]), io::putc('
    inline void init() {
    	read(n), read(m);
    	for (int i = 1; i <= n; ++i) read(a[i]);
    	int x, y;
    	for (int i = 1; i < n; ++i) read(x), read(y), adde(x, y);
    	dfs1(1), dfs2(1, 1);
    	int rt = 1;
    	for (int i = 1; i <= m; ++i) {
    		int opt, x, y;
    		if (opt == 1) { read(rt); continue; }
    		read(x), read(y);
    		if (x == rt) {
    			if (y == rt) addq(1, n, 1, n, ans + ansi);
    			else if (!intr(rt, y)) addq(1, n, dfn[y], dfn[y] + siz[y] - 1, ans + ansi);
    			else {
    				y = gson(rt, y);
    				if (dfn[y] > 1) addq(1, n, 1, dfn[y] - 1, ans + ansi);
    				if (dfn[y] + siz[y] - 1 < n) addq(1, n, dfn[y] + siz[y], n, ans + ansi);
    		else if (!intr(rt, x)) {
    			if (y == rt) addq(dfn[x], dfn[x] + siz[x] - 1, 1, n, ans + ansi);
    			else if (!intr(rt, y)) addq(dfn[x], dfn[x] + siz[x] - 1, dfn[y], dfn[y] + siz[y] - 1, ans + ansi);
    			else {
    				y = gson(rt, y);
    				if (dfn[y] > 1) addq(dfn[x], dfn[x] + siz[x] - 1, 1, dfn[y] - 1, ans + ansi);
    				if (dfn[y] + siz[y] - 1 < n) addq(dfn[x], dfn[x] + siz[x] - 1, dfn[y] + siz[y], n, ans + ansi);
    		} else {
    			x = gson(rt, x);
    			if (y == rt) {
    				if (dfn[x] > 1) addq(1, dfn[x] - 1, 1, n, ans + ansi);
    				if (dfn[x] + siz[x] - 1 < n) addq(dfn[x] + siz[x], n, 1, n, ans + ansi);
    			} else if (!intr(rt, y)) {
    				if (dfn[x] > 1) addq(1, dfn[x] - 1, dfn[y], dfn[y] + siz[y] - 1, ans + ansi);
    				if (dfn[x] + siz[x] - 1 < n) addq(dfn[x] + siz[x], n, dfn[y], dfn[y] + siz[y] - 1, ans + ansi);
    			} else {
    				y = gson(rt, y);
    				if (dfn[x] > 1) {
    					if (dfn[y] > 1) addq(1, dfn[x] - 1, 1, dfn[y] - 1, ans + ansi);
    					if (dfn[y] + siz[y] - 1 < n) addq(1, dfn[x] - 1, dfn[y] + siz[y], n, ans + ansi);
    				if (dfn[x] + siz[x] - 1 < n) {
    					if (dfn[y] > 1) addq(dfn[x] + siz[x], n, 1, dfn[y] - 1, ans + ansi);
    					if (dfn[y] + siz[y] - 1 < n) addq(dfn[x] + siz[x], n, dfn[y] + siz[y], n, ans + ansi);
    int main() {
    #ifdef hzhkk
    	freopen("hkk.in", "r", stdin);
    //	fclose(stdin), fclose(stdout);
    	return 0;
