This commit is contained in:
alyenc 2024-08-13 15:14:39 +08:00
commit af84b01cc2
49 changed files with 2780 additions and 0 deletions

39
.gitignore vendored Normal file
View File

@ -0,0 +1,39 @@
target/
!.mvn/wrapper/maven-wrapper.jar
!**/src/main/**/target/
!**/src/test/**/target/
### IntelliJ IDEA ###
.idea/modules.xml
.idea/jarRepositories.xml
.idea/compiler.xml
.idea/libraries/
*.iws
*.iml
*.ipr
### Eclipse ###
.apt_generated
.classpath
.factorypath
.project
.settings
.springBeans
.sts4-cache
### NetBeans ###
/nbproject/private/
/nbbuild/
/dist/
/nbdist/
/.nb-gradle/
build/
!**/src/main/**/build/
!**/src/test/**/build/
### VS Code ###
.vscode/
### Mac OS ###
.DS_Store
/.idea/

70
pom.xml Normal file
View File

@ -0,0 +1,70 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.codenil</groupId>
<artifactId>comm</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>22</maven.compiler.source>
<maven.compiler.target>22</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.112.Final</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.15.0</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>33.2.1-jre</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>2.0.13</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<version>1.5.6</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.5.6</version>
</dependency>
<dependency>
<groupId>io.tmio</groupId>
<artifactId>tuweni-bytes</artifactId>
<version>2.4.2</version>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk18on</artifactId>
<version>1.78.1</version>
</dependency>
<dependency>
<groupId>org.xerial.snappy</groupId>
<artifactId>snappy-java</artifactId>
<version>1.1.10.5</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.11.0</version>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,173 @@
package org.codenil.comm;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import org.codenil.comm.connections.*;
import org.codenil.comm.message.DisconnectReason;
import org.codenil.comm.message.MessageCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* 网络底层通信模块
*/
public class Communication {
private static final Logger logger = LoggerFactory.getLogger(Communication.class);
/**
* 连接初始化
*/
private final ConnectionInitializer connectionInitializer;
/**
* 消息订阅
*/
private final PeerConnectionEvents connectionEvents;
/**
* 连接回调订阅
*/
private final Subscribers<ConnectCallback> connectSubscribers = Subscribers.create();
/**
* 连接缓存
*/
private final Cache<String, CompletableFuture<PeerConnection>> peersConnectingCache = CacheBuilder.newBuilder()
.expireAfterWrite(Duration.ofSeconds(30L)).concurrencyLevel(1).build();
private final AtomicBoolean started = new AtomicBoolean(false);
private final AtomicBoolean stopped = new AtomicBoolean(false);
public Communication(
final PeerConnectionEvents connectionEvents,
final ConnectionInitializer connectionInitializer) {
this.connectionEvents = connectionEvents;
this.connectionInitializer = connectionInitializer;
}
/**
* 启动
*/
public CompletableFuture<Integer> start() {
if (!started.compareAndSet(false, true)) {
return CompletableFuture.failedFuture(
new IllegalStateException("Unable to start an already started " + getClass().getSimpleName()));
}
//注册回调监听
setupListeners();
//启动连接初始化
return connectionInitializer
.start()
.thenApply((socketAddress) -> {
logger.info("P2P RLPx agent started and listening on {}.", socketAddress);
return socketAddress.getPort();
})
.whenComplete((_, err) -> {
if (err != null) {
logger.error("Failed to start Communication. Check for port conflicts.");
}
});
}
public CompletableFuture<Void> stop() {
if (!started.get() || !stopped.compareAndSet(false, true)) {
return CompletableFuture.failedFuture(
new IllegalStateException("Illegal attempt to stop " + getClass().getSimpleName()));
}
peersConnectingCache.asMap()
.values()
.forEach((conn) -> {
try {
conn.get().disconnect(DisconnectReason.UNKNOWN);
} catch (Exception e) {
logger.debug("Failed to disconnect.");
}
});
return connectionInitializer.stop();
}
/**
* 连接到远程节点
*/
public CompletableFuture<PeerConnection> connect(final RemotePeer remotePeer) {
final CompletableFuture<PeerConnection> peerConnectionCompletableFuture;
try {
synchronized (this) {
//尝试从缓存获取链接获取不到就创建一个
peerConnectionCompletableFuture = peersConnectingCache.get(
remotePeer.ip(), () -> createConnection(remotePeer));
}
} catch (final ExecutionException e) {
throw new RuntimeException(e);
}
return peerConnectionCompletableFuture;
}
/**
* 订阅消息
*/
public void subscribeMessage(final MessageCallback callback) {
connectionEvents.subscribeMessage(callback);
}
/**
* 订阅连接
*/
public void subscribeConnect(final ConnectCallback callback) {
connectSubscribers.subscribe(callback);
}
/**
* 创建远程连接
*/
@Nonnull
private CompletableFuture<PeerConnection> createConnection(final RemotePeer remotePeer) {
CompletableFuture<PeerConnection> completableFuture = initiateOutboundConnection(remotePeer);
completableFuture.whenComplete((peerConnection, throwable) -> {
if (throwable == null) {
dispatchConnect(peerConnection);
}
});
return completableFuture;
}
/**
* 初始化远程连接
*/
private CompletableFuture<PeerConnection> initiateOutboundConnection(final RemotePeer remotePeer) {
logger.trace("Initiating connection to peer: {}:{}", remotePeer.ip(), remotePeer.listeningPort());
return connectionInitializer
.connect(remotePeer)
.whenComplete((conn, err) -> {
if (err != null) {
logger.debug("Failed to connect to peer {}: {}", remotePeer.ip(), err.getMessage());
} else {
logger.debug("Outbound connection established to peer: {}", remotePeer.ip());
}
});
}
private void setupListeners() {
connectionInitializer.subscribeIncomingConnect(this::handleIncomingConnection);
}
private void handleIncomingConnection(final PeerConnection peerConnection) {
dispatchConnect(peerConnection);
}
/**
* 连接完成后调用注册的回调
*/
private void dispatchConnect(final PeerConnection connection) {
connectSubscribers.forEach(c -> c.onConnect(connection));
}
}

View File

@ -0,0 +1,25 @@
package org.codenil.comm;
import org.codenil.comm.connections.PeerConnection;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class ConnectionStore {
private final Map<String, PeerConnection> pki2Conn = new ConcurrentHashMap<>();
public void registerConnection(PeerConnection peerConnection) {
pki2Conn.put(peerConnection.peerIdentity(), peerConnection);
System.out.println(peerConnection.peerIdentity());
}
public boolean unRegisterConnection(PeerConnection peerConnection) {
return pki2Conn.remove(peerConnection.peerIdentity(), peerConnection);
}
public PeerConnection getConnection(String peerIdentity) {
return pki2Conn.get(peerIdentity);
}
}

View File

@ -0,0 +1,29 @@
package org.codenil.comm;
import org.codenil.comm.message.AbstractMessage;
import org.codenil.comm.message.Message;
import org.codenil.comm.message.MessageCodes;
public class DataUpdateMessage extends AbstractMessage {
public DataUpdateMessage(byte[] data) {
super(data);
}
public static DataUpdateMessage readFrom(final Message message) {
if (message instanceof DataUpdateMessage) {
return (DataUpdateMessage) message;
}
return new DataUpdateMessage(message.getData());
}
public static DataUpdateMessage create(final String data) {
return new DataUpdateMessage(data.getBytes());
}
@Override
public int getCode() {
return MessageCodes.DATA_UPDATE;
}
}

View File

@ -0,0 +1,24 @@
package org.codenil.comm;
public class NetworkConfig {
private String bindHost;
private int bindPort;
public String bindHost() {
return bindHost;
}
public void setBindHost(String bindHost) {
this.bindHost = bindHost;
}
public int bindPort() {
return bindPort;
}
public void setBindPort(int bindPort) {
this.bindPort = bindPort;
}
}

View File

