负载均衡是一种手段,用来把对某种资源的访问分摊给不同的设备,从而减轻单点的压力。
架构图
图中左侧为ZooKeeper集群,右侧上方为工作服务器,下面为客户端。每台工作服务器在启动时都会去zookeeper的servers节点下注册临时节点,每台客户端在启动时都会去servers节点下取得所有可用的工作服务器列表,并通过一定的负载均衡算法计算得出一台工作服务器,并与之建立网络连接。网络连接我们采用开源框架netty。
流程图
负载均衡客户端流程
服务端主体流程
类图
Server端核心类
每个服务端对应一个Server接口,ServiceImpl是服务端的实现类。把服务端启动时的注册过程抽出为一个接口RegistProvider,并给予一个默认实现DefaultRegistProvider,它将用到一个上下文的类ZooKeeperRegistContext。我们的服务端是给予Netty的,它需要ServerHandler来处理与客户端之间的连接,当有客户端建立或失去连接时,我们都需要去修改当前服务器的负载信息,我们把修改负载信息的过程也抽出为一个接口BalanceUpdateProvider,并且给予了一个默认实现DefaultBalanceUpdateProvider。ServerRunner是调度类,负责调度我们的Server。
Client端核心类
每个客户端都需要实现一个Client接口,ClientImpl是实现,Client需要ClientHandler来处理与服务器之前的通讯,同时它需要BalanceProvider为它提供负载均衡的算法。BalanceProvider是接口,它有2个实现类,一个是抽象的实现AbstractBalanceProvider,一个是默认的实现DefaultBalanceProvider。ServerData是服务端和客户端共用的一个类,服务端会把自己的基本信息,包括负载信息,打包成ServerData并写入到zookeeper中,客户端在计算负载的时候需要到zookeeper中拿到ServerData,并取得服务端的地址和负载信息。ClientRunner是客户端的调度类,负责启动客户端。
服务端代码:
public interface BalanceUpdateProvider { // 增加负载 public boolean addBalance(Integer step); // 减少负载 public boolean reduceBalance(Integer step); }
package com.test.cbd.zookeeper.loadbalance; import org.I0Itec.zkclient.ZkClient; import org.I0Itec.zkclient.exception.ZkBadVersionException; import org.apache.zookeeper.data.Stat; public class DefaultBalanceUpdateProvider implements BalanceUpdateProvider { private String serverPath; private ZkClient zc; public DefaultBalanceUpdateProvider(String serverPath, ZkClient zkClient) { this.serverPath = serverPath; this.zc = zkClient; } public boolean addBalance(Integer step) { Stat stat = new Stat(); ServerData sd; // 增加负载:读取服务器的信息ServerData,增加负载,并写回zookeeper while (true) { try { sd = zc.readData(this.serverPath, stat); sd.setBalance(sd.getBalance() + step); // 带上版本,因为可能有其他客户端连接到服务器修改了负载 zc.writeData(this.serverPath, sd, stat.getVersion()); return true; } catch (ZkBadVersionException e) { // ignore } catch (Exception e) { return false; } } } public boolean reduceBalance(Integer step) { Stat stat = new Stat(); ServerData sd; while (true) { try { sd = zc.readData(this.serverPath, stat); final Integer currBalance = sd.getBalance(); sd.setBalance(currBalance>step?currBalance-step:0); zc.writeData(this.serverPath, sd, stat.getVersion()); return true; } catch (ZkBadVersionException e) { // ignore } catch (Exception e) { return false; } } } }
package com.test.cbd.zookeeper.loadbalance; import org.I0Itec.zkclient.ZkClient; import org.I0Itec.zkclient.exception.ZkNoNodeException; public class DefaultRegistProvider implements RegistProvider { // 在zookeeper中创建临时节点并写入信息 public void regist(Object context) throws Exception { // Server在zookeeper中注册自己,需要在zookeeper的目标节点上创建临时节点并写入自己 // 将需要的以下3个信息包装成上下文传入 // 1:path // 2:zkClient // 3:serverData ZooKeeperRegistContext registContext = (ZooKeeperRegistContext) context; String path = registContext.getPath(); ZkClient zc = registContext.getZkClient(); try { zc.createEphemeral(path, registContext.getData()); } catch (ZkNoNodeException e) { String parentDir = path.substring(0, path.lastIndexOf('/')); zc.createPersistent(parentDir, true); regist(registContext); } } public void unRegist(Object context) throws Exception { return; } }
package com.test.cbd.zookeeper.loadbalance; public interface RegistProvider { public void regist(Object context) throws Exception; public void unRegist(Object context) throws Exception; }
package com.test.cbd.zookeeper.loadbalance; public interface Server { public void bind(); }
package com.test.cbd.zookeeper.loadbalance; import java.io.Serializable; public class ServerData implements Serializable,Comparable<ServerData> { private static final long serialVersionUID = -8892569870391530906L; private Integer balance; private String host; private Integer port; public Integer getBalance() { return balance; } public void setBalance(Integer balance) { this.balance = balance; } public String getHost() { return host; } public void setHost(String host) { this.host = host; } public Integer getPort() { return port; } public void setPort(Integer port) { this.port = port; } @Override public String toString() { return "ServerData [balance=" + balance + ", host=" + host + ", port=" + port + "]"; } public int compareTo(ServerData o) { return this.getBalance().compareTo(o.getBalance()); } }
package com.test.cbd.zookeeper.loadbalance; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; /** * 处理服务端与客户端之间的通信 */ public class ServerHandler extends ChannelHandlerAdapter { private final BalanceUpdateProvider balanceUpdater; private static final Integer BALANCE_STEP = 1; public ServerHandler(BalanceUpdateProvider balanceUpdater){ this.balanceUpdater = balanceUpdater; } public BalanceUpdateProvider getBalanceUpdater() { return balanceUpdater; } // 建立连接时增加负载 @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { System.out.println("one client connect..."); balanceUpdater.addBalance(BALANCE_STEP); } // 断开连接时减少负载 @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { balanceUpdater.reduceBalance(BALANCE_STEP); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { cause.printStackTrace(); ctx.close(); } }
package com.test.cbd.zookeeper.loadbalance; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.*; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import org.I0Itec.zkclient.ZkClient; import org.I0Itec.zkclient.serialize.SerializableSerializer; public class ServerImpl implements Server { private EventLoopGroup bossGroup = new NioEventLoopGroup(); private EventLoopGroup workGroup = new NioEventLoopGroup(); private ServerBootstrap bootStrap = new ServerBootstrap(); private ChannelFuture cf; private String zkAddress; private String serversPath; private String currentServerPath; private ServerData sd; private volatile boolean binded = false; private final ZkClient zc; private final RegistProvider registProvider; private static final Integer SESSION_TIME_OUT = 10000; private static final Integer CONNECT_TIME_OUT = 10000; public String getCurrentServerPath() { return currentServerPath; } public String getZkAddress() { return zkAddress; } public String getServersPath() { return serversPath; } public ServerData getSd() { return sd; } public void setSd(ServerData sd) { this.sd = sd; } public ServerImpl(String zkAddress, String serversPath, ServerData sd){ this.zkAddress = zkAddress; this.serversPath = serversPath; this.zc = new ZkClient(this.zkAddress,SESSION_TIME_OUT,CONNECT_TIME_OUT,new SerializableSerializer()); this.registProvider = new DefaultRegistProvider(); this.sd = sd; } //初始化服务端 private void initRunning() throws Exception { String mePath = serversPath.concat("/").concat(sd.getPort().toString()); //注册到zookeeper registProvider.regist(new ZooKeeperRegistContext(mePath,zc,sd)); currentServerPath = mePath; } public void bind() { if (binded){ return; } System.out.println(sd.getPort()+":binding..."); try { initRunning(); } catch (Exception e) { e.printStackTrace(); return; } bootStrap.group(bossGroup,workGroup) .channel(NioServerSocketChannel.class) .option(ChannelOption.SO_BACKLOG, 1024) .childHandler(new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel ch) throws Exception { ChannelPipeline p = ch.pipeline(); p.addLast(new ServerHandler(new DefaultBalanceUpdateProvider(currentServerPath,zc))); } }); try { cf = bootStrap.bind(sd.getPort()).sync(); binded = true; System.out.println(sd.getPort()+":binded..."); cf.channel().closeFuture().sync(); } catch (InterruptedException e) { e.printStackTrace(); }finally{ bossGroup.shutdownGracefully(); workGroup.shutdownGracefully(); } } }
package com.test.cbd.zookeeper.loadbalance; import java.util.ArrayList; import java.util.List; /** * 用于测试,负责启动Work Server */ public class ServerRunner { private static final int SERVER_QTY = 2; private static final String ZOOKEEPER_SERVER = "192.168.1.105:2181"; private static final String SERVERS_PATH = "/servers"; public static void main(String[] args) { List<Thread> threadList = new ArrayList<Thread>(); for(int i=0; i<SERVER_QTY; i++){ final Integer count = i; Thread thread = new Thread(new Runnable() { public void run() { ServerData sd = new ServerData(); sd.setBalance(0); sd.setHost("127.0.0.1"); sd.setPort(6000+count); Server server = new ServerImpl(ZOOKEEPER_SERVER,SERVERS_PATH,sd); server.bind(); } }); threadList.add(thread); thread.start(); } for (int i=0; i<threadList.size(); i++){ try { threadList.get(i).join(); } catch (InterruptedException ignore) { // } } } }
package com.test.cbd.zookeeper.loadbalance; import org.I0Itec.zkclient.ZkClient; public class ZooKeeperRegistContext { private String path; private ZkClient zkClient; private Object data; public ZooKeeperRegistContext(String path, ZkClient zkClient, Object data) { super(); this.path = path; this.zkClient = zkClient; this.data = data; } public String getPath() { return path; } public void setPath(String path) { this.path = path; } public ZkClient getZkClient() { return zkClient; } public void setZkClient(ZkClient zkClient) { this.zkClient = zkClient; } public Object getData() { return data; } public void setData(Object data) { this.data = data; } }
客户端代码:
package com.test.cbd.zookeeper.loadbalance.client; import java.util.List; public abstract class AbstractBalanceProvider<T> implements BalanceProvider<T> { protected abstract T balanceAlgorithm(List<T> items); protected abstract List<T> getBalanceItems(); public T getBalanceItem(){ return balanceAlgorithm(getBalanceItems()); } }
package com.test.cbd.zookeeper.loadbalance.client; public interface BalanceProvider<T> { public T getBalanceItem(); }
package com.test.cbd.zookeeper.loadbalance.client; public interface Client { // 连接服务器 public void connect() throws Exception; // 断开服务器 public void disConnect() throws Exception; }
package com.test.cbd.zookeeper.loadbalance.client; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; public class ClientHandler extends ChannelHandlerAdapter { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { // Close the connection when an exception is raised. cause.printStackTrace(); ctx.close(); } }
package com.test.cbd.zookeeper.loadbalance.client; import com.test.cbd.zookeeper.loadbalance.ServerData; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioSocketChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import io.netty.channel.*; import io.netty.channel.socket.SocketChannel; public class ClientImpl implements Client { private final BalanceProvider<ServerData> provider; private EventLoopGroup group = null; private Channel channel = null; private final Logger log = LoggerFactory.getLogger(getClass()); public ClientImpl(BalanceProvider<ServerData> provider) { this.provider = provider; } public BalanceProvider<ServerData> getProvider() { return provider; } public void connect(){ try{ ServerData serverData = provider.getBalanceItem(); // 获取负载最小的服务器信息,并与之建立连接 System.out.println("connecting to "+serverData.getHost()+":"+serverData.getPort()+", it's balance:"+serverData.getBalance()); group = new NioEventLoopGroup(); Bootstrap b = new Bootstrap(); b.group(group) .channel(NioSocketChannel.class) .handler(new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel ch) throws Exception { ChannelPipeline p = ch.pipeline(); p.addLast(new ClientHandler()); } }); ChannelFuture f = b.connect(serverData.getHost(),serverData.getPort()).syncUninterruptibly(); channel = f.channel(); System.out.println("started success!"); }catch(Exception e){ System.out.println("连接异常:"+e.getMessage()); } } public void disConnect(){ try{ if (channel!=null) channel.close().syncUninterruptibly(); group.shutdownGracefully(); group = null; log.debug("disconnected!"); }catch(Exception e){ log.error(e.getMessage()); } } }
package com.test.cbd.zookeeper.loadbalance.client; import com.test.cbd.zookeeper.loadbalance.ServerData; import java.io.BufferedReader; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.List; public class ClientRunner { private static final int CLIENT_QTY = 3; private static final String ZOOKEEPER_SERVER = "192.168.1.105:2181"; private static final String SERVERS_PATH = "/servers"; public static void main(String[] args) { List<Thread> threadList = new ArrayList<Thread>(CLIENT_QTY); final List<Client> clientList = new ArrayList<Client>(); final BalanceProvider<ServerData> balanceProvider = new DefaultBalanceProvider(ZOOKEEPER_SERVER, SERVERS_PATH); try{ for(int i=0; i<CLIENT_QTY; i++){ Thread thread = new Thread(new Runnable() { public void run() { Client client = new ClientImpl(balanceProvider); clientList.add(client); try { client.connect(); } catch (Exception e) { e.printStackTrace(); } } }); threadList.add(thread); thread.start(); //延时 Thread.sleep(2000); } System.out.println("敲回车键退出! "); new BufferedReader(new InputStreamReader(System.in)).readLine(); }catch(Exception e){ e.printStackTrace(); }finally{ //关闭客户端 for (int i=0; i<clientList.size(); i++){ try { clientList.get(i); clientList.get(i).disConnect(); } catch (Exception ignore) { //ignore } } //关闭线程 for (int i=0; i<threadList.size(); i++){ threadList.get(i).interrupt(); try{ threadList.get(i).join(); }catch (InterruptedException e){ //ignore } } } } }
package com.test.cbd.zookeeper.loadbalance.client; import com.test.cbd.zookeeper.loadbalance.ServerData; import org.I0Itec.zkclient.ZkClient; import org.I0Itec.zkclient.serialize.SerializableSerializer; import java.util.ArrayList; import java.util.Collections; import java.util.List; public class DefaultBalanceProvider extends AbstractBalanceProvider<ServerData> { private final String zkServer; // zookeeper服务器地址 private final String serversPath; // servers节点路径 private final ZkClient zc; private static final Integer SESSION_TIME_OUT = 10000; private static final Integer CONNECT_TIME_OUT = 10000; public DefaultBalanceProvider(String zkServer, String serversPath) { this.serversPath = serversPath; this.zkServer = zkServer; this.zc = new ZkClient(this.zkServer, SESSION_TIME_OUT, CONNECT_TIME_OUT, new SerializableSerializer()); } @Override protected ServerData balanceAlgorithm(List<ServerData> items) { if (items.size()>0){ Collections.sort(items); // 根据负载由小到大排序 return items.get(0); // 返回负载最小的那个 }else{ return null; } } /** * 从zookeeper中拿到所有工作服务器的基本信息 */ @Override protected List<ServerData> getBalanceItems() { List<ServerData> sdList = new ArrayList<ServerData>(); List<String> children = zc.getChildren(this.serversPath); for(int i=0; i<children.size();i++){ ServerData sd = zc.readData(serversPath+"/"+children.get(i)); sdList.add(sd); } return sdList; } }