示例代码:
https://github.com/gordonklg/study,socket module
A. LineSeparate
基于 Buffer 实现逐行读取的 EchoServer 比传统 Socket 编程困难,相当于需要自己通过 Buffer 实现 BufferedReader 的 readLine 功能。
代码如下,假设单行不超过256字节,支持 Win 和 Linux(不支持单 作为换行符)系统,空行忽略。
代码就不分析了,写了好久才跑对测试,分包粘包真是麻烦,要去刷 LeetCode 基本题提高编码能力了,不能整天都 CTRL C CTRL V 啊。
gordon.study.socket.nio.basic.LineSeparateBlockingEchoServer.java
public class LineSeparateBlockingEchoServer {
public static void main(String[] args) throws Exception {
for (int i = 0; i < 3; i++) {
new Thread(new Client()).start();
}
ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.bind(new InetSocketAddress(8888));
while (true) {
SocketChannel socketChannel = serverSocketChannel.accept();
new Thread(new ServerHandler(socketChannel)).start();
}
}
private static class ServerHandler implements Runnable {
private SocketChannel socketChannel;
private int lastScannedPos = 0;
public ServerHandler(SocketChannel socketChannel) {
this.socketChannel = socketChannel;
}
@Override
public void run() {
try {
ByteBuffer buf = ByteBuffer.allocate(256);
ByteBuffer writeBuf = ByteBuffer.allocate(256);
byte[] content = null;
while (true) {
if (socketChannel.read(buf) > 0) {
do {
content = extractLine(buf);
if (content != null) {
echo(writeBuf, content);
}
} while (content != null && buf.position() > 0);
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
private byte[] extractLine(ByteBuffer buf) {
byte[] result = null;
int totalLen = buf.position();
buf.position(lastScannedPos);
for (int index = lastScannedPos; index < totalLen; index++) {
if (buf.get() == '
') {
result = new byte[index - (hasSlashRBeforeSlashN(buf) ? 1 : 0)];
buf.position(0);
buf.get(result);
buf.position(index + 1);
buf.limit(totalLen);
buf.compact();
lastScannedPos = 0;
return result.length == 0 ? null : result;
}
}
lastScannedPos = buf.position();
return result;
}
private boolean hasSlashRBeforeSlashN(ByteBuffer buf) {
int posOfSlashN = buf.position() - 1;
if (posOfSlashN > 0) {
return (buf.get(posOfSlashN - 1) == '
');
}
return false;
}
private void echo(ByteBuffer writeBuf, byte[] content) throws IOException {
System.out.println("ECHO: " + new String(content));
writeBuf.clear();
writeBuf.put(content);
writeBuf.put("
".getBytes());
writeBuf.flip();
while (writeBuf.hasRemaining()) {
socketChannel.write(writeBuf);
}
}
}
private static class Client implements Runnable {
@Override
public void run() {
try (Socket socket = new Socket()) {
socket.connect(new InetSocketAddress(8888));
DataOutputStream dos = new DataOutputStream(socket.getOutputStream());
dos.write("hello
".getBytes());
Thread.sleep(100);
dos.write("
".getBytes());
Thread.sleep(100);
dos.write("你瞅啥?
".getBytes());
Thread.sleep(100);
dos.write("
chi".getBytes());
Thread.sleep(100);
dos.write(" le ".getBytes());
Thread.sleep(100);
dos.write("ma?
ni hao
hi d".getBytes());
Thread.sleep(100);
dos.write("ude
".getBytes());
Thread.sleep(100);
dos.close();
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
B.CustomProtocol
最经典的自定义协议就是基于 TLV (Type, Length, Value) 格式编码。我们约定 Type 占用1字节,0表示通讯结束,1表示文本消息;Length 占用2字节。
代码依然很难写,而且极不优雅,留给自己以后吐槽用吧。
gordon.study.socket.nio.basic.CustomProtocolBlockingPrintServer.java
public class CustomProtocolBlockingPrintServer {
public static void main(String[] args) throws Exception {
for (int i = 0; i < 3; i++) {
new Thread(new Client()).start();
}
ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.bind(new InetSocketAddress(8888));
while (true) {
SocketChannel socketChannel = serverSocketChannel.accept();
new Thread(new ServerHandler(socketChannel)).start();
}
}
private static class ServerHandler implements Runnable {
private SocketChannel socketChannel;
private int nextMsgLen = 0;
public ServerHandler(SocketChannel socketChannel) {
this.socketChannel = socketChannel;
}
@Override
public void run() {
try {
ByteBuffer buf = ByteBuffer.allocate(256);
while (!Thread.currentThread().isInterrupted()) {
if (socketChannel.read(buf) > 0) {
extractMessageAndPrint(buf);
}
}
System.out.println("===============exit==============");
} catch (Exception e) {
e.printStackTrace();
}
}
private void extractMessageAndPrint(ByteBuffer buf) {
if (nextMsgLen == 0) {// means we havn't get full "head" info
buf.flip();
int type = buf.get();
if (type == 0) {
Thread.currentThread().interrupt();
return;
}
if (buf.remaining() < 2) {
buf.rewind();
buf.compact();
} else {
int length = buf.getChar();
if (buf.remaining() < length - 3) {
nextMsgLen = length;
buf.rewind();
buf.compact();
} else {
byte[] content = new byte[length - 3];
buf.get(content);
System.out.println(new String(content));
buf.compact();
if (buf.position() > 0) {
extractMessageAndPrint(buf);
}
}
}
} else {
buf.flip();
if (buf.remaining() >= nextMsgLen) {
byte[] content = new byte[nextMsgLen - 3];
buf.position(3);
buf.get(content);
System.out.println(new String(content));
buf.compact();
nextMsgLen = 0;
if (buf.position() > 0) {
extractMessageAndPrint(buf);
}
} else {
buf.compact();
}
}
}
}
private static class Client implements Runnable {
@Override
public void run() {
try (Socket socket = new Socket()) {
socket.connect(new InetSocketAddress(8888));
DataOutputStream dos = new DataOutputStream(socket.getOutputStream());
print(dos, "hello");
Thread.sleep(100);
print(dos, "");
Thread.sleep(100);
print(dos, "你瞅啥?");
Thread.sleep(100);
dos.writeByte(1);
dos.flush();
Thread.sleep(100);
dos.write((byte) (9 >> 8 & 0xFF));
dos.flush();
Thread.sleep(100);
dos.write((byte) (9 & 0xFF));
dos.flush();
Thread.sleep(100);
dos.write("ni".getBytes());
dos.flush();
Thread.sleep(100);
dos.write(" ".getBytes());
dos.flush();
Thread.sleep(100);
ByteBuffer buf = ByteBuffer.allocate(100);
buf.put("hao".getBytes());
buf.put((byte) 1);
buf.put((byte) 0);
buf.put((byte) 9);
buf.put("abcdef".getBytes());
buf.put((byte) 1);
buf.put((byte) 0);
buf.put((byte) 8);
buf.put("12345".getBytes());
buf.flip();
byte[] bytes = new byte[buf.remaining()];
buf.get(bytes);
dos.write(bytes);
dos.flush();
Thread.sleep(100);
dos.writeByte(0);
dos.close();
} catch (Exception e) {
e.printStackTrace();
}
}
private void print(DataOutputStream dos, String message) throws IOException {
byte[] bytes = message.getBytes();
int totalLength = 3 + bytes.length;
dos.writeByte(1);
dos.write((byte) (totalLength >> 8 & 0xFF));
dos.write((byte) (totalLength & 0xFF));
dos.write(bytes);
dos.flush();
}
}
}