@ -0,0 +1,222 @@
package org.codenil.comm;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.tuweni.bytes.Bytes;
import org.codenil.comm.connections.*;
import org.codenil.comm.message.DefaultMessage;
import org.codenil.comm.message.DisconnectReason;
import org.codenil.comm.message.Message;
import org.codenil.comm.message.RawMessage;
import org.codenil.comm.netty.NettyConnectionInitializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.math.BigInteger;
import java.time.Clock;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
public class NetworkService {
private static final Logger logger = LoggerFactory.getLogger(NetworkService.class);
private final CountDownLatch shutdown = new CountDownLatch(1);
private final AtomicBoolean started = new AtomicBoolean(false);
private final AtomicBoolean stopped = new AtomicBoolean(false);
private final Map<Integer, Subscribers<MessageCallback>> listenersByCode = new ConcurrentHashMap<>();
private final Map<Integer, MessageResponse> messageResponseByCode = new ConcurrentHashMap<>();
private final AtomicInteger outstandingRequests = new AtomicInteger(0);
private final AtomicLong requestIdCounter = new AtomicLong(1);
private final Clock clock = Clock.systemUTC();
private final PeerReputation reputation = new PeerReputation();
private final ConnectionStore connectionStore = new ConnectionStore();
private final Communication communication;
private volatile long lastRequestTimestamp = 0;
public NetworkService(final NetworkConfig networkConfig) {
PeerConnectionEvents connectionEvents = new PeerConnectionEvents();
ConnectionInitializer connectionInitializer = new NettyConnectionInitializer(networkConfig, connectionEvents);
this.communication = new Communication(connectionEvents, connectionInitializer);
}
public CompletableFuture<Integer> start() {
if (started.compareAndSet(false, true)) {
logger.info("Starting Network.");
setupHandlers();
return communication.start();
} else {
logger.error("Attempted to start already running network.");
return CompletableFuture.failedFuture(new Throwable("Attempted to start already running network."));
}
}
public CompletableFuture<Void> stop() {
if (stopped.compareAndSet(false, true)) {
logger.info("Stopping Network.");
CompletableFuture<Void> stop = communication.stop();
return stop.whenComplete((result, throwable) -> {
shutdown.countDown();
});
} else {
logger.error("Attempted to stop already stopped network.");
return CompletableFuture.failedFuture(new Throwable("Attempted to stop already stopped network."));
}
}
public CompletableFuture<PeerConnection> connect(final RemotePeer remotePeer) {
return communication.connect(remotePeer);
}
public void sendRequest(final String peerIdentity, final Message message) {
lastRequestTimestamp = clock.millis();
this.dispatchRequest(
msg -> {
try {
PeerConnection connection = connectionStore.getConnection(peerIdentity);
if(Objects.nonNull(connection)) {
connection.send(msg);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
},
message);
}
public long subscribe(final int messageCode, final MessageCallback callback) {
return listenersByCode
.computeIfAbsent(messageCode, _ -> Subscribers.create())
.subscribe(callback);
}
public void unsubscribe(final long subscriptionId, final int messageCode) {
if (listenersByCode.containsKey(messageCode)) {
listenersByCode.get(messageCode).unsubscribe(subscriptionId);
if (listenersByCode.get(messageCode).getSubscriberCount() < 1) {
listenersByCode.remove(messageCode);
}
}
}
public void registerResponse(
final int messageCode, final MessageResponse messageResponse) {
messageResponseByCode.put(messageCode, messageResponse);
}
private void setupHandlers() {
communication.subscribeConnect(this::registerNewConnection);
communication.subscribeMessage(message -> {
try {
this.receiveMessage(message);
} catch (Exception e) {
logger.error(e.getMessage(), e);
}
});
}
/**
* 注册新连接
*/
private void registerNewConnection(final PeerConnection newConnection) {
synchronized (this) {
connectionStore.registerConnection(newConnection);
}
}
private boolean registerDisconnect(final PeerConnection connection) {
return connectionStore.unRegisterConnection(connection);
}
private void dispatchRequest(final RequestSender sender, final Message message) {
outstandingRequests.incrementAndGet();
sender.send(message.wrapMessage(requestIdCounter.getAndIncrement() + ""));
}
private void receiveMessage(final DefaultMessage message) throws Exception {
//处理自定义回复处理器
Optional<Message> maybeResponse = Optional.empty();
try {
final int code = message.message().getCode();
Optional.ofNullable(listenersByCode.get(code))
.ifPresent(listeners -> listeners.forEach(messageCallback -> messageCallback.exec(message)));
Message requestIdAndEthMessage = message.message();
maybeResponse = Optional.ofNullable(messageResponseByCode.get(code))
.map(messageResponse -> messageResponse.response(message))
.map(responseData -> responseData.wrapMessage(requestIdAndEthMessage.getRequestId()));
} catch (Exception e) {
logger.atDebug()
.setMessage("Received malformed message {}, {}")
.addArgument(message)
.addArgument(e::toString)
.log();
this.disconnect(message.connection().peerIdentity(), DisconnectReason.UNKNOWN);
}
maybeResponse.ifPresent(
responseData -> {
try {
sendRequest(message.connection().peerIdentity(), responseData);
} catch (Exception __) {}
});
}
private void disconnect(final String peerIdentity, final DisconnectReason reason) {
PeerConnection connection = connectionStore.getConnection(peerIdentity);
if(Objects.nonNull(connection)) {
try {
connection.disconnect(reason);
} catch (Exception e) {
logger.debug("Disconnect by reason {}", reason);
}
}
}
private void recordRequestTimeout(final PeerConnection connection, final int requestCode) {
logger.atDebug()
.setMessage("Timed out while waiting for response")
.log();
logger.trace("Timed out while waiting for response from peer {}", this);
reputation.recordRequestTimeout(requestCode)
.ifPresent(reason -> {
this.disconnect(connection.peerIdentity(), reason);
});
}
private void recordUselessResponse(final PeerConnection connection, final String requestType) {
logger.atTrace()
.setMessage("Received useless response for request type {}")
.addArgument(requestType)
.log();
reputation.recordUselessResponse(System.currentTimeMillis())
.ifPresent(reason -> {
this.disconnect(connection.peerIdentity(), reason);
});
}
@FunctionalInterface
public interface RequestSender {
void send(final RawMessage message);
}
@FunctionalInterface
public interface MessageCallback {
void exec(DefaultMessage message);
}
@FunctionalInterface
public interface MessageResponse {
Message response(DefaultMessage message);
}
}

View File

@ -0,0 +1,114 @@
package org.codenil.comm;
import org.codenil.comm.message.DisconnectReason;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import static com.google.common.base.Preconditions.checkArgument;
public class PeerReputation implements Comparable<PeerReputation> {
private static final long USELESS_RESPONSE_WINDOW_IN_MILLIS = TimeUnit.MILLISECONDS.convert(1, TimeUnit.MINUTES);
private static final int DEFAULT_MAX_SCORE = 150;
private static final int DEFAULT_INITIAL_SCORE = 100;
private static final Logger LOG = LoggerFactory.getLogger(PeerReputation.class);
private static final int TIMEOUT_THRESHOLD = 5;
private static final int USELESS_RESPONSE_THRESHOLD = 5;
private static final int SMALL_ADJUSTMENT = 1;
private static final int LARGE_ADJUSTMENT = 10;
private final ConcurrentMap<Integer, AtomicInteger> timeoutCountByRequestType = new ConcurrentHashMap<>();
private final Queue<Long> uselessResponseTimes = new ConcurrentLinkedQueue<>();
private int score;
private final int maxScore;
public PeerReputation() {
this(DEFAULT_INITIAL_SCORE, DEFAULT_MAX_SCORE);
}
public PeerReputation(final int initialScore, final int maxScore) {
checkArgument(
initialScore <= maxScore, "Initial score must be less than or equal to max score");
this.maxScore = maxScore;
this.score = initialScore;
}
public Optional<DisconnectReason> recordRequestTimeout(final int requestCode) {
final int newTimeoutCount = getOrCreateTimeoutCount(requestCode).incrementAndGet();
if (newTimeoutCount >= TIMEOUT_THRESHOLD) {
LOG.debug(
"Disconnection triggered by {} repeated timeouts for requestCode {}",
newTimeoutCount,
requestCode);
score -= LARGE_ADJUSTMENT;
return Optional.of(DisconnectReason.TIMEOUT);
} else {
score -= SMALL_ADJUSTMENT;
return Optional.empty();
}
}
public void resetTimeoutCount(final int requestCode) {
timeoutCountByRequestType.remove(requestCode);
}
private AtomicInteger getOrCreateTimeoutCount(final int requestCode) {
return timeoutCountByRequestType.computeIfAbsent(requestCode, code -> new AtomicInteger());
}
public Map<Integer, AtomicInteger> timeoutCounts() {
return timeoutCountByRequestType;
}
public Optional<DisconnectReason> recordUselessResponse(final long timestamp) {
uselessResponseTimes.add(timestamp);
while (shouldRemove(uselessResponseTimes.peek(), timestamp)) {
uselessResponseTimes.poll();
}
if (uselessResponseTimes.size() >= USELESS_RESPONSE_THRESHOLD) {
score -= LARGE_ADJUSTMENT;
LOG.debug("Disconnection triggered by exceeding useless response threshold");
return Optional.of(DisconnectReason.UNKNOWN);
} else {
score -= SMALL_ADJUSTMENT;
return Optional.empty();
}
}
public void recordUsefulResponse() {
if (score < maxScore) {
score = Math.min(maxScore, score + SMALL_ADJUSTMENT);
}
}
private boolean shouldRemove(final Long timestamp, final long currentTimestamp) {
return timestamp != null && timestamp + USELESS_RESPONSE_WINDOW_IN_MILLIS < currentTimestamp;
}
@Override
public String toString() {
return String.format(
"PeerReputation score: %d, timeouts: %s, useless: %s",
score, timeoutCounts(), uselessResponseTimes.size());
}
@Override
public int compareTo(final @Nonnull PeerReputation otherReputation) {
return Integer.compare(this.score, otherReputation.score);
}
public int getScore() {
return score;
}
}

View File

@ -0,0 +1,32 @@
package org.codenil.comm;
import java.net.InetSocketAddress;
public class RemotePeer {
private final String peerIdentity;
private final String ip;
private final int listeningPort;
public RemotePeer(
final String ip,
final int listeningPort) {
this.ip = ip;
this.listeningPort = listeningPort;
this.peerIdentity = new InetSocketAddress(ip, listeningPort).toString();
}
public String peerIdentity() {
return peerIdentity;
}
public String ip() {
return ip;
}
public int listeningPort() {
return listeningPort;
}
}

View File

@ -0,0 +1,58 @@
package org.codenil.comm.connections;
import org.codenil.comm.message.DisconnectMessage;
import org.codenil.comm.message.DisconnectReason;
import org.codenil.comm.message.Message;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.concurrent.atomic.AtomicBoolean;
public abstract class AbstractPeerConnection implements PeerConnection {
private static final Logger logger = LoggerFactory.getLogger(AbstractPeerConnection.class);
protected final PeerConnectionEvents connectionEvents;
private final AtomicBoolean disconnected = new AtomicBoolean(false);
private final AtomicBoolean terminatedImmediately = new AtomicBoolean(false);
protected AbstractPeerConnection(final PeerConnectionEvents connectionEvents) {
this.connectionEvents = connectionEvents;
}
@Override
public void send(final Message message) {
doSendMessage(message);
}
@Override
public void terminateConnection() {
if (terminatedImmediately.compareAndSet(false, true)) {
if (disconnected.compareAndSet(false, true)) {
connectionEvents.dispatchDisconnect(this);
}
// Always ensure the context gets closed immediately even if we previously sent a disconnect
// message and are waiting to close.
closeConnectionImmediately();
logger.atTrace()
.setMessage("Terminating connection, reason {}")
.addArgument(this)
.log();
}
}
@Override
public void disconnect(DisconnectReason reason) {
if (disconnected.compareAndSet(false, true)) {
connectionEvents.dispatchDisconnect(this);
doSendMessage(DisconnectMessage.create(reason));
closeConnection();
}
}
protected abstract void doSendMessage(final Message message);
protected abstract void closeConnection();
protected abstract void closeConnectionImmediately();
}

View File

@ -0,0 +1,9 @@
package org.codenil.comm.connections;
/**
* 连接回调
*/
@FunctionalInterface
public interface ConnectCallback {
void onConnect(final PeerConnection peer);
}

View File

@ -0,0 +1,18 @@
package org.codenil.comm.connections;
import org.codenil.comm.RemotePeer;
import java.net.InetSocketAddress;
import java.util.concurrent.CompletableFuture;
public interface ConnectionInitializer {
CompletableFuture<InetSocketAddress> start();
CompletableFuture<Void> stop();
void subscribeIncomingConnect(final ConnectCallback callback);
CompletableFuture<PeerConnection> connect(RemotePeer remotePeer);
}

View File

@ -0,0 +1,56 @@
package org.codenil.comm.connections;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import org.codenil.comm.message.DisconnectReason;
import org.codenil.comm.message.PingMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
public class KeepAlive extends ChannelDuplexHandler {
private static final Logger logger = LoggerFactory.getLogger(KeepAlive.class);
private final AtomicBoolean waitingForPong;
private final PeerConnection connection;
public KeepAlive(
final PeerConnection connection,
final AtomicBoolean waitingForPong) {
this.connection = connection;
this.waitingForPong = waitingForPong;
}
@Override
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt)
throws IOException {
if (!(evt instanceof IdleStateEvent
&& ((IdleStateEvent) evt).state() == IdleState.READER_IDLE)) {
return;
}
if (waitingForPong.get()) {
logger.debug("PONG never received, disconnecting from peer.");
try {
connection.disconnect(DisconnectReason.TIMEOUT);
} catch (Exception e) {
logger.warn("Exception while disconnecting from peer.", e);
}
return;
}
try {
logger.debug("Idle connection detected, sending Wire PING to peer.");
connection.send(PingMessage.get());
waitingForPong.set(true);
} catch (Exception e) {
logger.trace("PING not sent because peer is already disconnected");
}
}
}

View File

@ -0,0 +1,15 @@
package org.codenil.comm.connections;
import org.codenil.comm.message.DisconnectReason;
import org.codenil.comm.message.Message;
public interface PeerConnection {
String peerIdentity();
void send(final Message message) throws Exception;
void disconnect(DisconnectReason reason) throws Exception;
void terminateConnection();
}

View File

@ -0,0 +1,30 @@
package org.codenil.comm.connections;
import org.codenil.comm.message.*;
public class PeerConnectionEvents {
private final Subscribers<DisconnectCallback> disconnectSubscribers = Subscribers.create(true);
private final Subscribers<MessageCallback> messageSubscribers = Subscribers.create(true);
public PeerConnectionEvents() {}
public void dispatchDisconnect(
final PeerConnection connection) {
disconnectSubscribers.forEach(s -> s.onDisconnect(connection));
}
public void dispatchMessage(final PeerConnection connection, final RawMessage message) {
final DefaultMessage msg = new DefaultMessage(connection, message);
messageSubscribers.forEach(s -> s.onMessage(msg));
}
public void subscribeDisconnect(final DisconnectCallback callback) {
disconnectSubscribers.subscribe(callback);
}
public void subscribeMessage(final MessageCallback callback) {
messageSubscribers.subscribe(callback);
}
}

View File

@ -0,0 +1,89 @@
package org.codenil.comm.connections;
import com.google.common.collect.ImmutableSet;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
public class Subscribers<T> {
private static final Subscribers<?> NONE = new EmptySubscribers<>();
private final AtomicLong subscriberId = new AtomicLong();
private final Map<Long, T> subscribers = new ConcurrentHashMap<>();
private final boolean suppressCallbackExceptions;
private Subscribers(final boolean suppressCallbackExceptions) {
this.suppressCallbackExceptions = suppressCallbackExceptions;
}
@SuppressWarnings("unchecked")
public static <T> Subscribers<T> none() {
return (Subscribers<T>) NONE;
}
public static <T> Subscribers<T> create() {
return new Subscribers<T>(false);
}
public static <T> Subscribers<T> create(final boolean catchCallbackExceptions) {
return new Subscribers<T>(catchCallbackExceptions);
}
public long subscribe(final T subscriber) {
final long id = subscriberId.getAndIncrement();
subscribers.put(id, subscriber);
return id;
}
public boolean unsubscribe(final long subscriberId) {
return subscribers.remove(subscriberId) != null;
}
public void forEach(final Consumer<T> action) {
ImmutableSet.copyOf(subscribers.values())
.forEach(subscriber -> {
try {
action.accept(subscriber);
} catch (final Exception e) {
if (suppressCallbackExceptions) {
// LOG.debug("Error in callback: {}", e);
} else {
throw e;
}
}
});
}
public int getSubscriberCount() {
return subscribers.size();
}
private static class EmptySubscribers<T> extends Subscribers<T> {
private EmptySubscribers() {
super(false);
}
@Override
public long subscribe(final T subscriber) {
throw new UnsupportedOperationException();
}
@Override
public boolean unsubscribe(final long subscriberId) {
return false;
}
@Override
public void forEach(final Consumer<T> action) {}
@Override
public int getSubscriberCount() {
return 0;
}
}
}

View File

@ -0,0 +1,132 @@
package org.codenil.comm.handshake;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.MessageToByteEncoder;
import org.codenil.comm.connections.PeerConnectionEvents;
import org.codenil.comm.message.*;
import org.codenil.comm.connections.PeerConnection;
import org.codenil.comm.netty.handler.MessageFrameDecoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.charset.StandardCharsets;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
public abstract class AbstractHandshakeHandler extends SimpleChannelInboundHandler<ByteBuf> {
private static final Logger logger = LoggerFactory.getLogger(AbstractHandshakeHandler.class);
private final CompletableFuture<PeerConnection> connectionFuture;
private final PeerConnectionEvents connectionEvents;
protected final Handshaker handshaker;
protected AbstractHandshakeHandler(
final CompletableFuture<PeerConnection> connectionFuture,
final PeerConnectionEvents connectionEvents,
final Handshaker handshaker) {
this.connectionFuture = connectionFuture;
this.connectionEvents = connectionEvents;
this.handshaker = handshaker;
}
@Override
protected void channelRead0(final ChannelHandlerContext ctx, final ByteBuf msg) {
final Optional<ByteBuf> nextMsg = nextHandshakeMessage(msg);
if (nextMsg.isPresent()) {
ctx.writeAndFlush(nextMsg.get());
} else if (handshaker.getStatus() != HandshakeStatus.SUCCESS) {
logger.debug("waiting for more bytes");
} else {
/*
* 握手成功后替换掉握手消息处理器
* 替换为消息解码器
* 同时添加一个消息编码器
* 形成一个完整的Message处理链
* validate处理器只负责检测帧合法性尝试封帧封帧成功后移除这个处理器
*/
ctx.channel()
.pipeline()
.replace(this, "DeFramer", new MessageFrameDecoder(connectionEvents, connectionFuture))
.addBefore("DeFramer", "validate", new FirstMessageFrameEncoder());
/*
* 替换完编解码器后发送Hello消息
*/
HelloMessage helloMessage = HelloMessage.create();
RawMessage rawMessage = new RawMessage(helloMessage.getCode(), helloMessage.getData());
rawMessage.setRequestId(helloMessage.getRequestId());
ctx.writeAndFlush(rawMessage)
.addListener(ff -> {
if (ff.isSuccess()) {
logger.trace("Successfully wrote hello message");
}
});
msg.retain();
ctx.fireChannelRead(msg);
}
}
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable throwable) {
logger.trace("Handshake error:", throwable);
connectionFuture.completeExceptionally(throwable);
ctx.close();
}
protected abstract Optional<ByteBuf> nextHandshakeMessage(ByteBuf msg);
/** Ensures that wire hello message is the first message written. */
private static class FirstMessageFrameEncoder extends MessageToByteEncoder<RawMessage> {
private FirstMessageFrameEncoder() {}
@Override
protected void encode(
final ChannelHandlerContext context,
final RawMessage msg,
final ByteBuf out) {
if (msg.getCode() != MessageCodes.HELLO) {
throw new IllegalStateException("First message sent wasn't a HELLO.");
}
byte[] idBytes = Optional.ofNullable(msg.getRequestId()).orElse("").getBytes(StandardCharsets.UTF_8);
int channelsSize = msg.getChannels().size();
int channelBytesLength = 0;
for (String channel : msg.getChannels()) {
byte[] channelBytes = channel.getBytes(StandardCharsets.UTF_8);
channelBytesLength = channelBytesLength + 4 + channelBytes.length;
}
int payloadLength = 4 + 4 + idBytes.length + 4 + channelBytesLength + 4 + msg.getData().length;
// 写入协议头消息总长度
out.writeInt(payloadLength + 4);
// 写入payload
// 写入code
out.writeInt(msg.getCode());
// 写入id
out.writeInt(idBytes.length);
out.writeBytes(idBytes);
// 写入channels
out.writeInt(channelsSize);
for (String channel : msg.getChannels()) {
byte[] channelBytes = channel.getBytes(StandardCharsets.UTF_8);
out.writeInt(channelBytes.length);
out.writeBytes(channelBytes);
}
// 写入data
out.writeInt(msg.getData().length);
out.writeBytes(msg.getData());
context.pipeline().remove(this);
}
}
}

