package cn.aust.zyw.demo; public class BST<Key extends Comparable<Key>,Value> { public static void main(String args[]){ BST<Integer,String> bst=new BST<>(); bst.put(3,"three"); bst.put(5,"five"); bst.put(1,"one"); bst.put(6,"six"); bst.put(4,"four"); System.out.println(bst.celling(2)); } private Node root; private class Node{ private Key key; private Value val; private Node left,right; private int N;//以该结点为根的子树的结点总数 public Node(Key key,Value val,int N){ this.key=key;this.val=val;this.N=N; } } public int size(){return size(root);} private int size(Node x){ if(x==null) return 0; return x.N; } public Value get(Key key){return get(root,key);} //按索引值查找 private Value get(Node x,Key key){ if(x==null) return null; int cmp=key.compareTo(x.key); if(cmp<0) return get(x.left,key); if(cmp>0) return get(x.right,key); else return x.val; } public void put(Key key,Value val){ root=put(root,key,val); } //若存在直接改变val,不存在new node,则相应子节点数目+1 private Node put(Node x,Key key,Value val){ if(x==null) return new Node(key,val,1); int cmp=key.compareTo(x.key); if(cmp<0) x.left=put(x.left,key,val); else if(cmp>0) x.right=put(x.right,key,val); else x.val=val; x.N=size(x.left)+size(x.right)+1; return x; } //return 最小子节点 public Key min(){return min(root).key;} private Node min(Node x){ if(x.left==null) return x; return min(x.left); } //return 最大子节点 public Key max(){return max(root).key;} public Node max(Node x){ if(x.right==null) return x; return max(x.right); } //[key]向下取整 public Key floor(Key key){ Node x=floor(root,key); if(x==null) return null; return x.key; } private Node floor(Node x,Key key){ if(x==null) return null; int cmp=key.compareTo(x.key); if(cmp==0) return x; if(cmp<0) return floor(x.left,key); Node t=floor(x.right,key); if(t!=null) return t; else return x; } public Key celling(Key key){ Node x=celling(root,key); if(x==null) return null; return x.key; } private Node celling(Node x,Key key){ if(x==null) return null; int cmp=key.compareTo(x.key); if(cmp==0) return x; if(cmp>0) return celling(x.right,key); Node t=celling(x.left,key); if(t!=null) return t; else return x; } }