晚上无聊写了个二叉树(图)的广度和深度遍历算法,算法本身很简单,但是如何做到通用呢,一下代码是我的设计,请大家帮忙看看有什么问题,我自己感觉有问题就是不知道具体什么问题
public interface IGraph<TVertex> { IEnumerable<IEdge<TVertex>> Edges { get; } }
public interface IEdge<TVertex> { TVertex From { get; set; } TVertex To { get; set; } }
public interface INode { IEnumerable<TNode> GetNextNodes<TNode>() where TNode : INode; }
public class Edge<TVertex> : IEdge<TVertex> { public TVertex From { get; set; } public TVertex To { get; set; } }
public static class NodeVisitor { public static void BreadthVisit<TNode>(TNode rootNode, Action<TNode> visitAction) where TNode : INode { BreadthVisit(rootNode, n => n.GetNextNodes<TNode>(), visitAction); } public static void BreadthVisit<TNode>(TNode rootNode, Func<TNode,IEnumerable<TNode>> nextNodeSelector, Action<TNode> visitAction) { var nodeQueue = new Queue<TNode>(); nodeQueue.Enqueue(rootNode); while (nodeQueue.Any()) { var currentNode = nodeQueue.Dequeue(); if (visitAction != null) { visitAction(currentNode); } foreach (var nextNode in nextNodeSelector(currentNode)) { nodeQueue.Enqueue(nextNode); } } } public static void DepthVisit<TNode>(TNode rootNode, Func<TNode, IEnumerable<TNode>> nextNodeSelector, Action<TNode> visitAction) { var nodeStack = new Stack<TNode>(); nodeStack.Push(rootNode); while (nodeStack.Any()) { var currentNode = nodeStack.Pop(); if (visitAction != null) { visitAction(currentNode); } foreach (var nextNode in nextNodeSeletor(currentNode)) { nodeStack.Push(nextNode); } } } public static void DepthVisit<TNode>(TNode rootNode, Action<TNode> visitAction) where TNode : INode { DepthVisit(rootNode, n => n.GetNextNodes<TNode>(), visitAction); } }
public class GraphVisitor<TVertex> { private IGraph<TVertex> _graph; public GraphVisitor(IGraph<TVertex> graph) { _graph = graph; } public TVertex GetRoot() { var vertexs = _graph.Edges.Select(t => t.From).Concat(_graph.Edges.Select(t => t.To)); var toVertexs = _graph.Edges.Select(t => t.To); return vertexs.FirstOrDefault(t => toVertexs.All(v => !v.Equals(t))); } public IEnumerable<TVertex> GetNextVertexs(TVertex current) { return _graph.Edges.Where(t => t.From.Equals(current)).Select(t => t.To); } public void BreadthVisit(Action<TVertex> visitAction, TVertex startVertex) { NodeVisitor.BreadthVisit(startVertex, t => GetNextVertexs(t), visitAction); } public void BreadthVisit(Action<TVertex> visitAction) { NodeVisitor.BreadthVisit(GetRoot(), t => GetNextVertexs(t), visitAction); } public void DepthVisit(Action<TVertex> visitAction, TVertex startVertex) { NodeVisitor.DepthVisit(startVertex, t => GetNextVertexs(t), visitAction); } public void DepthVisit(Action<TVertex> visitAction) { NodeVisitor.DepthVisit(GetRoot(), t => GetNextVertexs(t), visitAction); } private class GraphNode : INode { private IList<INode> nodes = new List<INode>(); public string Id { get; set; } public void AddNext(INode node) { nodes.Add(node); } public IEnumerable<TNode> GetNextNodes<TNode>() where TNode : INode { return nodes.Cast<TNode>(); } } }
单元测试代码:
[TestClass] public class BreadthVisitorTest { [TestMethod] public void TestVisit() { var node1 = new TestNode() { Id = "1" }; var node1_1 = new TestNode() { Id = "1_1" }; var node1_2 = new TestNode() { Id = "1_2" }; var node1_1_1 = new TestNode() { Id = "1_1_1" }; var node1_1_2 = new TestNode() { Id = "1_1_2" }; var node1_1_3 = new TestNode() { Id = "1_1_3" }; var node1_2_1 = new TestNode() { Id = "1_2_1" }; var node1_2_2 = new TestNode() { Id = "1_2_2" }; node1.AddNext(node1_1); node1.AddNext(node1_2); node1_1.AddNext(node1_1_1); node1_1.AddNext(node1_1_2); node1_1.AddNext(node1_1_3); node1_2.AddNext(node1_2_1); node1_2.AddNext(node1_2_2); var expected = "1.1_1.1_2.1_1_1.1_1_2.1_1_3.1_2_1.1_2_2"; var actual = ""; NodeVisitor.BreadthVisit(node1, n => { actual += n.Id + "."; }); Assert.AreEqual(expected, actual.Trim('.')); expected = "1.1_1.1_1_1.1_1_2.1_1_3.1_2.1_2_1.1_2_2"; actual = ""; NodeVisitor.DepthVisit(node1, n => { actual += n.Id + "."; }); } [TestMethod] public void TestGraphVisit() { var graph = new Graph(); var graphVisitor = new GraphVisitor<int>(graph); graph.AddEdge(1, 2); graph.AddEdge(1, 3); graph.AddEdge(2, 4); graph.AddEdge(2, 5); graph.AddEdge(3, 6); var expected = "123456"; var actual = ""; graphVisitor.BreadthVisit(a => { actual += a.ToString(); }); Assert.AreEqual(expected, actual); expected = "124536"; actual = ""; graphVisitor.DepthVisit(a => { actual += a.ToString(); }); } } public class TestNode : INode { private IList<INode> nodes = new List<INode>(); public string Id { get; set; } public void AddNext(INode node) { nodes.Add(node); } public IEnumerable<TNode> GetNextNodes<TNode>() where TNode : INode { return nodes.Cast<TNode>(); } } public class Graph : IGraph<int> { private IList<IEdge<int>> _edges = new List<IEdge<int>>(); public IEnumerable<IEdge<int>> Edges { get { return _edges; } } public void AddEdge(int from, int to) { _edges.Add(new Edge<int>() { From = from, To = to }); } }