View File

@ -0,0 +1,33 @@
package org.codenil.comm.handshake;
import io.netty.buffer.ByteBuf;
import org.codenil.comm.connections.PeerConnection;
import org.codenil.comm.connections.PeerConnectionEvents;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
/**
* 入站握手信息处理器
*/
public class HandshakeHandlerInbound extends AbstractHandshakeHandler {
public HandshakeHandlerInbound(
final CompletableFuture<PeerConnection> connectionFuture,
final PeerConnectionEvents connectionEvent,
final Handshaker handshaker) {
super(connectionFuture, connectionEvent, handshaker);
handshaker.prepareResponder();
}
@Override
protected Optional<ByteBuf> nextHandshakeMessage(ByteBuf msg) {
final Optional<ByteBuf> nextMsg;
if (handshaker.getStatus() == HandshakeStatus.IN_PROGRESS) {
nextMsg = handshaker.handleMessage(msg);
} else {
nextMsg = Optional.empty();
}
return nextMsg;
}
}

View File

@ -0,0 +1,53 @@
package org.codenil.comm.handshake;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import org.codenil.comm.connections.PeerConnection;
import org.codenil.comm.connections.PeerConnectionEvents;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
/**
* 出站握手信息处理器
*/
public class HandshakeHandlerOutbound extends AbstractHandshakeHandler {
private static final Logger logger = LoggerFactory.getLogger(AbstractHandshakeHandler.class);
private final ByteBuf first;
public HandshakeHandlerOutbound(
final CompletableFuture<PeerConnection> connectionFuture,
final PeerConnectionEvents connectionEvent,
final Handshaker handshaker) {
super(connectionFuture, connectionEvent, handshaker);
handshaker.prepareInitiator();
this.first = handshaker.firstMessage();
}
@Override
protected Optional<ByteBuf> nextHandshakeMessage(ByteBuf msg) {
final Optional<ByteBuf> nextMsg;
if (handshaker.getStatus() == HandshakeStatus.IN_PROGRESS) {
nextMsg = handshaker.handleMessage(msg);
} else {
nextMsg = Optional.empty();
}
return nextMsg;
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
ctx.writeAndFlush(first)
.addListener(f -> {
if (f.isSuccess()) {
logger.trace("Wrote initial crypto handshake message to {}.", ctx.channel().remoteAddress());
}
});
}
}

