本篇,我们用go简单的实现二叉查找树。
1.节点定义
type BSNode struct{
data int
left, right, parent *BSNode
}
2.前序遍历
func (p *BSNode) PreTraverse() error{
if p.data == 0 {
return errors.New("Error: no data!")
}
fmt.Printf("%d ", p.data)
if p.left != nil {
p.left.PreTraverse()
}
if p.right != nil {
p.right.PreTraverse()
}
return nil
}
3.中序遍历
func (p *BSNode) InTraverse() error{
if p.data == 0 {
return errors.New("Error: no data!")
}
if p.left != nil {
p.left.InTraverse()
}
fmt.Printf("%d ", p.data)
if p.right != nil {
p.right.InTraverse()
}
return nil
}
4.后序遍历
func (p *BSNode) PostTraverse() error{
if p.data == 0 {
return errors.New("Error: no data!")
}
if p.left != nil {
p.left.PostTraverse()
}
if p.right != nil {
p.right.PostTraverse()
}
fmt.Printf("%d ", p.data)
return nil
}
5.添加节点
func (p *BSNode) Add(data int) error {
if data == 0 {
return errors.New("Error: not support 0 value!")
}
if p.data == 0 {
p.data = data
return nil
}
if p.data == data {
return errors.New("Error: add repeated data!")
} else if data < p.data {
if p.left == nil {
p.left = new(BSNode)
p.left.data = data
p.left.parent = p
return nil
}
p.left.Add(data)
} else {
if p.right == nil {
p.right = new(BSNode)
p.right.data = data
p.right.parent = p
return nil
}
p.right.Add(data)
}
return nil
}
6.删除节点
func (p *BSNode) Delete(data int) {
bsnode := p.Find(data)
if bsnode == nil {
return
}
if bsnode.left != nil {
var tmp *BSNode
for bsnode.left != nil {
bsnode.data = bsnode.left.data
tmp = bsnode
bsnode = bsnode.left
}
tmp.left = nil
return
}
if bsnode.right != nil {
var tmp *BSNode
for bsnode.right != nil {
bsnode.data = bsnode.right.data
tmp = bsnode
bsnode = bsnode.right
}
tmp.right = nil
return
}
if bsnode.parent != nil {
if bsnode.parent.left == bsnode {
bsnode.parent.left = nil
} else {
bsnode.parent.right = nil
}
}
}
7.查询节点
func (p *BSNode) Find(data int) *BSNode {
if p.data == data {
return p
} else if data < p.data {
if p.left != nil {
return p.left.Find(data)
}
return nil
} else {
if p.right != nil {
return p.right.Find(data)
}
return nil
}
}
8.测试代码
func main() {
num := []int{50, 20, 60, 40, 80, 10, 55, 52, 56}
var root *BSNode = new(BSNode)
for _, v := range num {
root.Add(v)
}
fmt.Println("前序遍历:")
root.PreTraverse()
fmt.Printf("
")
fmt.Println("中序遍历:")
root.InTraverse()
fmt.Printf("
")
fmt.Println("后序遍历:")
root.PostTraverse()
fmt.Printf("
")
bsnode := root.Find(60)
if bsnode != nil {
fmt.Println("查询结果:")
fmt.Printf("节点:%d 父节点:%d 左子节点:%d 右子节点:%d
", bsnode.data, bsnode.parent.data, bsnode.left.data, bsnode.right.data)
}
root.Delete(50)
fmt.Println("删除后前序遍历:")
root.PreTraverse()
fmt.Printf("
")
}
9.完整代码
package main
import (
"fmt"
"errors"
)
type BSNode struct{
data int
left, right, parent *BSNode
}
// 前序遍历
func (p *BSNode) PreTraverse() error{
if p.data == 0 {
return errors.New("Error: no data!")
}
fmt.Printf("%d ", p.data)
if p.left != nil {
p.left.PreTraverse()
}
if p.right != nil {
p.right.PreTraverse()
}
return nil
}
// 中序遍历
func (p *BSNode) InTraverse() error{
if p.data == 0 {
return errors.New("Error: no data!")
}
if p.left != nil {
p.left.InTraverse()
}
fmt.Printf("%d ", p.data)
if p.right != nil {
p.right.InTraverse()
}
return nil
}
// 后序遍历
func (p *BSNode) PostTraverse() error{
if p.data == 0 {
return errors.New("Error: no data!")
}
if p.left != nil {
p.left.PostTraverse()
}
if p.right != nil {
p.right.PostTraverse()
}
fmt.Printf("%d ", p.data)
return nil
}
// 添加节点
func (p *BSNode) Add(data int) error {
if data == 0 {
return errors.New("Error: not support 0 value!")
}
if p.data == 0 {
p.data = data
return nil
}
if p.data == data {
return errors.New("Error: add repeated data!")
} else if data < p.data {
if p.left == nil {
p.left = new(BSNode)
p.left.data = data
p.left.parent = p
return nil
}
p.left.Add(data)
} else {
if p.right == nil {
p.right = new(BSNode)
p.right.data = data
p.right.parent = p
return nil
}
p.right.Add(data)
}
return nil
}
// 删除节点
func (p *BSNode) Delete(data int) {
bsnode := p.Find(data)
if bsnode == nil {
return
}
if bsnode.left != nil {
var tmp *BSNode
for bsnode.left != nil {
bsnode.data = bsnode.left.data
tmp = bsnode
bsnode = bsnode.left
}
tmp.left = nil
return
}
if bsnode.right != nil {
var tmp *BSNode
for bsnode.right != nil {
bsnode.data = bsnode.right.data
tmp = bsnode
bsnode = bsnode.right
}
tmp.right = nil
return
}
if bsnode.parent != nil {
if bsnode.parent.left == bsnode {
bsnode.parent.left = nil
} else {
bsnode.parent.right = nil
}
}
}
// 查询节点
func (p *BSNode) Find(data int) *BSNode {
if p.data == data {
return p
} else if data < p.data {
if p.left != nil {
return p.left.Find(data)
}
return nil
} else {
if p.right != nil {
return p.right.Find(data)
}
return nil
}
}
func main() {
num := []int{50, 20, 60, 40, 80, 10, 55, 52, 56}
var root *BSNode = new(BSNode)
for _, v := range num {
root.Add(v)
}
fmt.Println("前序遍历:")
root.PreTraverse()
fmt.Printf("
")
fmt.Println("中序遍历:")
root.InTraverse()
fmt.Printf("
")
fmt.Println("后序遍历:")
root.PostTraverse()
fmt.Printf("
")
bsnode := root.Find(60)
if bsnode != nil {
fmt.Println("查询结果:")
fmt.Printf("节点:%d 父节点:%d 左子节点:%d 右子节点:%d
", bsnode.data, bsnode.parent.data, bsnode.left.data, bsnode.right.data)
}
root.Delete(50)
fmt.Println("删除后前序遍历:")
root.PreTraverse()
fmt.Printf("
")
}