参考
#include<bits/stdc++.h>
#define mid (l+r>>1)
using namespace std;
const int N = 1e5 + 100;
vector<int>G[N];
int n, k;
int val[N], son[N], dep[N], sz[N];
void dfs1(int u) {
sz[u] = 1;
for (int v : G[u]) {
//dep[v] = dep[u] + 1 要放在dfs前面 !!!
dep[v] = dep[u] + 1;
dfs1(v);
sz[u] += sz[v];
if (sz[v] > sz[son[u]]) {
son[u] = v;
}
}
}
int root[N], tot, lc[N * 200], rc[N * 200], sum[N * 200];
void insert(int& rt, int l, int r, int val, int pos) {
if (!rt)rt = ++tot;
sum[rt] += val;
if (l == r)return;
if (pos <= mid)insert(lc[rt], l, mid, val, pos);
else insert(rc[rt], mid + 1, r, val, pos);
}
long long query(int rt, int l, int r, int L, int R) {
if (!rt)return 0;
if (L <= l and r <= R)return sum[rt];
long long ans = 0;
if (L <= mid) ans += query(lc[rt], l, mid, L, R);
if (R > mid) ans += query(rc[rt], mid + 1, r, L, R);
return ans;
}
long long res;
void cal(int u, int rt) {
int d = dep[rt] + k - (dep[u] - dep[rt]);
d = min(d, n); // re ?
int v = 2 * val[rt] - val[u];
if (d >= 1 and v <= n and v >= 0) res += 2ll * query(root[v], 1, n, 1, d);
for (int x : G[u]) cal(x, rt);
}
void add(int u, int v) {
insert(root[val[u]], 1, n, v, dep[u]);
for (int x : G[u])add(x, v);
}
void DFS(int u) {
for (int v : G[u]) {
if (v == son[u])continue;
DFS(v); add(v, -1);
}
if (son[u])DFS(son[u]);
for (int v : G[u]) {
if (v == son[u])continue;
cal(v, u); add(v, 1);
}
insert(root[val[u]], 1, n, 1, dep[u]);
}
int main() {
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; i++)scanf("%d", val + i);
for (int i = 2; i <= n; i++) {
int x; scanf("%d", &x);
G[x].push_back(i);
}
dep[1] = 1;
dfs1(1);
DFS(1);
printf("%lld
", res);
}