• solution


    solution - 简单题(K-Dimension Tree)

    咕了这么久,终于可以来讲讲KDT了。

    说句实话,KDT的算法是非常简单的,但是很少有人能很快的写对,总是会出现一些奇奇怪怪的BUG,我自己也写了一个下午。主要是写代码时注意结构的对称性,以及算法的模块性,一个function干一件事就行。

    说了这么多,就开始讲算法吧。

    #1 算法描述

    考虑一种二叉树的结构,其中每一个节点有两个功能:

    • 存储这个节点( extbf{p}=[x_0, x_1, cdots,x_{k-1}])的值
    • 存储一个 (k) 维区域的记录,并且通过这一个节点的左右儿子(( extbf p_l, extbf p_r))将其在某一个维度将其分为两半,形式化地,就是对于任意的节点( extbf a in { extbf p_l ext{及其后代}}, extbf b in { extbf p_r ext{及其后代}}),存在某一个维度向量( extbf T = [x_0,x_1,cdots,x_{k-1}], ext{其中} x_t = 1, x_{p ot=t} = 0),有$ extbf T cdot extbf a leq extbf T cdot extbf p < extbf T cdot extbf b $

    考虑到算法的简便性,我们人为规定 (t) 为这个节点的深度模 (k)

    那么就十分简单了,可以很清楚的实现这个算法的插入,查询,删除。

    但是考虑到算法的单次复杂度还是 ( extrm O(n)),需要优化。

    我们可以使用替罪羊树的思路优化,即一但某个节点的某个左右儿子的重量大于这个节点的重量的 (alpha) 倍,那么就直接重构这个树。其中 (alpha) 基本在 (0.75) 附近最好。

    下面是代码实现。

    #2 代码实现细节

    #2.1 节点定义

    先是定义节点。

    template <int D>
    struct KDT {
    	KDT<D>* ls = nullptr, * rs = nullptr;
    	int mx[D], mn[D], pos[D];
    	int val; // 这个结点的值
    	int sum; // 这个节点及其子节点的值的和
    	int tot; // 这个节点及其子节点的数量
    	const bool operator < (KDT t) const {
    		return t.val < val;
    	} // 为了pair所必须的
    };
    

    其中 mx, mn 为这个节点及其子树的在所有维度的极大值与极小值。

    pos 为这个节点的维度值。

    其他的见注释。

    你可以注意到这里使用了指针来定义。

    #2.2 插入

    这里写一下这个程序的伪代码:

    $ extbf {function} Insert ( ext{这个节点的指针}nx, ext{插入的维度} extbf d, ext{节点的值}val, ext{节点的深度depth(模过D)}) $

    ( extbf{if not exist } nx extbf {then new }nx)

    $ extbf{else if } d_{val} leq nx.pos_{val} $

    $ extbf { then }nx leftarrow insert(nx ightarrow ls, d, val, depth + 1) $

    ( extbf { else } nx leftarrow insert(nx ightarrow rs, d, val, depth + 1);)

    $ extbf {if not } ext{balance} extbf { then } rebuild(nx) $

    ( update(nx))

    $ extbf {return } nx $

    ( extbf {end function})

    template <int D>
    KDT<D>* insert(KDT<D>* nx, int d[D], int val, int depth) {
    	if (depth >= D) depth -= D;
    	if (nx == nullptr) {
    		nx = new KDT<D>;
    		for (int i = 0; i < D; i++) {
    			nx->pos[i] = nx->mn[i] = nx->mx[i] = d[i];
    		}
    		nx->val = nx->sum = val;
    		nx->tot = 1;
    		return nx;
    	}
    	else {
    		int flag = 1;
    		for (int i = 0; i < D; i++) {
    			if (d[i] != nx->pos[i])
    				flag = 0;
    		}
    		if (flag) {
    			nx->val += val;
    			update(nx);
    			return nx;
    		}
    		if (d[depth] <= nx->pos[depth]) {
    			nx->ls = insert(nx->ls, d, val, depth + 1);
    		} else {
    			nx->rs = insert(nx->rs, d, val, depth + 1);
    		}
    		update(nx);
    		int mx = 0;
    		if (nx->ls != nullptr) mx = max(mx, nx->ls->tot);
    		if (nx->rs != nullptr) mx = max(mx, nx->rs->tot);
    		if (mx > nx->tot * alpha) {
    			pair<int, KDT<D> >* arr = 
                    new pair<int, KDT<D> >[nx->tot + 10];
    			pl = 0;
    			pia(nx, arr);
    			nx = rebuild(1, pl + 1, depth, arr);
    			delete[]arr;
    		}
    		return nx;
    	}
    }
    

    这里请注意以下rebuild块的写法。

    我的写法是新建一块内存存储被删除的节点 pair<int, KDT<D> >* arr = new pair<int, KDT<D> >[nx->tot + 10];

    然后将其节点删除,并放至arr中。

    代码如下:

    template <int D>
    void pia(KDT<D>* ptr, pair<int, KDT<D> >* arr) {
    	if (ptr != nullptr) {
    		arr[++pl].second = *ptr;
    		pia(ptr->ls, arr);
    		pia(ptr->rs, arr);
    		delete ptr;
    	}
    }
    

    最后是rebuild的一块。

    代码如下:

    template <int D>
    KDT<D>* rebuild(int L, int R, int dep, pair<int, KDT<D> >* arr) {
    	if (L >= R) return nullptr;
    	if (dep > D) dep -= D;
    	for (int i = L; i < R; i++) {
    		arr[i].first = arr[i].second.pos[dep];
    	}
    	int mid = (L + R) >> 1;
    	nth_element(arr + L, arr + mid, arr + R);
    	KDT<D>* ret = new KDT<D>;
    	*ret = arr[mid].second;
    	ret->ls = rebuild(L, mid, dep + 1, arr);
    	ret->rs = rebuild(mid + 1, R, dep + 1, arr);
    	update(ret);
    	return ret;
    }
    

    这里为了偷懒才用了系统的nth_element,否则可以不用pair数组。

    #2.3查询

    这一块比较简单,不予赘述。其中allin函数表示这个节点及其子节点全在所给范围之中。allout 相反。 in表示这个单独的点是否在区域中。

    代码如下:

    template <int D>
    int get_ans(KDT<D>* nx, int mx[D], int mn[D]) {
    	if (nx == nullptr) return 0;
    	if (allout(nx, mx, mn)) {
    		return 0;
    	}
    	if (allin(nx, mx, mn)) return nx->sum;
    	int ret = 0;
    	if (in(nx, mx, mn)) {
    		ret = nx->val;
    	}
    	ret += get_ans(nx->ls, mx, mn);
    	ret += get_ans(nx->rs, mx, mn);
    	return ret;
    }
    

    #3 代码呈现

    #include<cstdio>
    #include<algorithm>
    
    const double alpha = 0.75;
    const int maxn = 210000;
    
    using namespace std;
    
    template <int d>
    struct kdt {
    	kdt<d>* ls = nullptr, * rs = nullptr;
    	int mx[d], mn[d], pos[d];
    	int val;
    	int sum;
    	int tot;
    	const bool operator < (kdt t) const {
    		return t.val < val;
    	}
    };
    
    //pair <int, kdt<t> > arr[maxn];
    int pl = 0;
    template <int d>
    void update(kdt<d>* ret) {
    	ret->tot = 1;
    	ret->sum = ret->val;
    	for (int i = 0; i < d; i++) {
    		ret->mx[i] = ret->mn[i] = ret->pos[i];
    	}
    	if (ret->ls != nullptr) {
    		for (int i = 0; i < d; i++)
    			ret->mx[i] = max(ret->mx[i], ret->ls->mx[i]),
    			ret->mn[i] = min(ret->mn[i], ret->ls->mn[i]);
    		ret->sum += ret->ls->sum;
    		ret->tot += ret->ls->tot;
    	}
    	if (ret->rs != nullptr) {
    		for (int i = 0; i < d; i++)
    			ret->mx[i] = max(ret->mx[i], ret->rs->mx[i]),
    			ret->mn[i] = min(ret->mn[i], ret->rs->mn[i]);
    		ret->sum += ret->rs->sum;
    		ret->tot += ret->rs->tot;
    	}
    }
    
    template <int d>
    void pia(kdt<d>* ptr, pair<int, kdt<d> >* arr) {
    	if (ptr != nullptr) {
    		arr[++pl].second = *ptr;
    		pia(ptr->ls, arr);
    		pia(ptr->rs, arr);
    		delete ptr;
    	}
    }
    
    template <int d>
    kdt<d>* rebuild(int l, int r, int dep, pair<int, kdt<d> >* arr) {
    	if (l >= r) return nullptr;
    	if (dep > d) dep -= d;
    	for (int i = l; i < r; i++) {
    		arr[i].first = arr[i].second.pos[dep];
    	}
    	int mid = (l + r) >> 1;
    	nth_element(arr + l, arr + mid, arr + r);
    	kdt<d>* ret = new kdt<d>;
    	*ret = arr[mid].second;
    	ret->ls = rebuild(l, mid, dep + 1, arr);
    	ret->rs = rebuild(mid + 1, r, dep + 1, arr);
    	update(ret);
    	return ret;
    }
    
    template <int d>
    kdt<d>* insert(kdt<d>* nx, int d[d], int val, int depth) {
    	if (depth >= d) depth -= d;
    	if (nx == nullptr) {
    		nx = new kdt<d>;
    		for (int i = 0; i < d; i++) {
    			nx->pos[i] = nx->mn[i] = nx->mx[i] = d[i];
    		}
    		nx->val = nx->sum = val;
    		nx->tot = 1;
    		return nx;
    	}
    	else {
    		int flag = 1;
    		for (int i = 0; i < d; i++) {
    			if (d[i] != nx->pos[i])
    				flag = 0;
    		}
    		if (flag) {
    			nx->val += val;
    			update(nx);
    			return nx;
    		}
    		if (d[depth] < nx->pos[depth]) {
    			nx->ls = insert(nx->ls, d, val, depth + 1);
    		} else {
    			nx->rs = insert(nx->rs, d, val, depth + 1);
    		}
    		update(nx);
    		int mx = 0;
    		if (nx->ls != nullptr) mx = max(mx, nx->ls->tot);
    		if (nx->rs != nullptr) mx = max(mx, nx->rs->tot);
    		if (mx > nx->tot * alpha) {
    			pair<int, kdt<d> >* arr = new pair<int, kdt<d> >[nx->tot + 10];
    			pl = 0;
    			pia(nx, arr);
    			nx = rebuild(1, pl + 1, depth, arr);
    			delete[]arr;
    		}
    		return nx;
    	}
    }
    
    template <int d>
    int allin(kdt<d>* nx, int mx[d], int mn[d]) {
    	for (int i = 0; i < d; i++) {
    		if (nx->mx[i] > mx[i]) {
    			return 0;
    		}
    		if (nx->mn[i] < mn[i]) {
    			return 0;
    		}
    	}
    	return 1;
    }
    
    template <int d>
    int allout(kdt<d>* nx, int mx[d], int mn[d]) {
    	for (int i = 0; i < d; i++) {
    		if (nx->mn[i] > mx[i]) {
    			return 1;
    		}
    		if (nx->mx[i] < mn[i]) {
    			return 1;
    		}
    	}
    	return 0;
    }
    
    template <int d>
    int in(kdt<d>* nx, int mx[d], int mn[d]) {
    	for (int i = 0; i < d; i++) {
    		if (nx->pos[i] > mx[i]) {
    			return 0;
    		}
    		if (nx->pos[i] < mn[i]) {
    			return 0;
    		}
    	}
    	return 1;
    }
    
    template <int d>
    int get_ans(kdt<d>* nx, int mx[d], int mn[d]) {
    	if (nx == nullptr) return 0;
    	if (allout(nx, mx, mn)) {
    		return 0;
    	}
    	if (allin(nx, mx, mn)) return nx->sum;
    	int ret = 0;
    	if (in(nx, mx, mn)) {
    		ret = nx->val;
    	}
    	ret += get_ans(nx->ls, mx, mn);
    	ret += get_ans(nx->rs, mx, mn);
    	return ret;
    }
    
    int main() {
    	kdt<2>* root = nullptr;
    	int n;scanf("%d", &n);
    	int lst_ans = 0;
    	while (1) {
    		int opt;
    		scanf("%d", &opt);
    		if (opt == 3) break;
    		if (opt == 1) {
    			int d[2] = { 0,0}, val;
    			scanf("%d%d%d", d, d + 1, &val);
    			d[0] ^= lst_ans, d[1] ^= lst_ans, val ^= lst_ans;
    			root = insert(root, d, val, 0);
    		}
    		if (opt == 2) {
    			int mx[2] = { 0,0 }, mn[2] = { 0,0 };
    			scanf("%d%d%d%d", mn, mn + 1, mx, mx + 1);
    			mx[0] ^= lst_ans, mx[1] ^= lst_ans;
    			mn[0] ^= lst_ans, mn[1] ^= lst_ans;
    			lst_ans = get_ans(root, mx, mn);
    			printf("%d
    ", lst_ans);
    		}
    	}
    }
    
  • 相关阅读:
    【转】JavaScript里的this指针
    userscript.user.js 文件头
    css clearfix
    callback调用测试
    【个人】IIS Express 配置
    Js中 关于top、clientTop、scrollTop、offsetTop的用法
    【设为首页】/【收藏本站】
    JQuery插件开发
    Google Ajax Library API使用方法(JQuery)
    并发操作问题
  • 原文地址:https://www.cnblogs.com/dgklr/p/13874129.html
Copyright © 2020-2023  润新知