View File

@ -0,0 +1,113 @@
package org.codenil.comm.handshake;
import org.apache.tuweni.bytes.Bytes;
import org.apache.tuweni.bytes.Bytes32;
import org.bouncycastle.crypto.digests.KeccakDigest;
import java.util.Arrays;
import java.util.Objects;
import static com.google.common.base.Preconditions.checkArgument;
public class HandshakeSecrets {
private final byte[] aesSecret;
private final byte[] macSecret;
private final byte[] token;
private final KeccakDigest egressMac = new KeccakDigest(Bytes32.SIZE * 8);
private final KeccakDigest ingressMac = new KeccakDigest(Bytes32.SIZE * 8);
public HandshakeSecrets(final byte[] aesSecret, final byte[] macSecret, final byte[] token) {
checkArgument(aesSecret.length == Bytes32.SIZE, "aes secret must be exactly 32 bytes long");
checkArgument(macSecret.length == Bytes32.SIZE, "mac secret must be exactly 32 bytes long");
checkArgument(token.length == Bytes32.SIZE, "token must be exactly 32 bytes long");
this.aesSecret = aesSecret;
this.macSecret = macSecret;
this.token = token;
}
public HandshakeSecrets updateEgress(final byte[] bytes) {
egressMac.update(bytes, 0, bytes.length);
return this;
}
public HandshakeSecrets updateIngress(final byte[] bytes) {
ingressMac.update(bytes, 0, bytes.length);
return this;
}
@Override
public String toString() {
return "HandshakeSecrets{"
+ "aesSecret="
+ Bytes.wrap(aesSecret)
+ ", macSecret="
+ Bytes.wrap(macSecret)
+ ", token="
+ Bytes.wrap(token)
+ ", egressMac="
+ Bytes.wrap(snapshot(egressMac))
+ ", ingressMac="
+ Bytes.wrap(snapshot(ingressMac))
+ '}';
}
public byte[] getAesSecret() {
return aesSecret;
}
public byte[] getMacSecret() {
return macSecret;
}
public byte[] getToken() {
return token;
}
public byte[] getEgressMac() {
return snapshot(egressMac);
}
public byte[] getIngressMac() {
return snapshot(ingressMac);
}
private static byte[] snapshot(final KeccakDigest digest) {
final byte[] out = new byte[Bytes32.SIZE];
new KeccakDigest(digest).doFinal(out, 0);
return out;
}
@SuppressWarnings("EqualsWhichDoesntCheckParameterClass") // checked in delegated method
@Override
public boolean equals(final Object obj) {
return equals(obj, false);
}
public boolean equals(final Object o, final boolean flipMacs) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final HandshakeSecrets that = (HandshakeSecrets) o;
final KeccakDigest vsEgress = flipMacs ? that.ingressMac : that.egressMac;
final KeccakDigest vsIngress = flipMacs ? that.egressMac : that.ingressMac;
return Arrays.equals(aesSecret, that.aesSecret)
&& Arrays.equals(macSecret, that.macSecret)
&& Arrays.equals(token, that.token)
&& Arrays.equals(snapshot(egressMac), snapshot(vsEgress))
&& Arrays.equals(snapshot(ingressMac), snapshot(vsIngress));
}
@Override
public int hashCode() {
return Objects.hash(
Arrays.hashCode(aesSecret),
Arrays.hashCode(macSecret),
Arrays.hashCode(token),
egressMac,
ingressMac);
}
}

View File

@ -0,0 +1,13 @@
package org.codenil.comm.handshake;
public enum HandshakeStatus {
UNINITIALIZED,
PREPARED,
IN_PROGRESS,
SUCCESS,
FAILED
}

View File

@ -0,0 +1,20 @@
package org.codenil.comm.handshake;
import io.netty.buffer.ByteBuf;
import java.util.Optional;
public interface Handshaker {
void prepareInitiator();
void prepareResponder();
HandshakeStatus getStatus();
ByteBuf firstMessage();
HandshakeSecrets secrets();
Optional<ByteBuf> handleMessage(ByteBuf buf);
}

