• 一篇系列——KD-Tree 一篇就够了


    Python代码部分

    
    import heapq
    
    maxn = int(1e5)
    K = 13 # dimension
    
    class Pt:
        def __init__(self, x = None) -> None:
            # x 假如用 List 会出现天坑.... href: https://www.cnblogs.com/jclian91/p/10325849.html
            if x is None: self.x = [0 for _ in range(K)]
            else : self.x = x
            self.val = 0
    
    class Node:
        def __init__(self, d = 0, id = 0) -> None:
            """ 用于放到heapq中的结点 """
            self.d = d 
            self.id = id 
        def __lt__(self, other):
            return self.d > other.d 
        def __str__(self) -> str:
            return "(d: %.2f, id: %d)" % (self.d, self.id)
    
    class KDTree_Node:
        def __init__(self) -> None:
            self.SZ = 0
            self.lc = 0
            self.rc = 0
            self.maxn = [0] * K
            self.minn = [0] * K 
            self.place = Pt()
    
    class KDTree:
        def __init__(self, n = 0, alpha= 0.75) -> None:
            """ initial KDself.Tree """
            self.n = n # p array length 
            self.p = [Pt() for _ in range(maxn)]
            self.Tr = [KDTree_Node() for _ in range(maxn)]
            # 优先级队列,大根堆
            self.hp = []
            self.alpha = alpha # 替罪羊树重构因子
            self.cuK = 0 # current K
            self.top, self.tot = 0, 0 # self.top-> self.store array, self.tot -> self.total KD-self.Tree size
            self.store = [0] * maxn # self.store useless node
            self.root = 1 # temp save the KDTree root 
        
        def New(self) -> int:
            if self.top != 0:
                self.top -= 1
                return self.store[self.top + 1]
            self.tot += 1
            return self.tot
    
        def update(self, x) -> None:
            """ use to update the father node info """
            lc = self.Tr[x].lc; rc = self.Tr[x].rc 
            for i in range(K):
                self.Tr[x].maxn[i] = self.Tr[x].minn[i] = self.Tr[x].place.x[i]
                if (lc != 0):
                    self.Tr[x].maxn[i] = max(self.Tr[x].maxn[i], self.Tr[lc].maxn[i])
                    self.Tr[x].minn[i] = min(self.Tr[x].minn[i], self.Tr[lc].minn[i])
                if (rc != 0):
                    self.Tr[x].maxn[i] = max(self.Tr[x].maxn[i], self.Tr[rc].maxn[i])
                    self.Tr[x].minn[i] = min(self.Tr[x].minn[i], self.Tr[rc].minn[i])
            self.Tr[x].SZ = self.Tr[lc].SZ + self.Tr[rc].SZ + 1
    
        def build(self, l: int, r: int, dep= 0) -> int:
            if (l > r): return 0
            self.cuK = dep % K
            mid = (l + r) >> 1
            x = self.New()
            
            temp_p = self.p[l:r + 1]
            temp_p.sort(key= lambda arr: arr.x[self.cuK])
            self.p[l:r + 1] = temp_p 
            
            self.Tr[x].place = self.p[mid]
            self.Tr[x].lc = self.build(l, mid - 1, dep + 1)
            self.Tr[x].rc = self.build(mid + 1, r, dep + 1) 
            self.update(x)
            return x
    
        def rebuild(self, x: int, base= 0):
            """ recursion rebuild, and store these self.Tree node in store """
            lc = self.Tr[x].lc
            rc = self.Tr[x].rc
            if lc != 0: self.rebuild(lc, base)
            self.p[base + self.Tr[lc].SZ + 1] = self.Tr[x].place
            self.top += 1; self.store[self.top] = x;
            if rc != 0: self.rebuild(rc, base + self.Tr[lc].SZ + 1)
    
        def check(self, x: int, dep: int) -> int:
            """ if meet the limit, rebuild the sub-Tree """
            if (self.alpha * self.Tr[x].SZ < self.Tr[self.Tr[x].lc].SZ or self.alpha * self.Tr[x].SZ < self.Tr[self.Tr[x].rc].SZ):
                self.rebuild(x)
                x = self.build(1, self.Tr[x].SZ, dep)
            return x
    
        def insert(self, inP: Pt, loc: int, dep= 0) -> int:
            if loc == 0:
                loc = self.New()
                self.Tr[loc].place = inP
                self.Tr[loc].lc = self.Tr[loc].rc = 0
                self.update(loc)
                return loc 
            if (inP.x[dep % K] <= self.Tr[loc].place.x[dep % K]):
                self.Tr[loc].lc = self.insert(inP, self.Tr[loc].lc, dep + 1)
            else :
                self.Tr[loc].rc = self.insert(inP, self.Tr[loc].rc, dep +1)
            self.update(loc)
            loc = self.check(loc, dep)
            return loc 
    
        def getdis(self, temp: Pt, x: int):
            """ 用于计算upper bound的估值 """
            res = 0
            for i in range(K):
                res += (max(0, temp.x[i] - self.Tr[x].maxn[i]) + max(0, self.Tr[x].minn[i] - temp.x[i])) ** 2
            return res
    
        def dist(self, a, b):
            res = 0
            for i in range(K):
                res += (a.x[i] - b.x[i]) ** 2
            return res 
    
        def _query(self, ask: Pt, x: int, k: int) -> None:
            """ inter query """
            d = self.dist(ask, self.Tr[x].place) # 当前值
            heapq.heappush(self.hp, Node(d, x))
            if len(self.hp) > k:
                heapq.heappop(self.hp)
            lim = self.hp[0].d # upper bound 
            lc = self.Tr[x].lc; rc = self.Tr[x].rc;
            inf = 1e9 # 最值
            dl = inf; dr = inf; # 左右孩子的估值 dl, dr 
            if lc != 0: dl = self.getdis(ask, lc);
            if rc != 0: dr = self.getdis(ask, rc);
            
            # print(lim, dl, dr, d, len(self.hp), "lc: ", lc, "rc: ", rc)
            
            if dl > dr:
                dl, dr = dr, dl
                lc, rc = rc, lc 
            if (dl < lim or len(self.hp) < k): self._query(ask, lc, k);
            if (dr < lim or len(self.hp) < k): self._query(ask, rc, k);
            
        def query(self, ask: Pt, x: int, k= 1):
            """ outer query """
            assert k > 0 and k <= self.n
            self.hp = []
            self._query(ask, x, k)
            return [[e.d, e.id] for e in self.hp]
    
    # KNN-KDTree
    class KNN_KDTree:
        def __init__(self, X_train, y_train, n_neighbors=3, p=2): # 通过n_neighbors修改k值
            """
            parameter: n_neighbors 临近点个数
            parameter: p 距离度量
            """
            self.n = n_neighbors
            self.p = p
            self.X_train = X_train
            self.y_train = y_train
            self.KDTree = KDTree()
            self.KDTree.n = len(X_train)
            
            Len = len(self.X_train)
            for i in range(1, len(self.X_train) + 1, 1):
                for j in range(K):
                    self.KDTree.p[i].x[j] = self.X_train[i - 1][j]
                self.KDTree.p[i].val = self.y_train[i - 1]
            self.KDTree.root = self.KDTree.build(1, Len)
            
        def predict(self, X):
            res = self.KDTree.query(Pt(X), self.KDTree.root, self.n)
            knn = [self.KDTree.Tr[k[-1]].place.val for k in res]
            count_pairs = Counter(knn)
            # print(count_pairs.items())
            max_count = sorted(count_pairs.items(), key= lambda x: x[1])[-1][0]
            return max_count
    
        def score(self, X_test, y_test):
            right_count = 0
            for X, y in zip(X_test, y_test):
                label = self.predict(X)
                if label == y:
                    right_count += 1
            return right_count / len(X_test)
    

    C++代码部分

    
    #include <bits/stdc++.h>
    using namespace std;
    
    using ll = long long;
    using db = double;
    const int maxn = 3e5 + 50;
    const int K = 2; // 维度
    const db alpha = 0.75; // 替罪羊树重构因子
    int cuK; // current K
    int top, tot; // top -> store array, tot -> total KDTree size 
    int store[maxn]; // store useless node, to save memory 
    
    
    struct Pt{
        int x[K]; // denote the position
        int val; // denote the value of point
        /* compare by K-Dimension */
        inline bool operator< (const Pt &other){
            return x[cuK] < other.x[cuK];
        }
    }p[maxn];
    
    struct KDTree{
        int SZ; // denote the size of child root 
        int lc, rc; // denote left, right child 
        int maxn[K], minn[K]; // maxn(i) denote the maximum of K-th dimension
        Pt place; // denote the split point
    }Tr[maxn];
    
    inline int New() { if (top) return store[top--]; return ++tot; }
    
    #define chmax(x, y) x = max(x, y)
    #define chmin(x, y) x = min(x, y)
    inline void update(int x){
        /* use to update the father node info */
        int lc = Tr[x].lc, rc = Tr[x].rc;
        for (int i = 0; i < K; ++ i){
            Tr[x].maxn[i] = Tr[x].minn[i] = Tr[x].place.x[i];
            if (lc) chmax(Tr[x].maxn[i], Tr[lc].maxn[i]), chmin(Tr[x].minn[i], Tr[lc].minn[i]);
            if (rc) chmax(Tr[x].maxn[i], Tr[rc].maxn[i]), chmin(Tr[x].minn[i], Tr[rc].minn[i]);
        }
        Tr[x].SZ = Tr[lc].SZ + Tr[rc].SZ + 1; // update size 
    }
    
    int build(int l, int r, int dep= 0){
        if (l > r) return 0;
        cuK = dep % K; // determine the current dimension
        int mid = (1ll * l +  r) >> 1, x = New();
        nth_element(p + l, p + mid, p + r + 1), Tr[x].place = p[mid];
        Tr[x].lc = build(l, mid - 1, dep + 1), Tr[x].rc = build(mid + 1, r, dep + 1);
        update(x);
        return x;
    }
    
    void rebuild(int x, int base= 0){
        /* recursion rebuild, and store these tree node in store */
        int lc = Tr[x].lc, rc = Tr[x].rc;
        if (lc) rebuild(lc, base);
        p[base + Tr[lc].SZ + 1] = Tr[x].place, store[++top] = x;
        if (rc) rebuild(rc, base + Tr[lc].SZ + 1);
    }
    
    void check(int &x, int dep){
        /* if meet the limit, rebuild the sub-tree */
        if (alpha * Tr[x].SZ < Tr[Tr[x].lc].SZ || alpha * Tr[x].SZ < Tr[Tr[x].rc].SZ)
            rebuild(x), x = build(1, Tr[x].SZ, dep);
    }
    
    void insert(Pt inP, int &loc, int dep= 0){
        if (!loc) { loc = New(); Tr[loc].place = inP, Tr[loc].lc = Tr[loc].rc = 0; update(loc); return; }
        if (inP.x[dep % K] <= Tr[loc].place.x[dep % K]) insert(inP, Tr[loc].lc, dep + 1);
        else insert(inP, Tr[loc].rc, dep + 1);
        update(loc), check(loc, dep);
    }
    
    
    int getdis(Pt temp, int x){
        int res = 0;
        for (int i = 0; i < K; ++ i){
            res += max(0, temp.x[i] - Tr[x].maxn[i]) + max(0, Tr[x].minn[i] - temp.x[i]); 
        }
        return res;
    }
    
    int dist(Pt a, Pt b){
        int res = 0;
        for (int i = 0; i < K; ++ i) res += abs(a.x[i] - b.x[i]);
        return res; 
    }
    
    int ans;
    const int inf = 0x3f3f3f3f;
    inline void query(Pt ask, int x){
        chmin(ans, dist(ask, Tr[x].place));
        int lc = Tr[x].lc, rc = Tr[x].rc;
        int dl = inf, dr = inf;
        if (lc) dl = getdis(ask, lc);
        if (rc) dr = getdis(ask, rc);
        if (dl > dr) swap(dl, dr), swap(lc, rc);
        if (dl < ans) query(ask, lc);
        if (dr < ans) query(ask, rc);
    }
    
    void solve(){
        int n, m; std::cin >> n >> m;
        for (int i = 1; i <= n; ++ i) std::cin >> p[i].x[0] >> p[i].x[1];
        int root = build(1, n);
        for (int i = 0; i < m; ++ i){
            int type; std::cin >> type;
            Pt ask; std::cin >> ask.x[0] >> ask.x[1];
            if (type == 1) insert(ask, root);
            else { ans = inf; query(ask, root); std::cout << ans << "
    "; }
        }
    }
    
    int main(){
        ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
        solve();
        return 0;
    }
    
    
  • 相关阅读:
    在cmd命令行中弹出Windows对话框
    Windows远程桌面连接如何直接使用剪贴板功能
    升级Windows10后Apache服务器启动失败的解决方法
    Windows下尝试PHP7提示丢失VCRUNTIME140.DLL的问题解决
    手动构建Servlet项目的流程
    更改Apache默认网站根目录
    windows下安装Appserv等php套件之后无法进入数据库管理的问题
    Java web项目的字符集问题
    五谷-小米:白小米
    五谷-小米:黑小米
  • 原文地址:https://www.cnblogs.com/Last--Whisper/p/14550592.html
Copyright © 2020-2023  润新知