solution - 简单题(K-Dimension Tree)
咕了这么久,终于可以来讲讲KDT了。
说句实话,KDT的算法是非常简单的,但是很少有人能很快的写对,总是会出现一些奇奇怪怪的BUG,我自己也写了一个下午。主要是写代码时注意结构的对称性,以及算法的模块性,一个function干一件事就行。
说了这么多,就开始讲算法吧。
#1 算法描述
考虑一种二叉树的结构,其中每一个节点有两个功能:
- 存储这个节点( extbf{p}=[x_0, x_1, cdots,x_{k-1}])的值
- 存储一个 (k) 维区域的记录,并且通过这一个节点的左右儿子(( extbf p_l, extbf p_r))将其在某一个维度将其分为两半,形式化地,就是对于任意的节点( extbf a in { extbf p_l ext{及其后代}}, extbf b in { extbf p_r ext{及其后代}}),存在某一个维度向量( extbf T = [x_0,x_1,cdots,x_{k-1}], ext{其中} x_t = 1, x_{p ot=t} = 0),有$ extbf T cdot extbf a leq extbf T cdot extbf p < extbf T cdot extbf b $
考虑到算法的简便性,我们人为规定 (t) 为这个节点的深度模 (k)。
那么就十分简单了,可以很清楚的实现这个算法的插入,查询,删除。
但是考虑到算法的单次复杂度还是 ( extrm O(n)),需要优化。
我们可以使用替罪羊树的思路优化,即一但某个节点的某个左右儿子的重量大于这个节点的重量的 (alpha) 倍,那么就直接重构这个树。其中 (alpha) 基本在 (0.75) 附近最好。
下面是代码实现。
#2 代码实现细节
#2.1 节点定义
先是定义节点。
template <int D>
struct KDT {
KDT<D>* ls = nullptr, * rs = nullptr;
int mx[D], mn[D], pos[D];
int val; // 这个结点的值
int sum; // 这个节点及其子节点的值的和
int tot; // 这个节点及其子节点的数量
const bool operator < (KDT t) const {
return t.val < val;
} // 为了pair所必须的
};
其中 mx, mn
为这个节点及其子树的在所有维度的极大值与极小值。
pos
为这个节点的维度值。
其他的见注释。
你可以注意到这里使用了指针来定义。
#2.2 插入
这里写一下这个程序的伪代码:
$ extbf {function} Insert ( ext{这个节点的指针}nx, ext{插入的维度} extbf d, ext{节点的值}val, ext{节点的深度depth(模过D)}) $
( extbf{if not exist } nx extbf {then new }nx)
$ extbf{else if } d_{val} leq nx.pos_{val} $
$ extbf { then }nx leftarrow insert(nx ightarrow ls, d, val, depth + 1) $
( extbf { else } nx leftarrow insert(nx ightarrow rs, d, val, depth + 1);)
$ extbf {if not } ext{balance} extbf { then } rebuild(nx) $
( update(nx))
$ extbf {return } nx $
( extbf {end function})
template <int D>
KDT<D>* insert(KDT<D>* nx, int d[D], int val, int depth) {
if (depth >= D) depth -= D;
if (nx == nullptr) {
nx = new KDT<D>;
for (int i = 0; i < D; i++) {
nx->pos[i] = nx->mn[i] = nx->mx[i] = d[i];
}
nx->val = nx->sum = val;
nx->tot = 1;
return nx;
}
else {
int flag = 1;
for (int i = 0; i < D; i++) {
if (d[i] != nx->pos[i])
flag = 0;
}
if (flag) {
nx->val += val;
update(nx);
return nx;
}
if (d[depth] <= nx->pos[depth]) {
nx->ls = insert(nx->ls, d, val, depth + 1);
} else {
nx->rs = insert(nx->rs, d, val, depth + 1);
}
update(nx);
int mx = 0;
if (nx->ls != nullptr) mx = max(mx, nx->ls->tot);
if (nx->rs != nullptr) mx = max(mx, nx->rs->tot);
if (mx > nx->tot * alpha) {
pair<int, KDT<D> >* arr =
new pair<int, KDT<D> >[nx->tot + 10];
pl = 0;
pia(nx, arr);
nx = rebuild(1, pl + 1, depth, arr);
delete[]arr;
}
return nx;
}
}
这里请注意以下rebuild
块的写法。
我的写法是新建一块内存存储被删除的节点 pair<int, KDT<D> >* arr = new pair<int, KDT<D> >[nx->tot + 10];
然后将其节点删除,并放至arr
中。
代码如下:
template <int D>
void pia(KDT<D>* ptr, pair<int, KDT<D> >* arr) {
if (ptr != nullptr) {
arr[++pl].second = *ptr;
pia(ptr->ls, arr);
pia(ptr->rs, arr);
delete ptr;
}
}
最后是rebuild
的一块。
代码如下:
template <int D>
KDT<D>* rebuild(int L, int R, int dep, pair<int, KDT<D> >* arr) {
if (L >= R) return nullptr;
if (dep > D) dep -= D;
for (int i = L; i < R; i++) {
arr[i].first = arr[i].second.pos[dep];
}
int mid = (L + R) >> 1;
nth_element(arr + L, arr + mid, arr + R);
KDT<D>* ret = new KDT<D>;
*ret = arr[mid].second;
ret->ls = rebuild(L, mid, dep + 1, arr);
ret->rs = rebuild(mid + 1, R, dep + 1, arr);
update(ret);
return ret;
}
这里为了偷懒才用了系统的nth_element
,否则可以不用pair
数组。
#2.3查询
这一块比较简单,不予赘述。其中allin
函数表示这个节点及其子节点全在所给范围之中。allout
相反。 in
表示这个单独的点是否在区域中。
代码如下:
template <int D>
int get_ans(KDT<D>* nx, int mx[D], int mn[D]) {
if (nx == nullptr) return 0;
if (allout(nx, mx, mn)) {
return 0;
}
if (allin(nx, mx, mn)) return nx->sum;
int ret = 0;
if (in(nx, mx, mn)) {
ret = nx->val;
}
ret += get_ans(nx->ls, mx, mn);
ret += get_ans(nx->rs, mx, mn);
return ret;
}
#3 代码呈现
#include<cstdio>
#include<algorithm>
const double alpha = 0.75;
const int maxn = 210000;
using namespace std;
template <int d>
struct kdt {
kdt<d>* ls = nullptr, * rs = nullptr;
int mx[d], mn[d], pos[d];
int val;
int sum;
int tot;
const bool operator < (kdt t) const {
return t.val < val;
}
};
//pair <int, kdt<t> > arr[maxn];
int pl = 0;
template <int d>
void update(kdt<d>* ret) {
ret->tot = 1;
ret->sum = ret->val;
for (int i = 0; i < d; i++) {
ret->mx[i] = ret->mn[i] = ret->pos[i];
}
if (ret->ls != nullptr) {
for (int i = 0; i < d; i++)
ret->mx[i] = max(ret->mx[i], ret->ls->mx[i]),
ret->mn[i] = min(ret->mn[i], ret->ls->mn[i]);
ret->sum += ret->ls->sum;
ret->tot += ret->ls->tot;
}
if (ret->rs != nullptr) {
for (int i = 0; i < d; i++)
ret->mx[i] = max(ret->mx[i], ret->rs->mx[i]),
ret->mn[i] = min(ret->mn[i], ret->rs->mn[i]);
ret->sum += ret->rs->sum;
ret->tot += ret->rs->tot;
}
}
template <int d>
void pia(kdt<d>* ptr, pair<int, kdt<d> >* arr) {
if (ptr != nullptr) {
arr[++pl].second = *ptr;
pia(ptr->ls, arr);
pia(ptr->rs, arr);
delete ptr;
}
}
template <int d>
kdt<d>* rebuild(int l, int r, int dep, pair<int, kdt<d> >* arr) {
if (l >= r) return nullptr;
if (dep > d) dep -= d;
for (int i = l; i < r; i++) {
arr[i].first = arr[i].second.pos[dep];
}
int mid = (l + r) >> 1;
nth_element(arr + l, arr + mid, arr + r);
kdt<d>* ret = new kdt<d>;
*ret = arr[mid].second;
ret->ls = rebuild(l, mid, dep + 1, arr);
ret->rs = rebuild(mid + 1, r, dep + 1, arr);
update(ret);
return ret;
}
template <int d>
kdt<d>* insert(kdt<d>* nx, int d[d], int val, int depth) {
if (depth >= d) depth -= d;
if (nx == nullptr) {
nx = new kdt<d>;
for (int i = 0; i < d; i++) {
nx->pos[i] = nx->mn[i] = nx->mx[i] = d[i];
}
nx->val = nx->sum = val;
nx->tot = 1;
return nx;
}
else {
int flag = 1;
for (int i = 0; i < d; i++) {
if (d[i] != nx->pos[i])
flag = 0;
}
if (flag) {
nx->val += val;
update(nx);
return nx;
}
if (d[depth] < nx->pos[depth]) {
nx->ls = insert(nx->ls, d, val, depth + 1);
} else {
nx->rs = insert(nx->rs, d, val, depth + 1);
}
update(nx);
int mx = 0;
if (nx->ls != nullptr) mx = max(mx, nx->ls->tot);
if (nx->rs != nullptr) mx = max(mx, nx->rs->tot);
if (mx > nx->tot * alpha) {
pair<int, kdt<d> >* arr = new pair<int, kdt<d> >[nx->tot + 10];
pl = 0;
pia(nx, arr);
nx = rebuild(1, pl + 1, depth, arr);
delete[]arr;
}
return nx;
}
}
template <int d>
int allin(kdt<d>* nx, int mx[d], int mn[d]) {
for (int i = 0; i < d; i++) {
if (nx->mx[i] > mx[i]) {
return 0;
}
if (nx->mn[i] < mn[i]) {
return 0;
}
}
return 1;
}
template <int d>
int allout(kdt<d>* nx, int mx[d], int mn[d]) {
for (int i = 0; i < d; i++) {
if (nx->mn[i] > mx[i]) {
return 1;
}
if (nx->mx[i] < mn[i]) {
return 1;
}
}
return 0;
}
template <int d>
int in(kdt<d>* nx, int mx[d], int mn[d]) {
for (int i = 0; i < d; i++) {
if (nx->pos[i] > mx[i]) {
return 0;
}
if (nx->pos[i] < mn[i]) {
return 0;
}
}
return 1;
}
template <int d>
int get_ans(kdt<d>* nx, int mx[d], int mn[d]) {
if (nx == nullptr) return 0;
if (allout(nx, mx, mn)) {
return 0;
}
if (allin(nx, mx, mn)) return nx->sum;
int ret = 0;
if (in(nx, mx, mn)) {
ret = nx->val;
}
ret += get_ans(nx->ls, mx, mn);
ret += get_ans(nx->rs, mx, mn);
return ret;
}
int main() {
kdt<2>* root = nullptr;
int n;scanf("%d", &n);
int lst_ans = 0;
while (1) {
int opt;
scanf("%d", &opt);
if (opt == 3) break;
if (opt == 1) {
int d[2] = { 0,0}, val;
scanf("%d%d%d", d, d + 1, &val);
d[0] ^= lst_ans, d[1] ^= lst_ans, val ^= lst_ans;
root = insert(root, d, val, 0);
}
if (opt == 2) {
int mx[2] = { 0,0 }, mn[2] = { 0,0 };
scanf("%d%d%d%d", mn, mn + 1, mx, mx + 1);
mx[0] ^= lst_ans, mx[1] ^= lst_ans;
mn[0] ^= lst_ans, mn[1] ^= lst_ans;
lst_ans = get_ans(root, mx, mn);
printf("%d
", lst_ans);
}
}
}