View File

@ -0,0 +1,92 @@
package org.codenil.comm.handshake;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.codenil.comm.message.MessageType;
import org.codenil.comm.netty.handler.MessageHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import static com.google.common.base.Preconditions.checkState;
public class PlainHandshaker implements Handshaker {
private static final Logger logger = LoggerFactory.getLogger(AbstractHandshakeHandler.class);
private final AtomicReference<HandshakeStatus> status =
new AtomicReference<>(HandshakeStatus.UNINITIALIZED);
private boolean initiator;
private byte[] initiatorMsg;
private byte[] responderMsg;
@Override
public void prepareInitiator() {
checkState(status.compareAndSet(
HandshakeStatus.UNINITIALIZED, HandshakeStatus.PREPARED),
"handshake was already prepared");
this.initiator = true;
}
@Override
public void prepareResponder() {
checkState(status.compareAndSet(
HandshakeStatus.UNINITIALIZED, HandshakeStatus.IN_PROGRESS),
"handshake was already prepared");
this.initiator = false;
}
@Override
public HandshakeStatus getStatus() {
return status.get();
}
@Override
public ByteBuf firstMessage() {
checkState(initiator, "illegal invocation of firstMessage on non-initiator end of handshake");
checkState(status.compareAndSet(HandshakeStatus.PREPARED, HandshakeStatus.IN_PROGRESS),
"illegal invocation of firstMessage, handshake had already started");
initiatorMsg = MessageHandler.buildMessage(MessageType.PING, MessageType.PING.getValue(), new byte[0]);
logger.trace("First plain handshake message under INITIATOR role");
return Unpooled.wrappedBuffer(initiatorMsg);
}
@Override
public Optional<ByteBuf> handleMessage(ByteBuf buf) {
checkState(status.get() == HandshakeStatus.IN_PROGRESS,
"illegal invocation of onMessage on handshake that is not in progress");
PlainMessage message = MessageHandler.parseMessage(buf);
Optional<byte[]> nextMsg = Optional.empty();
if (initiator) {
checkState(responderMsg == null,
"unexpected message: responder message had " + "already been received");
checkState(message.messageType().equals(MessageType.PONG),
"unexpected message: needs to be a pong");
responderMsg = message.data();
} else {
checkState(initiatorMsg == null,
"unexpected message: initiator message " + "had already been received");
checkState(message.messageType().equals(MessageType.PING),
"unexpected message: needs to be a ping");
initiatorMsg = message.data();
responderMsg = MessageHandler.buildMessage(MessageType.PONG, MessageType.PONG.getValue(), new byte[0]);
nextMsg = Optional.of(responderMsg);
}
status.set(HandshakeStatus.SUCCESS);
logger.trace("Handshake status set to {}", status.get());
return nextMsg.map(Unpooled::wrappedBuffer);
}
@Override
public HandshakeSecrets secrets() {
return null;
}
}

View File

@ -0,0 +1,31 @@
package org.codenil.comm.handshake;
import org.codenil.comm.message.MessageType;
public class PlainMessage {
private final MessageType messageType;
private final int code;
private final byte[] data;
public PlainMessage(final MessageType messageType, final byte[] data) {
this(messageType, -1, data);
}
public PlainMessage(final MessageType messageType, final int code, final byte[] data) {
this.messageType = messageType;
this.code = code;
this.data = data;
}
public MessageType messageType() {
return messageType;
}
public byte[] data() {
return data;
}
public int code() {
return code;
}
}

View File

@ -0,0 +1,32 @@
package org.codenil.comm.message;
public abstract class AbstractMessage implements Message {
private String requestId;
private byte[] data = new byte[]{};
public AbstractMessage(
final byte[] data) {
this.data = data;
}
@Override
public String getRequestId() {
return requestId;
}
@Override
public final int getSize() {
return data.length;
}
@Override
public byte[] getData() {
return data;
}
public void setRequestId(String requestId) {
this.requestId = requestId;
}
}

View File

@ -0,0 +1,26 @@
package org.codenil.comm.message;
import org.codenil.comm.connections.PeerConnection;
public class DefaultMessage {
private final RawMessage message;
private final PeerConnection connection;
public DefaultMessage(
final PeerConnection connection,
final RawMessage message) {
this.message = message;
this.connection = connection;
}
public RawMessage message() {
return message;
}
public PeerConnection connection() {
return connection;
}
}

View File

@ -0,0 +1,8 @@
package org.codenil.comm.message;
import org.codenil.comm.connections.PeerConnection;
@FunctionalInterface
public interface DisconnectCallback {
void onDisconnect(final PeerConnection connection);
}

View File

@ -0,0 +1,31 @@
package org.codenil.comm.message;
import com.google.gson.Gson;
public class DisconnectMessage extends AbstractMessage {
private DisconnectMessage(final byte[] data) {
super(data);
}
public static DisconnectMessage create(final DisconnectReason reason) {
return new DisconnectMessage(new Gson().toJson(reason).getBytes());
}
public static DisconnectMessage readFrom(final Message message) {
if (message instanceof DisconnectMessage) {
return (DisconnectMessage) message;
}
final int code = message.getCode();
if (code != MessageCodes.DISCONNECT) {
throw new IllegalArgumentException(
String.format("Message has code %d and thus is not a DisconnectMessage.", code));
}
return new DisconnectMessage(message.getData());
}
@Override
public int getCode() {
return MessageCodes.DISCONNECT;
}
}

View File

@ -0,0 +1,27 @@
package org.codenil.comm.message;
import java.util.Optional;
public enum DisconnectReason {
UNKNOWN(null),
TIMEOUT((byte) 0x0b),
INVALID_MESSAGE_RECEIVED((byte) 0x02, "An exception was caught decoding message"),
;
private final Optional<Byte> code;
private final Optional<String> message;
DisconnectReason(final Byte code) {
this.code = Optional.ofNullable(code);
this.message = Optional.empty();
}
DisconnectReason(final Byte code, final String message) {
this.code = Optional.ofNullable(code);
this.message = Optional.of(message);
}
}

View File

@ -0,0 +1,19 @@
package org.codenil.comm.message;
public abstract class EmptyMessage implements Message {
@Override
public final int getSize() {
return 0;
}
@Override
public byte[] getData() {
return new byte[]{};
}
@Override
public String toString() {
return getClass().getSimpleName() + "{ code=" + getCode() + ", size=0}";
}
}

View File

@ -0,0 +1,17 @@
package org.codenil.comm.message;
public class HelloMessage extends AbstractMessage {
public HelloMessage(final byte[] data) {
super(data);
}
public static HelloMessage create() {
return new HelloMessage(new byte[0]);
}
@Override
public int getCode() {
return MessageCodes.HELLO;
}
}

View File

@ -0,0 +1,18 @@
package org.codenil.comm.message;
public interface Message {
String getRequestId();
int getSize();
int getCode();
byte[] getData();
default RawMessage wrapMessage(final String requestId) {
RawMessage rawMessage = new RawMessage(getCode(), getData());
rawMessage.setRequestId(requestId);
return rawMessage;
}
}

View File

@ -0,0 +1,7 @@
package org.codenil.comm.message;
@FunctionalInterface
public interface MessageCallback {
void onMessage(final DefaultMessage message);
}

View File

@ -0,0 +1,22 @@
package org.codenil.comm.message;
public class MessageCodes {
public static final int HELLO = 0x00;
public static final int DISCONNECT = 0x01;
public static final int PING = 0x02;
public static final int PONG = 0x03;
public static final int DATA_UPDATE = 0x04;
private MessageCodes() {}
public static String messageName(final int code) {
return switch (code) {
case HELLO -> "Hello";
case DISCONNECT -> "Disconnect";
case PING -> "Ping";
case PONG -> "Pong";
default -> "invalid";
};
}
}

View File

@ -0,0 +1,29 @@
package org.codenil.comm.message;
public enum MessageType {
PING(0),
PONG(1),
DATA(2),
UNRECOGNIZED(-1),
;
private final int value;
MessageType(final int value) {
this.value = value;
}
public int getValue() {
return value;
}
public static MessageType forNumber(final int value) {
return switch (value) {
case 0 -> PING;
case 1 -> PONG;
case 2 -> DATA;
case -1 -> UNRECOGNIZED;
default -> null;
};
}
}

View File

@ -0,0 +1,26 @@
package org.codenil.comm.message;
public class PingMessage extends EmptyMessage {
private static final PingMessage INSTANCE = new PingMessage();
public static PingMessage get() {
return INSTANCE;
}
private PingMessage() {}
@Override
public String getRequestId() {
return "";
}
@Override
public int getCode() {
return MessageCodes.PING;
}
@Override
public String toString() {
return "PingMessage{data=''}";
}
}

View File

@ -0,0 +1,21 @@
package org.codenil.comm.message;
public class PongMessage extends EmptyMessage {
private static final PongMessage INSTANCE = new PongMessage();
public static PongMessage get() {
return INSTANCE;
}
private PongMessage() {}
@Override
public String getRequestId() {
return "";
}
@Override
public int getCode() {
return MessageCodes.PONG;
}
}

View File

@ -0,0 +1,18 @@
package org.codenil.comm.message;
public class RawMessage extends AbstractMessage {
private final int code;
public RawMessage(
final int code,
final byte[] data) {
super(data);
this.code = code;
}
@Override
public int getCode() {
return code;
}
}

View File

@ -0,0 +1,200 @@
package org.codenil.comm.netty;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import org.codenil.comm.RemotePeer;
import org.codenil.comm.connections.*;
import org.codenil.comm.NetworkConfig;
import org.codenil.comm.handshake.HandshakeHandlerInbound;
import org.codenil.comm.handshake.HandshakeHandlerOutbound;
import org.codenil.comm.handshake.PlainHandshaker;
import org.codenil.comm.netty.handler.TimeoutHandler;
import javax.annotation.Nonnull;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.GeneralSecurityException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* netty初始化
*/
public class NettyConnectionInitializer implements ConnectionInitializer {
private static final int TIMEOUT_SECONDS = 10;
private final Subscribers<ConnectCallback> connectSubscribers = Subscribers.create();
private final PeerConnectionEvents eventDispatcher;
private final EventLoopGroup boss = new NioEventLoopGroup(1);
private final EventLoopGroup workers = new NioEventLoopGroup(10);
private final AtomicBoolean started = new AtomicBoolean(false);
private final AtomicBoolean stopped = new AtomicBoolean(false);
private final NetworkConfig config;
private ChannelFuture server;
public NettyConnectionInitializer(
final NetworkConfig config,
final PeerConnectionEvents eventDispatcher) {
this.config = config;
this.eventDispatcher = eventDispatcher;
}
/**
* 启动netty服务器
*/
@Override
public CompletableFuture<InetSocketAddress> start() {
final CompletableFuture<InetSocketAddress> listeningPortFuture = new CompletableFuture<>();
if (!started.compareAndSet(false, true)) {
listeningPortFuture.completeExceptionally(
new IllegalStateException(
"Attempt to start an already started " + this.getClass().getSimpleName()));
return listeningPortFuture;
}
this.server = new ServerBootstrap()
.group(boss, workers)
.channel(NioServerSocketChannel.class)
.childHandler(inboundChannelInitializer())
.bind(config.bindHost(), config.bindPort());
server.addListener(future -> {
final InetSocketAddress socketAddress =
(InetSocketAddress) server.channel().localAddress();
if (!future.isSuccess() || socketAddress == null) {
final String message =
String.format("Unable to start listening on %s:%s. Check for port conflicts.",
config.bindHost(), config.bindPort());
listeningPortFuture.completeExceptionally(
new IllegalStateException(message, future.cause()));
return;
}
listeningPortFuture.complete(socketAddress);
});
return listeningPortFuture;
}
/**
* 停止netty服务器
*/
@Override
public CompletableFuture<Void> stop() {
final CompletableFuture<Void> stoppedFuture = new CompletableFuture<>();
if (!started.get() || !stopped.compareAndSet(false, true)) {
stoppedFuture.completeExceptionally(
new IllegalStateException("Illegal attempt to stop " + this.getClass().getSimpleName()));
return stoppedFuture;
}
workers.shutdownGracefully();
boss.shutdownGracefully();
server.channel()
.closeFuture()
.addListener((future) -> {
if (future.isSuccess()) {
stoppedFuture.complete(null);
} else {
stoppedFuture.completeExceptionally(future.cause());
}
});
return stoppedFuture;
}
@Override
public void subscribeIncomingConnect(ConnectCallback callback) {
connectSubscribers.subscribe(callback);
}
/**
* 连接到远程
*/
@Override
public CompletableFuture<PeerConnection> connect(RemotePeer remotePeer) {
final CompletableFuture<PeerConnection> connectionFuture = new CompletableFuture<>();
new Bootstrap()
.group(workers)
.channel(NioSocketChannel.class)
.remoteAddress(new InetSocketAddress(remotePeer.ip(), remotePeer.listeningPort()))
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, TIMEOUT_SECONDS * 1000)
.handler(outboundChannelInitializer(remotePeer, connectionFuture))
.connect()
.addListener(
(f) -> {
if (!f.isSuccess()) {
connectionFuture.completeExceptionally(f.cause());
}
});
return connectionFuture;
}
private ChannelInitializer<SocketChannel> inboundChannelInitializer() {
return new ChannelInitializer<>() {
@Override
protected void initChannel(final SocketChannel ch) throws Exception {
final CompletableFuture<PeerConnection> connectionFuture = new CompletableFuture<>();
connectionFuture.thenAccept(connection -> connectSubscribers.forEach(c -> c.onConnect(connection)));
//连接处理器
ch.pipeline().addLast(timeoutHandler(connectionFuture, "Timed out waiting to fully establish incoming connection"));
//其他处理器TLS之类的
addAdditionalInboundHandlers(ch);
//握手消息处理器
ch.pipeline().addLast(inboundHandler(connectionFuture));
}
};
}
@Nonnull
private ChannelInitializer<SocketChannel> outboundChannelInitializer(
final RemotePeer remotePeer, final CompletableFuture<PeerConnection> connectionFuture) {
return new ChannelInitializer<>() {
@Override
protected void initChannel(final SocketChannel ch) throws Exception {
//连接处理器
ch.pipeline().addLast(timeoutHandler(connectionFuture, "Timed out waiting to establish connection with peer: " + remotePeer.ip()));
//其他处理器
addAdditionalOutboundHandlers(ch, remotePeer);
//握手消息处理器
ch.pipeline().addLast(outboundHandler(remotePeer, connectionFuture));
}
};
}
@Nonnull
private TimeoutHandler<Channel> timeoutHandler(
final CompletableFuture<PeerConnection> connectionFuture, final String message) {
return new TimeoutHandler<>(connectionFuture::isDone, TIMEOUT_SECONDS,
() -> connectionFuture.completeExceptionally(new TimeoutException(message)));
}
@Nonnull
private HandshakeHandlerInbound inboundHandler(
final CompletableFuture<PeerConnection> connectionFuture) {
return new HandshakeHandlerInbound(connectionFuture, eventDispatcher, new PlainHandshaker());
}
@Nonnull
private HandshakeHandlerOutbound outboundHandler(
final RemotePeer remotePeer, final CompletableFuture<PeerConnection> connectionFuture) {
return new HandshakeHandlerOutbound(connectionFuture, eventDispatcher, new PlainHandshaker());
}
void addAdditionalOutboundHandlers(final Channel channel, final RemotePeer remotePeer)
throws GeneralSecurityException, IOException {}
void addAdditionalInboundHandlers(final Channel channel)
throws GeneralSecurityException, IOException {}
}

View File

@ -0,0 +1,45 @@
package org.codenil.comm.netty;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import org.codenil.comm.connections.AbstractPeerConnection;
import org.codenil.comm.connections.PeerConnectionEvents;
import org.codenil.comm.message.Message;
import org.codenil.comm.message.RawMessage;
import java.util.concurrent.Callable;
import static java.util.concurrent.TimeUnit.SECONDS;
public class NettyPeerConnection extends AbstractPeerConnection {
private final ChannelHandlerContext ctx;
public NettyPeerConnection(
final ChannelHandlerContext ctx,
final PeerConnectionEvents connectionEvents) {
super(connectionEvents);
this.ctx = ctx;
}
@Override
public String peerIdentity() {
return ctx.channel().remoteAddress().toString();
}
@Override
protected void doSendMessage(final Message message) {
ctx.channel().writeAndFlush(new RawMessage(message.getCode(), message.getData()));
}
@Override
protected void closeConnectionImmediately() {
ctx.close();
}
@Override
protected void closeConnection() {
ctx.channel().eventLoop().schedule((Callable<ChannelFuture>) ctx::close, 2L, SECONDS);
}
}

View File

@ -0,0 +1,17 @@
package org.codenil.comm.netty;
import org.codenil.comm.message.Message;
public class OutboundMessage {
private final Message message;
public OutboundMessage(
final Message message) {
this.message = message;
}
public Message message() {
return message;
}
}

View File

@ -0,0 +1,67 @@
package org.codenil.comm.netty.handler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import org.codenil.comm.connections.PeerConnection;
import org.codenil.comm.connections.PeerConnectionEvents;
import org.codenil.comm.message.Message;
import org.codenil.comm.message.MessageCodes;
import org.codenil.comm.message.PongMessage;
import org.codenil.comm.message.RawMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.concurrent.atomic.AtomicBoolean;
public class CommonHandler extends SimpleChannelInboundHandler<RawMessage> {
private static final Logger logger = LoggerFactory.getLogger(CommonHandler.class);
private final AtomicBoolean waitingForPong;
private final PeerConnection connection;
private final PeerConnectionEvents connectionEvents;
public CommonHandler(
final PeerConnection connection,
final PeerConnectionEvents connectionEvents,
final AtomicBoolean waitingForPong) {
this.connection = connection;
this.connectionEvents = connectionEvents;
this.waitingForPong = waitingForPong;
}
@Override
protected void channelRead0(final ChannelHandlerContext ctx, final RawMessage originalMessage) {
logger.debug("Received a message from {}", originalMessage.getCode());
switch (originalMessage.getCode()) {
case MessageCodes.PING:
logger.trace("Received Wire PING");
try {
connection.send(PongMessage.get());
} catch (Exception e) {
// Nothing to do
}
break;
case MessageCodes.PONG:
logger.trace("Received Wire PONG");
waitingForPong.set(false);
break;
case MessageCodes.DISCONNECT:
try {
logger.trace("Received DISCONNECT Message");
} catch (final Exception e) {
logger.error("Received Wire DISCONNECT, but unable to parse reason. ");
}
connection.terminateConnection();
}
connectionEvents.dispatchMessage(connection, originalMessage);
}
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable throwable) {
logger.error("Error:", throwable);
connectionEvents.dispatchDisconnect(connection);
ctx.close();
}
}

View File

@ -0,0 +1,139 @@
package org.codenil.comm.netty.handler;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.timeout.IdleStateHandler;
import org.codenil.comm.connections.KeepAlive;
import org.codenil.comm.connections.PeerConnection;
import org.codenil.comm.connections.PeerConnectionEvents;
import org.codenil.comm.message.DisconnectMessage;
import org.codenil.comm.message.DisconnectReason;
import org.codenil.comm.message.MessageCodes;
import org.codenil.comm.message.RawMessage;
import org.codenil.comm.netty.NettyPeerConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
public class MessageFrameDecoder extends ByteToMessageDecoder {
private static final Logger logger = LoggerFactory.getLogger(MessageFrameDecoder.class);
private final CompletableFuture<PeerConnection> connectFuture;
private final PeerConnectionEvents connectionEvents;
private boolean hellosExchanged;
public MessageFrameDecoder(
final PeerConnectionEvents connectionEvents,
final CompletableFuture<PeerConnection> connectFuture) {
this.connectionEvents = connectionEvents;
this.connectFuture = connectFuture;
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf byteBuf, List<Object> out) throws Exception {
if (byteBuf.readableBytes() < 4) {
return; // 不足4字节长度字段等待更多数据
}
byteBuf.readerIndex(0);
// 读取协议头消息总长度
int totalLength = byteBuf.readInt();
if (byteBuf.readableBytes() < totalLength - 4) {
return; // 不足消息总长度等待更多数据
}
// 读取payload
// 读取id
int idLength = byteBuf.readInt();
byte[] idBytes = new byte[idLength];
byteBuf.readBytes(idBytes);
String id = new String(idBytes, StandardCharsets.UTF_8);
// 读取code
int code = byteBuf.readInt();
// 读取data
int dataLength = byteBuf.readInt();;
byte[] data = new byte[dataLength];
byteBuf.readBytes(data);
// 创建消息对象
RawMessage message = new RawMessage(code, data);
message.setRequestId(id);
if (hellosExchanged) {
out.add(message);
} else if (message.getCode() == MessageCodes.HELLO) {
hellosExchanged = true;
final PeerConnection connection = new NettyPeerConnection(ctx, connectionEvents);
/*
* 如果收到的消息是Hello消息
* 添加一个空闲链接检测处理器
* 添加一个连接保活处理器检测到连接空闲后发送一个Ping消息
* 通用消息处理器处理所有的协议消息
* 添加一个消息封帧处理器
*/
final AtomicBoolean waitingForPong = new AtomicBoolean(false);
ctx.channel()
.pipeline()
.addLast(new IdleStateHandler(15, 0, 0),
new KeepAlive(connection, waitingForPong),
new CommonHandler(connection, connectionEvents, waitingForPong),
new MessageFrameEncoder());
connectFuture.complete(connection);
} else if (message.getCode() == MessageCodes.DISCONNECT) {
logger.debug("Disconnected before sending HELLO.");
ctx.close();
connectFuture.completeExceptionally(new RuntimeException("Disconnect"));
} else {
logger.debug(
"Message received before HELLO's exchanged, disconnecting. Code: {}, Data: {}",
message.getCode(), Arrays.toString(message.getData()));
DisconnectMessage disconnectMessage = DisconnectMessage.create(DisconnectReason.UNKNOWN);
ctx.writeAndFlush(new RawMessage(disconnectMessage.getCode(), disconnectMessage.getData()))
.addListener((f) -> ctx.close());
connectFuture.completeExceptionally(new RuntimeException("Message received before HELLO's exchanged"));
}
}
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable throwable)
throws Exception {
final Throwable cause =
throwable instanceof DecoderException && throwable.getCause() != null
? throwable.getCause()
: throwable;
if (cause instanceof IllegalArgumentException) {
logger.debug("Invalid incoming message ", throwable);
if (connectFuture.isDone() && !connectFuture.isCompletedExceptionally()) {
connectFuture.get().disconnect(DisconnectReason.INVALID_MESSAGE_RECEIVED);
return;
}
} else if (cause instanceof IOException) {
// IO failures are routine when communicating with random peers across the network.
logger.debug("IO error while processing incoming message", throwable);
} else {
logger.error("Exception while processing incoming message", throwable);
}
if (connectFuture.isDone() && !connectFuture.isCompletedExceptionally()) {
connectFuture.get().terminateConnection();
} else {
connectFuture.completeExceptionally(throwable);
ctx.close();
}
}
}

View File

@ -0,0 +1,32 @@
package org.codenil.comm.netty.handler;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import org.codenil.comm.message.RawMessage;
import org.codenil.comm.serialize.SerializeHelper;
import java.nio.charset.StandardCharsets;
import java.util.Optional;
public class MessageFrameEncoder extends MessageToByteEncoder<RawMessage> {
public MessageFrameEncoder() {}
@Override
protected void encode(
final ChannelHandlerContext ctx,
final RawMessage msg,
final ByteBuf out) {
byte[] idBytes = Optional.ofNullable(msg.getRequestId()).orElse("").getBytes(StandardCharsets.UTF_8);
SerializeHelper builder = new SerializeHelper();
ByteBuf buf = builder.writeBytes(idBytes)
.writeInt(msg.getCode())
.writeBytes(msg.getData())
.build();
out.writeBytes(buf);
buf.release();
}
}

View File

@ -0,0 +1,51 @@
package org.codenil.comm.netty.handler;
import io.netty.buffer.ByteBuf;
import org.codenil.comm.handshake.PlainMessage;
import org.codenil.comm.message.MessageType;
import org.codenil.comm.serialize.SerializeHelper;
public class MessageHandler {
public static byte[] buildMessage(final PlainMessage message) {
SerializeHelper builder = new SerializeHelper();
ByteBuf buf = builder.writeInt(message.messageType().getValue())
.writeInt(message.code())
.writeBytes(message.data()).build();
byte[] result = new byte[buf.readableBytes()];
buf.readBytes(result);
buf.release();
return result;
}
public static byte[] buildMessage(
final MessageType messageType,
final int code,
final byte[] data) {
return buildMessage(new PlainMessage(messageType, code, data));
}
public static PlainMessage parseMessage(final ByteBuf buf) {
PlainMessage ret = null;
buf.readerIndex(0);
//跳过版本
int versionLength = buf.readInt();
buf.skipBytes(versionLength);
int payloadLength = buf.readInt();
if(payloadLength < 8) {
return ret;
}
int messageType = buf.readInt();
int code = buf.readInt();
int dataLength = buf.readInt();
byte[] data = new byte[dataLength];
buf.readBytes(data);
ret = new PlainMessage(MessageType.forNumber(messageType), code, data);
return ret;
}
}

View File

@ -0,0 +1,38 @@
package org.codenil.comm.netty.handler;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
public class TimeoutHandler<C extends Channel> extends ChannelInitializer<C> {
private final Supplier<Boolean> condition;
private final int timeoutInSeconds;
private final OnTimeoutCallback callback;
public TimeoutHandler(
final Supplier<Boolean> condition,
final int timeoutInSeconds,
final OnTimeoutCallback callback) {
this.condition = condition;
this.timeoutInSeconds = timeoutInSeconds;
this.callback = callback;
}
@Override
protected void initChannel(final C ch) throws Exception {
ch.eventLoop().schedule(() -> {
if (!condition.get()) {
callback.invoke();
ch.close();
}
}, timeoutInSeconds, TimeUnit.SECONDS);
}
@FunctionalInterface
public interface OnTimeoutCallback {
void invoke();
}
}

View File

@ -0,0 +1,374 @@
package org.codenil.comm.serialize;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import org.apache.commons.lang3.StringUtils;
import java.nio.charset.StandardCharsets;
import java.util.List;
/**
* 根据协议格式写入数据
* 返回组装好协议的ByteBuf
* 只能通过这个方法创建协议其他方式创建的ByteBuf
* 后续修改协议头时不确定兼容
*/
public class SerializeHelper {
private final ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;
private ByteBuf buf;
private boolean hasHeader = false;
public SerializeHelper() {
buf = allocator.buffer();
}
public SerializeHelper(
final ByteBuf buf,
final boolean hasHeader) {
this.buf = buf;
this.hasHeader = hasHeader;
}
/**
* 协议版本号固定长度4
*/
public SerializeHelper writeVersion(final String version) {
if (!Version.supported(version)) {
throw new IllegalArgumentException("Unsupported protocol version: " + version);
}
if (StringUtils.isBlank(version)) {
throw new IllegalArgumentException("Version cannot be null or empty");
}
byte[] bytes = version.getBytes();
if (bytes.length > 4) {
throw new IllegalArgumentException("Version length exceeds 4 bytes");
}
ByteBuf buf = allocator.buffer();
if (this.buf.readableBytes() == 0) {
buf.writeInt(bytes.length);
buf.writeBytes(bytes);
buf.writeInt(0);
this.buf.release();
this.buf = buf;
return this;
}
String oldVersion = readVersion(this.buf);
if (oldVersion.equals(version)) {
return this;
}
//version长度
buf.writeInt(bytes.length);
//version内容
buf.writeBytes(bytes);
//payload长度
buf.writeInt(0);
this.buf.release();
this.buf = buf;
this.hasHeader = true;
return this;
}
public SerializeHelper writeInt(int value) {
ByteBuf buf = allocator.buffer();
this.buf.readerIndex(0);
if (this.buf.readableBytes() == 0) {
buf.writeInt(value);
this.buf.release();
this.buf = buf;
return this;
}
if (hasHeader) {
String version = readVersion(this.buf);
buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(version.getBytes(StandardCharsets.UTF_8));
//写入总长度
int payloadLength = this.buf.readInt();
buf.writeInt(payloadLength + 4);
}
byte[] payloadBytes = new byte[buf.readableBytes()];
this.buf.readBytes(payloadBytes);
buf.writeBytes(payloadBytes);
buf.writeInt(value);
this.buf.release();
this.buf = buf;
return this;
}
public SerializeHelper writeLong(long value) {
ByteBuf buf = allocator.buffer();
this.buf.readerIndex(0);
if (this.buf.readableBytes() == 0) {
buf.writeLong(value);
this.buf.release();
this.buf = buf;
return this;
}
if (hasHeader) {
String version = readVersion(this.buf);
buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(version.getBytes(StandardCharsets.UTF_8));
int payloadLength = this.buf.readInt();
buf.writeInt(payloadLength + 8);
}
byte[] payloadBytes = new byte[buf.readableBytes()];
this.buf.readBytes(payloadBytes);
buf.writeBytes(payloadBytes);
buf.writeLong(value);
this.buf.release();
this.buf = buf;
return this;
}
public SerializeHelper writeString(String value) {
if (value == null) {
return this;
}
ByteBuf buf = allocator.buffer();
this.buf.readerIndex(0);
if (this.buf.readableBytes() == 0) {
buf.writeInt(value.getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(value.getBytes(StandardCharsets.UTF_8));
this.buf.release();
this.buf = buf;
return this;
}
if (hasHeader) {
String version = readVersion(this.buf);
buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(version.getBytes(StandardCharsets.UTF_8));
int payloadLength = this.buf.readInt();
buf.writeInt(payloadLength + value.getBytes(StandardCharsets.UTF_8).length + 4);
}
byte[] payloadBytes = new byte[buf.readableBytes()];
this.buf.readBytes(payloadBytes);
buf.writeBytes(payloadBytes);
buf.writeInt(value.getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(value.getBytes(StandardCharsets.UTF_8));
this.buf.release();
this.buf = buf;
return this;
}
public SerializeHelper writeBoolean(boolean value) {
ByteBuf buf = allocator.buffer();
this.buf.readerIndex(0);
if (this.buf.readableBytes() == 0) {
buf.writeInt(1);
buf.writeBoolean(value);
this.buf.release();
this.buf = buf;
return this;
}
if (hasHeader) {
String version = readVersion(this.buf);
buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(version.getBytes(StandardCharsets.UTF_8));
int payloadLength = this.buf.readInt();
buf.writeInt(payloadLength + 1 + 4);
}
byte[] payloadBytes = new byte[buf.readableBytes()];
this.buf.readBytes(payloadBytes);
buf.writeBytes(payloadBytes);
buf.writeInt(1);
buf.writeBoolean(value);
this.buf.release();
this.buf = buf;
return this;
}
public SerializeHelper writeBytes(byte[] value) {
ByteBuf buf = allocator.buffer();
this.buf.readerIndex(0);
if (this.buf.readableBytes() == 0) {
buf.writeInt(value.length);
buf.writeBytes(value);
this.buf.release();
this.buf = buf;
return this;
}
if (hasHeader) {
String version = readVersion(this.buf);
buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(version.getBytes(StandardCharsets.UTF_8));
int payloadLength = this.buf.readInt();
buf.writeInt(payloadLength + value.length + 4);
}
byte[] payloadBytes = new byte[buf.readableBytes()];
this.buf.readBytes(payloadBytes);
buf.writeBytes(payloadBytes);
buf.writeInt(value.length);
buf.writeBytes(value);
this.buf.release();
this.buf = buf;
return this;
}
public SerializeHelper writeBytes(List<byte[]> values) {
if (values == null) {
return this;
}
int totalLength = values.stream()
.mapToInt(bytes -> bytes.length + 4).sum();
ByteBuf buf = allocator.buffer();
this.buf.readerIndex(0);
if (this.buf.readableBytes() == 0) {
buf.writeInt(totalLength);
values.forEach(value -> {
buf.writeInt(value.length);
buf.writeBytes(value);
});
this.buf.release();
this.buf = buf;
return this;
}
if (hasHeader) {
String version = readVersion(this.buf);
buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(version.getBytes(StandardCharsets.UTF_8));
int payloadLength = this.buf.readInt();
buf.writeInt(payloadLength + totalLength);
}
byte[] payloadBytes = new byte[buf.readableBytes()];
this.buf.readBytes(payloadBytes);
buf.writeBytes(payloadBytes);
buf.writeInt(totalLength);
values.forEach(value -> {
buf.writeInt(value.length);
buf.writeBytes(value);
});
this.buf.release();
this.buf = buf;
return this;
}
public SerializeHelper writeList(List<String> values) {
if (values == null) {
return this;
}
int totalLength = values.stream()
.map(value -> value.getBytes(StandardCharsets.UTF_8))
.mapToInt(bytes -> bytes.length + 4)
.sum();
ByteBuf buf = allocator.buffer();
this.buf.readerIndex(0);
if (this.buf.readableBytes() == 0) {
buf.writeInt(Version.defaultVersion().getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(Version.defaultVersion().getBytes(StandardCharsets.UTF_8));
buf.writeInt(totalLength);
values.forEach(value -> {
buf.writeInt(value.getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(value.getBytes(StandardCharsets.UTF_8));
});
this.buf.release();
this.buf = buf;
return this;
}
String version = readVersion(this.buf);
buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(version.getBytes(StandardCharsets.UTF_8));
int payloadLength = this.buf.readInt();
byte[] payloadBytes = new byte[payloadLength];
this.buf.readBytes(payloadBytes);
buf.writeInt(payloadLength + totalLength);
buf.writeBytes(payloadBytes);
values.forEach(value -> {
buf.writeInt(value.getBytes(StandardCharsets.UTF_8).length);
buf.writeBytes(value.getBytes(StandardCharsets.UTF_8));
});
this.buf.release();
this.buf = buf;
return this;
}
public String readVersion(ByteBuf buf) {
this.buf.readerIndex(0);
int versionLength = this.buf.readInt();
byte[] versionBytes = new byte[versionLength];
this.buf.readBytes(versionBytes);
return new String(versionBytes);
}
public byte[] readPayload(ByteBuf buf) {
this.buf.readerIndex(0);
if (hasHeader) {
int versionLength = this.buf.readInt();
this.buf.skipBytes(versionLength);
}
if (this.buf.readableBytes() == 0) {
return new byte[]{};
}
int payloadLength = this.buf.readInt();
byte[] payloadBytes = new byte[payloadLength];
this.buf.readBytes(payloadBytes);
return payloadBytes;
}
public ByteBuf build() {
return buf;
}
}

View File

@ -0,0 +1,14 @@
package org.codenil.comm.serialize;
public class Version {
public static final String VERSION_1_0 = "1.0";
public static boolean supported(String version) {
return VERSION_1_0.equals(version);
}
public static String defaultVersion() {
return VERSION_1_0;
}
}

View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<configuration>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>
<root level="trace">
<appender-ref ref="STDOUT" />
</root>
</configuration>