commit af84b01cc2b8050565d212d2ee89f3669042bf10 Author: alyenc Date: Tue Aug 13 15:14:39 2024 +0800 提交 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a91c35d --- /dev/null +++ b/.gitignore @@ -0,0 +1,39 @@ +target/ +!.mvn/wrapper/maven-wrapper.jar +!**/src/main/**/target/ +!**/src/test/**/target/ + +### IntelliJ IDEA ### +.idea/modules.xml +.idea/jarRepositories.xml +.idea/compiler.xml +.idea/libraries/ +*.iws +*.iml +*.ipr + +### Eclipse ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache + +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ +build/ +!**/src/main/**/build/ +!**/src/test/**/build/ + +### VS Code ### +.vscode/ + +### Mac OS ### +.DS_Store +/.idea/ diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..d2801cc --- /dev/null +++ b/pom.xml @@ -0,0 +1,70 @@ + + + 4.0.0 + + org.codenil + comm + 1.0-SNAPSHOT + + + 22 + 22 + UTF-8 + + + + + io.netty + netty-all + 4.1.112.Final + + + org.apache.commons + commons-lang3 + 3.15.0 + + + com.google.guava + guava + 33.2.1-jre + + + org.slf4j + slf4j-api + 2.0.13 + + + ch.qos.logback + logback-core + 1.5.6 + + + ch.qos.logback + logback-classic + 1.5.6 + + + io.tmio + tuweni-bytes + 2.4.2 + + + org.bouncycastle + bcprov-jdk18on + 1.78.1 + + + org.xerial.snappy + snappy-java + 1.1.10.5 + + + com.google.code.gson + gson + 2.11.0 + + + + \ No newline at end of file diff --git a/src/main/java/org/codenil/comm/Communication.java b/src/main/java/org/codenil/comm/Communication.java new file mode 100644 index 0000000..02e8adb --- /dev/null +++ b/src/main/java/org/codenil/comm/Communication.java @@ -0,0 +1,173 @@ +package org.codenil.comm; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import org.codenil.comm.connections.*; +import org.codenil.comm.message.DisconnectReason; +import org.codenil.comm.message.MessageCallback; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * 网络底层通信模块 + */ +public class Communication { + + private static final Logger logger = LoggerFactory.getLogger(Communication.class); + + /** + * 连接初始化 + */ + private final ConnectionInitializer connectionInitializer; + /** + * 消息订阅 + */ + private final PeerConnectionEvents connectionEvents; + /** + * 连接回调订阅 + */ + private final Subscribers connectSubscribers = Subscribers.create(); + /** + * 连接缓存 + */ + private final Cache> peersConnectingCache = CacheBuilder.newBuilder() + .expireAfterWrite(Duration.ofSeconds(30L)).concurrencyLevel(1).build(); + + private final AtomicBoolean started = new AtomicBoolean(false); + private final AtomicBoolean stopped = new AtomicBoolean(false); + + public Communication( + final PeerConnectionEvents connectionEvents, + final ConnectionInitializer connectionInitializer) { + this.connectionEvents = connectionEvents; + this.connectionInitializer = connectionInitializer; + } + + /** + * 启动 + */ + public CompletableFuture start() { + if (!started.compareAndSet(false, true)) { + return CompletableFuture.failedFuture( + new IllegalStateException("Unable to start an already started " + getClass().getSimpleName())); + } + + //注册回调监听 + setupListeners(); + + //启动连接初始化 + return connectionInitializer + .start() + .thenApply((socketAddress) -> { + logger.info("P2P RLPx agent started and listening on {}.", socketAddress); + return socketAddress.getPort(); + }) + .whenComplete((_, err) -> { + if (err != null) { + logger.error("Failed to start Communication. Check for port conflicts."); + } + }); + } + + public CompletableFuture stop() { + if (!started.get() || !stopped.compareAndSet(false, true)) { + return CompletableFuture.failedFuture( + new IllegalStateException("Illegal attempt to stop " + getClass().getSimpleName())); + } + + peersConnectingCache.asMap() + .values() + .forEach((conn) -> { + try { + conn.get().disconnect(DisconnectReason.UNKNOWN); + } catch (Exception e) { + logger.debug("Failed to disconnect."); + } + }); + return connectionInitializer.stop(); + } + + /** + * 连接到远程节点 + */ + public CompletableFuture connect(final RemotePeer remotePeer) { + final CompletableFuture peerConnectionCompletableFuture; + try { + synchronized (this) { + //尝试从缓存获取链接,获取不到就创建一个 + peerConnectionCompletableFuture = peersConnectingCache.get( + remotePeer.ip(), () -> createConnection(remotePeer)); + } + } catch (final ExecutionException e) { + throw new RuntimeException(e); + } + return peerConnectionCompletableFuture; + } + + /** + * 订阅消息 + */ + public void subscribeMessage(final MessageCallback callback) { + connectionEvents.subscribeMessage(callback); + } + + /** + * 订阅连接 + */ + public void subscribeConnect(final ConnectCallback callback) { + connectSubscribers.subscribe(callback); + } + + /** + * 创建远程连接 + */ + @Nonnull + private CompletableFuture createConnection(final RemotePeer remotePeer) { + CompletableFuture completableFuture = initiateOutboundConnection(remotePeer); + + completableFuture.whenComplete((peerConnection, throwable) -> { + if (throwable == null) { + dispatchConnect(peerConnection); + } + }); + return completableFuture; + } + + /** + * 初始化远程连接 + */ + private CompletableFuture initiateOutboundConnection(final RemotePeer remotePeer) { + logger.trace("Initiating connection to peer: {}:{}", remotePeer.ip(), remotePeer.listeningPort()); + + return connectionInitializer + .connect(remotePeer) + .whenComplete((conn, err) -> { + if (err != null) { + logger.debug("Failed to connect to peer {}: {}", remotePeer.ip(), err.getMessage()); + } else { + logger.debug("Outbound connection established to peer: {}", remotePeer.ip()); + } + }); + } + + private void setupListeners() { + connectionInitializer.subscribeIncomingConnect(this::handleIncomingConnection); + } + + private void handleIncomingConnection(final PeerConnection peerConnection) { + dispatchConnect(peerConnection); + } + + /** + * 连接完成后调用注册的回调 + */ + private void dispatchConnect(final PeerConnection connection) { + connectSubscribers.forEach(c -> c.onConnect(connection)); + } +} diff --git a/src/main/java/org/codenil/comm/ConnectionStore.java b/src/main/java/org/codenil/comm/ConnectionStore.java new file mode 100644 index 0000000..06ef13d --- /dev/null +++ b/src/main/java/org/codenil/comm/ConnectionStore.java @@ -0,0 +1,25 @@ +package org.codenil.comm; + +import org.codenil.comm.connections.PeerConnection; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class ConnectionStore { + + private final Map pki2Conn = new ConcurrentHashMap<>(); + + public void registerConnection(PeerConnection peerConnection) { + pki2Conn.put(peerConnection.peerIdentity(), peerConnection); + System.out.println(peerConnection.peerIdentity()); + } + + public boolean unRegisterConnection(PeerConnection peerConnection) { + return pki2Conn.remove(peerConnection.peerIdentity(), peerConnection); + } + + public PeerConnection getConnection(String peerIdentity) { + return pki2Conn.get(peerIdentity); + } +} + diff --git a/src/main/java/org/codenil/comm/DataUpdateMessage.java b/src/main/java/org/codenil/comm/DataUpdateMessage.java new file mode 100644 index 0000000..966597a --- /dev/null +++ b/src/main/java/org/codenil/comm/DataUpdateMessage.java @@ -0,0 +1,29 @@ +package org.codenil.comm; + +import org.codenil.comm.message.AbstractMessage; +import org.codenil.comm.message.Message; +import org.codenil.comm.message.MessageCodes; + +public class DataUpdateMessage extends AbstractMessage { + + public DataUpdateMessage(byte[] data) { + super(data); + } + + public static DataUpdateMessage readFrom(final Message message) { + if (message instanceof DataUpdateMessage) { + return (DataUpdateMessage) message; + } + return new DataUpdateMessage(message.getData()); + } + + public static DataUpdateMessage create(final String data) { + return new DataUpdateMessage(data.getBytes()); + } + + @Override + public int getCode() { + return MessageCodes.DATA_UPDATE; + } + +} diff --git a/src/main/java/org/codenil/comm/NetworkConfig.java b/src/main/java/org/codenil/comm/NetworkConfig.java new file mode 100644 index 0000000..32f41ce --- /dev/null +++ b/src/main/java/org/codenil/comm/NetworkConfig.java @@ -0,0 +1,24 @@ +package org.codenil.comm; + +public class NetworkConfig { + + private String bindHost; + + private int bindPort; + + public String bindHost() { + return bindHost; + } + + public void setBindHost(String bindHost) { + this.bindHost = bindHost; + } + + public int bindPort() { + return bindPort; + } + + public void setBindPort(int bindPort) { + this.bindPort = bindPort; + } +} diff --git a/src/main/java/org/codenil/comm/NetworkService.java b/src/main/java/org/codenil/comm/NetworkService.java new file mode 100644 index 0000000..30a78b8 --- /dev/null +++ b/src/main/java/org/codenil/comm/NetworkService.java @@ -0,0 +1,222 @@ +package org.codenil.comm; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.tuweni.bytes.Bytes; +import org.codenil.comm.connections.*; +import org.codenil.comm.message.DefaultMessage; +import org.codenil.comm.message.DisconnectReason; +import org.codenil.comm.message.Message; +import org.codenil.comm.message.RawMessage; +import org.codenil.comm.netty.NettyConnectionInitializer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.math.BigInteger; +import java.time.Clock; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +public class NetworkService { + + private static final Logger logger = LoggerFactory.getLogger(NetworkService.class); + + private final CountDownLatch shutdown = new CountDownLatch(1); + private final AtomicBoolean started = new AtomicBoolean(false); + private final AtomicBoolean stopped = new AtomicBoolean(false); + + private final Map> listenersByCode = new ConcurrentHashMap<>(); + private final Map messageResponseByCode = new ConcurrentHashMap<>(); + + private final AtomicInteger outstandingRequests = new AtomicInteger(0); + private final AtomicLong requestIdCounter = new AtomicLong(1); + private final Clock clock = Clock.systemUTC(); + private final PeerReputation reputation = new PeerReputation(); + private final ConnectionStore connectionStore = new ConnectionStore(); + + private final Communication communication; + + private volatile long lastRequestTimestamp = 0; + + public NetworkService(final NetworkConfig networkConfig) { + PeerConnectionEvents connectionEvents = new PeerConnectionEvents(); + ConnectionInitializer connectionInitializer = new NettyConnectionInitializer(networkConfig, connectionEvents); + this.communication = new Communication(connectionEvents, connectionInitializer); + } + + public CompletableFuture start() { + if (started.compareAndSet(false, true)) { + logger.info("Starting Network."); + setupHandlers(); + return communication.start(); + } else { + logger.error("Attempted to start already running network."); + return CompletableFuture.failedFuture(new Throwable("Attempted to start already running network.")); + } + } + + public CompletableFuture stop() { + if (stopped.compareAndSet(false, true)) { + logger.info("Stopping Network."); + CompletableFuture stop = communication.stop(); + return stop.whenComplete((result, throwable) -> { + shutdown.countDown(); + }); + } else { + logger.error("Attempted to stop already stopped network."); + return CompletableFuture.failedFuture(new Throwable("Attempted to stop already stopped network.")); + } + } + + public CompletableFuture connect(final RemotePeer remotePeer) { + return communication.connect(remotePeer); + } + + public void sendRequest(final String peerIdentity, final Message message) { + lastRequestTimestamp = clock.millis(); + this.dispatchRequest( + msg -> { + try { + PeerConnection connection = connectionStore.getConnection(peerIdentity); + if(Objects.nonNull(connection)) { + connection.send(msg); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + message); + } + + public long subscribe(final int messageCode, final MessageCallback callback) { + return listenersByCode + .computeIfAbsent(messageCode, _ -> Subscribers.create()) + .subscribe(callback); + } + + public void unsubscribe(final long subscriptionId, final int messageCode) { + if (listenersByCode.containsKey(messageCode)) { + listenersByCode.get(messageCode).unsubscribe(subscriptionId); + if (listenersByCode.get(messageCode).getSubscriberCount() < 1) { + listenersByCode.remove(messageCode); + } + } + } + + public void registerResponse( + final int messageCode, final MessageResponse messageResponse) { + messageResponseByCode.put(messageCode, messageResponse); + } + + private void setupHandlers() { + communication.subscribeConnect(this::registerNewConnection); + communication.subscribeMessage(message -> { + try { + this.receiveMessage(message); + } catch (Exception e) { + logger.error(e.getMessage(), e); + } + }); + } + + /** + * 注册新连接 + */ + private void registerNewConnection(final PeerConnection newConnection) { + synchronized (this) { + connectionStore.registerConnection(newConnection); + } + } + + private boolean registerDisconnect(final PeerConnection connection) { + return connectionStore.unRegisterConnection(connection); + } + + private void dispatchRequest(final RequestSender sender, final Message message) { + outstandingRequests.incrementAndGet(); + sender.send(message.wrapMessage(requestIdCounter.getAndIncrement() + "")); + } + + private void receiveMessage(final DefaultMessage message) throws Exception { + + //处理自定义回复处理器 + Optional maybeResponse = Optional.empty(); + try { + final int code = message.message().getCode(); + Optional.ofNullable(listenersByCode.get(code)) + .ifPresent(listeners -> listeners.forEach(messageCallback -> messageCallback.exec(message))); + + Message requestIdAndEthMessage = message.message(); + maybeResponse = Optional.ofNullable(messageResponseByCode.get(code)) + .map(messageResponse -> messageResponse.response(message)) + .map(responseData -> responseData.wrapMessage(requestIdAndEthMessage.getRequestId())); + } catch (Exception e) { + logger.atDebug() + .setMessage("Received malformed message {}, {}") + .addArgument(message) + .addArgument(e::toString) + .log(); + this.disconnect(message.connection().peerIdentity(), DisconnectReason.UNKNOWN); + } + + maybeResponse.ifPresent( + responseData -> { + try { + sendRequest(message.connection().peerIdentity(), responseData); + } catch (Exception __) {} + }); + } + + private void disconnect(final String peerIdentity, final DisconnectReason reason) { + PeerConnection connection = connectionStore.getConnection(peerIdentity); + if(Objects.nonNull(connection)) { + try { + connection.disconnect(reason); + } catch (Exception e) { + logger.debug("Disconnect by reason {}", reason); + } + } + } + + private void recordRequestTimeout(final PeerConnection connection, final int requestCode) { + logger.atDebug() + .setMessage("Timed out while waiting for response") + .log(); + logger.trace("Timed out while waiting for response from peer {}", this); + reputation.recordRequestTimeout(requestCode) + .ifPresent(reason -> { + this.disconnect(connection.peerIdentity(), reason); + }); + } + + private void recordUselessResponse(final PeerConnection connection, final String requestType) { + logger.atTrace() + .setMessage("Received useless response for request type {}") + .addArgument(requestType) + .log(); + reputation.recordUselessResponse(System.currentTimeMillis()) + .ifPresent(reason -> { + this.disconnect(connection.peerIdentity(), reason); + }); + } + + @FunctionalInterface + public interface RequestSender { + void send(final RawMessage message); + } + + @FunctionalInterface + public interface MessageCallback { + void exec(DefaultMessage message); + } + + @FunctionalInterface + public interface MessageResponse { + Message response(DefaultMessage message); + } +} diff --git a/src/main/java/org/codenil/comm/PeerReputation.java b/src/main/java/org/codenil/comm/PeerReputation.java new file mode 100644 index 0000000..2798329 --- /dev/null +++ b/src/main/java/org/codenil/comm/PeerReputation.java @@ -0,0 +1,114 @@ +package org.codenil.comm; + +import org.codenil.comm.message.DisconnectReason; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.Map; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.google.common.base.Preconditions.checkArgument; + +public class PeerReputation implements Comparable { + private static final long USELESS_RESPONSE_WINDOW_IN_MILLIS = TimeUnit.MILLISECONDS.convert(1, TimeUnit.MINUTES); + private static final int DEFAULT_MAX_SCORE = 150; + private static final int DEFAULT_INITIAL_SCORE = 100; + private static final Logger LOG = LoggerFactory.getLogger(PeerReputation.class); + private static final int TIMEOUT_THRESHOLD = 5; + private static final int USELESS_RESPONSE_THRESHOLD = 5; + private static final int SMALL_ADJUSTMENT = 1; + private static final int LARGE_ADJUSTMENT = 10; + + private final ConcurrentMap timeoutCountByRequestType = new ConcurrentHashMap<>(); + private final Queue uselessResponseTimes = new ConcurrentLinkedQueue<>(); + + private int score; + + private final int maxScore; + + public PeerReputation() { + this(DEFAULT_INITIAL_SCORE, DEFAULT_MAX_SCORE); + } + + public PeerReputation(final int initialScore, final int maxScore) { + checkArgument( + initialScore <= maxScore, "Initial score must be less than or equal to max score"); + this.maxScore = maxScore; + this.score = initialScore; + } + + public Optional recordRequestTimeout(final int requestCode) { + final int newTimeoutCount = getOrCreateTimeoutCount(requestCode).incrementAndGet(); + if (newTimeoutCount >= TIMEOUT_THRESHOLD) { + LOG.debug( + "Disconnection triggered by {} repeated timeouts for requestCode {}", + newTimeoutCount, + requestCode); + score -= LARGE_ADJUSTMENT; + return Optional.of(DisconnectReason.TIMEOUT); + } else { + score -= SMALL_ADJUSTMENT; + return Optional.empty(); + } + } + + public void resetTimeoutCount(final int requestCode) { + timeoutCountByRequestType.remove(requestCode); + } + + private AtomicInteger getOrCreateTimeoutCount(final int requestCode) { + return timeoutCountByRequestType.computeIfAbsent(requestCode, code -> new AtomicInteger()); + } + + public Map timeoutCounts() { + return timeoutCountByRequestType; + } + + public Optional recordUselessResponse(final long timestamp) { + uselessResponseTimes.add(timestamp); + while (shouldRemove(uselessResponseTimes.peek(), timestamp)) { + uselessResponseTimes.poll(); + } + if (uselessResponseTimes.size() >= USELESS_RESPONSE_THRESHOLD) { + score -= LARGE_ADJUSTMENT; + LOG.debug("Disconnection triggered by exceeding useless response threshold"); + return Optional.of(DisconnectReason.UNKNOWN); + } else { + score -= SMALL_ADJUSTMENT; + return Optional.empty(); + } + } + + public void recordUsefulResponse() { + if (score < maxScore) { + score = Math.min(maxScore, score + SMALL_ADJUSTMENT); + } + } + + private boolean shouldRemove(final Long timestamp, final long currentTimestamp) { + return timestamp != null && timestamp + USELESS_RESPONSE_WINDOW_IN_MILLIS < currentTimestamp; + } + + @Override + public String toString() { + return String.format( + "PeerReputation score: %d, timeouts: %s, useless: %s", + score, timeoutCounts(), uselessResponseTimes.size()); + } + + @Override + public int compareTo(final @Nonnull PeerReputation otherReputation) { + return Integer.compare(this.score, otherReputation.score); + } + + public int getScore() { + return score; + } +} \ No newline at end of file diff --git a/src/main/java/org/codenil/comm/RemotePeer.java b/src/main/java/org/codenil/comm/RemotePeer.java new file mode 100644 index 0000000..83d2be8 --- /dev/null +++ b/src/main/java/org/codenil/comm/RemotePeer.java @@ -0,0 +1,32 @@ +package org.codenil.comm; + +import java.net.InetSocketAddress; + +public class RemotePeer { + + private final String peerIdentity; + + private final String ip; + + private final int listeningPort; + + public RemotePeer( + final String ip, + final int listeningPort) { + this.ip = ip; + this.listeningPort = listeningPort; + this.peerIdentity = new InetSocketAddress(ip, listeningPort).toString(); + } + + public String peerIdentity() { + return peerIdentity; + } + + public String ip() { + return ip; + } + + public int listeningPort() { + return listeningPort; + } +} diff --git a/src/main/java/org/codenil/comm/connections/AbstractPeerConnection.java b/src/main/java/org/codenil/comm/connections/AbstractPeerConnection.java new file mode 100644 index 0000000..935636a --- /dev/null +++ b/src/main/java/org/codenil/comm/connections/AbstractPeerConnection.java @@ -0,0 +1,58 @@ +package org.codenil.comm.connections; + +import org.codenil.comm.message.DisconnectMessage; +import org.codenil.comm.message.DisconnectReason; +import org.codenil.comm.message.Message; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.atomic.AtomicBoolean; + +public abstract class AbstractPeerConnection implements PeerConnection { + + private static final Logger logger = LoggerFactory.getLogger(AbstractPeerConnection.class); + + protected final PeerConnectionEvents connectionEvents; + private final AtomicBoolean disconnected = new AtomicBoolean(false); + private final AtomicBoolean terminatedImmediately = new AtomicBoolean(false); + + protected AbstractPeerConnection(final PeerConnectionEvents connectionEvents) { + this.connectionEvents = connectionEvents; + } + + @Override + public void send(final Message message) { + doSendMessage(message); + } + + @Override + public void terminateConnection() { + if (terminatedImmediately.compareAndSet(false, true)) { + if (disconnected.compareAndSet(false, true)) { + connectionEvents.dispatchDisconnect(this); + } + // Always ensure the context gets closed immediately even if we previously sent a disconnect + // message and are waiting to close. + closeConnectionImmediately(); + logger.atTrace() + .setMessage("Terminating connection, reason {}") + .addArgument(this) + .log(); + } + } + + @Override + public void disconnect(DisconnectReason reason) { + if (disconnected.compareAndSet(false, true)) { + connectionEvents.dispatchDisconnect(this); + doSendMessage(DisconnectMessage.create(reason)); + closeConnection(); + } + } + + protected abstract void doSendMessage(final Message message); + + protected abstract void closeConnection(); + + protected abstract void closeConnectionImmediately(); +} diff --git a/src/main/java/org/codenil/comm/connections/ConnectCallback.java b/src/main/java/org/codenil/comm/connections/ConnectCallback.java new file mode 100644 index 0000000..463c560 --- /dev/null +++ b/src/main/java/org/codenil/comm/connections/ConnectCallback.java @@ -0,0 +1,9 @@ +package org.codenil.comm.connections; + +/** + * 连接回调 + */ +@FunctionalInterface +public interface ConnectCallback { + void onConnect(final PeerConnection peer); +} diff --git a/src/main/java/org/codenil/comm/connections/ConnectionInitializer.java b/src/main/java/org/codenil/comm/connections/ConnectionInitializer.java new file mode 100644 index 0000000..9e21ea4 --- /dev/null +++ b/src/main/java/org/codenil/comm/connections/ConnectionInitializer.java @@ -0,0 +1,18 @@ +package org.codenil.comm.connections; + +import org.codenil.comm.RemotePeer; + +import java.net.InetSocketAddress; +import java.util.concurrent.CompletableFuture; + +public interface ConnectionInitializer { + + CompletableFuture start(); + + CompletableFuture stop(); + + void subscribeIncomingConnect(final ConnectCallback callback); + + CompletableFuture connect(RemotePeer remotePeer); +} + diff --git a/src/main/java/org/codenil/comm/connections/KeepAlive.java b/src/main/java/org/codenil/comm/connections/KeepAlive.java new file mode 100644 index 0000000..b8d880d --- /dev/null +++ b/src/main/java/org/codenil/comm/connections/KeepAlive.java @@ -0,0 +1,56 @@ +package org.codenil.comm.connections; + +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; +import org.codenil.comm.message.DisconnectReason; +import org.codenil.comm.message.PingMessage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; + +public class KeepAlive extends ChannelDuplexHandler { + + private static final Logger logger = LoggerFactory.getLogger(KeepAlive.class); + + private final AtomicBoolean waitingForPong; + + private final PeerConnection connection; + + public KeepAlive( + final PeerConnection connection, + final AtomicBoolean waitingForPong) { + this.connection = connection; + this.waitingForPong = waitingForPong; + } + + @Override + public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) + throws IOException { + if (!(evt instanceof IdleStateEvent + && ((IdleStateEvent) evt).state() == IdleState.READER_IDLE)) { + return; + } + + if (waitingForPong.get()) { + logger.debug("PONG never received, disconnecting from peer."); + try { + connection.disconnect(DisconnectReason.TIMEOUT); + } catch (Exception e) { + logger.warn("Exception while disconnecting from peer.", e); + } + return; + } + + try { + logger.debug("Idle connection detected, sending Wire PING to peer."); + connection.send(PingMessage.get()); + waitingForPong.set(true); + } catch (Exception e) { + logger.trace("PING not sent because peer is already disconnected"); + } + } +} diff --git a/src/main/java/org/codenil/comm/connections/PeerConnection.java b/src/main/java/org/codenil/comm/connections/PeerConnection.java new file mode 100644 index 0000000..b1fff5d --- /dev/null +++ b/src/main/java/org/codenil/comm/connections/PeerConnection.java @@ -0,0 +1,15 @@ +package org.codenil.comm.connections; + +import org.codenil.comm.message.DisconnectReason; +import org.codenil.comm.message.Message; + +public interface PeerConnection { + + String peerIdentity(); + + void send(final Message message) throws Exception; + + void disconnect(DisconnectReason reason) throws Exception; + + void terminateConnection(); +} diff --git a/src/main/java/org/codenil/comm/connections/PeerConnectionEvents.java b/src/main/java/org/codenil/comm/connections/PeerConnectionEvents.java new file mode 100644 index 0000000..fb93f46 --- /dev/null +++ b/src/main/java/org/codenil/comm/connections/PeerConnectionEvents.java @@ -0,0 +1,30 @@ +package org.codenil.comm.connections; + +import org.codenil.comm.message.*; + +public class PeerConnectionEvents { + + private final Subscribers disconnectSubscribers = Subscribers.create(true); + + private final Subscribers messageSubscribers = Subscribers.create(true); + + public PeerConnectionEvents() {} + + public void dispatchDisconnect( + final PeerConnection connection) { + disconnectSubscribers.forEach(s -> s.onDisconnect(connection)); + } + + public void dispatchMessage(final PeerConnection connection, final RawMessage message) { + final DefaultMessage msg = new DefaultMessage(connection, message); + messageSubscribers.forEach(s -> s.onMessage(msg)); + } + + public void subscribeDisconnect(final DisconnectCallback callback) { + disconnectSubscribers.subscribe(callback); + } + + public void subscribeMessage(final MessageCallback callback) { + messageSubscribers.subscribe(callback); + } +} diff --git a/src/main/java/org/codenil/comm/connections/Subscribers.java b/src/main/java/org/codenil/comm/connections/Subscribers.java new file mode 100644 index 0000000..f212461 --- /dev/null +++ b/src/main/java/org/codenil/comm/connections/Subscribers.java @@ -0,0 +1,89 @@ +package org.codenil.comm.connections; + +import com.google.common.collect.ImmutableSet; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; + +public class Subscribers { + private static final Subscribers NONE = new EmptySubscribers<>(); + + private final AtomicLong subscriberId = new AtomicLong(); + private final Map subscribers = new ConcurrentHashMap<>(); + + private final boolean suppressCallbackExceptions; + + private Subscribers(final boolean suppressCallbackExceptions) { + this.suppressCallbackExceptions = suppressCallbackExceptions; + } + + @SuppressWarnings("unchecked") + public static Subscribers none() { + return (Subscribers) NONE; + } + + public static Subscribers create() { + return new Subscribers(false); + } + + + public static Subscribers create(final boolean catchCallbackExceptions) { + return new Subscribers(catchCallbackExceptions); + } + + public long subscribe(final T subscriber) { + final long id = subscriberId.getAndIncrement(); + subscribers.put(id, subscriber); + return id; + } + + public boolean unsubscribe(final long subscriberId) { + return subscribers.remove(subscriberId) != null; + } + + public void forEach(final Consumer action) { + ImmutableSet.copyOf(subscribers.values()) + .forEach(subscriber -> { + try { + action.accept(subscriber); + } catch (final Exception e) { + if (suppressCallbackExceptions) { +// LOG.debug("Error in callback: {}", e); + } else { + throw e; + } + } + }); + } + + public int getSubscriberCount() { + return subscribers.size(); + } + + private static class EmptySubscribers extends Subscribers { + + private EmptySubscribers() { + super(false); + } + + @Override + public long subscribe(final T subscriber) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean unsubscribe(final long subscriberId) { + return false; + } + + @Override + public void forEach(final Consumer action) {} + + @Override + public int getSubscriberCount() { + return 0; + } + } +} diff --git a/src/main/java/org/codenil/comm/handshake/AbstractHandshakeHandler.java b/src/main/java/org/codenil/comm/handshake/AbstractHandshakeHandler.java new file mode 100644 index 0000000..0322395 --- /dev/null +++ b/src/main/java/org/codenil/comm/handshake/AbstractHandshakeHandler.java @@ -0,0 +1,132 @@ +package org.codenil.comm.handshake; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.MessageToByteEncoder; +import org.codenil.comm.connections.PeerConnectionEvents; +import org.codenil.comm.message.*; +import org.codenil.comm.connections.PeerConnection; +import org.codenil.comm.netty.handler.MessageFrameDecoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.charset.StandardCharsets; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +public abstract class AbstractHandshakeHandler extends SimpleChannelInboundHandler { + + private static final Logger logger = LoggerFactory.getLogger(AbstractHandshakeHandler.class); + + private final CompletableFuture connectionFuture; + private final PeerConnectionEvents connectionEvents; + protected final Handshaker handshaker; + + protected AbstractHandshakeHandler( + final CompletableFuture connectionFuture, + final PeerConnectionEvents connectionEvents, + final Handshaker handshaker) { + this.connectionFuture = connectionFuture; + this.connectionEvents = connectionEvents; + this.handshaker = handshaker; + } + + @Override + protected void channelRead0(final ChannelHandlerContext ctx, final ByteBuf msg) { + final Optional nextMsg = nextHandshakeMessage(msg); + + if (nextMsg.isPresent()) { + ctx.writeAndFlush(nextMsg.get()); + } else if (handshaker.getStatus() != HandshakeStatus.SUCCESS) { + logger.debug("waiting for more bytes"); + } else { + /* + * 握手成功后替换掉握手消息处理器 + * 替换为消息解码器 + * 同时添加一个消息编码器 + * 形成一个完整的Message处理链 + * validate处理器只负责检测帧合法性,尝试封帧,封帧成功后移除这个处理器 + */ + ctx.channel() + .pipeline() + .replace(this, "DeFramer", new MessageFrameDecoder(connectionEvents, connectionFuture)) + .addBefore("DeFramer", "validate", new FirstMessageFrameEncoder()); + + /* + * 替换完编解码器后发送Hello消息 + */ + HelloMessage helloMessage = HelloMessage.create(); + RawMessage rawMessage = new RawMessage(helloMessage.getCode(), helloMessage.getData()); + rawMessage.setRequestId(helloMessage.getRequestId()); + ctx.writeAndFlush(rawMessage) + .addListener(ff -> { + if (ff.isSuccess()) { + logger.trace("Successfully wrote hello message"); + } + }); + + msg.retain(); + ctx.fireChannelRead(msg); + } + } + + @Override + public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable throwable) { + logger.trace("Handshake error:", throwable); + connectionFuture.completeExceptionally(throwable); + ctx.close(); + } + + protected abstract Optional nextHandshakeMessage(ByteBuf msg); + + /** Ensures that wire hello message is the first message written. */ + private static class FirstMessageFrameEncoder extends MessageToByteEncoder { + + private FirstMessageFrameEncoder() {} + + @Override + protected void encode( + final ChannelHandlerContext context, + final RawMessage msg, + final ByteBuf out) { + if (msg.getCode() != MessageCodes.HELLO) { + throw new IllegalStateException("First message sent wasn't a HELLO."); + } + byte[] idBytes = Optional.ofNullable(msg.getRequestId()).orElse("").getBytes(StandardCharsets.UTF_8); + int channelsSize = msg.getChannels().size(); + + int channelBytesLength = 0; + for (String channel : msg.getChannels()) { + byte[] channelBytes = channel.getBytes(StandardCharsets.UTF_8); + channelBytesLength = channelBytesLength + 4 + channelBytes.length; + } + + int payloadLength = 4 + 4 + idBytes.length + 4 + channelBytesLength + 4 + msg.getData().length; + + // 写入协议头:消息总长度 + out.writeInt(payloadLength + 4); + + // 写入payload + // 写入code + out.writeInt(msg.getCode()); + + // 写入id + out.writeInt(idBytes.length); + out.writeBytes(idBytes); + + // 写入channels + out.writeInt(channelsSize); + for (String channel : msg.getChannels()) { + byte[] channelBytes = channel.getBytes(StandardCharsets.UTF_8); + out.writeInt(channelBytes.length); + out.writeBytes(channelBytes); + } + + // 写入data + out.writeInt(msg.getData().length); + out.writeBytes(msg.getData()); + context.pipeline().remove(this); + } + } +} diff --git a/src/main/java/org/codenil/comm/handshake/HandshakeHandlerInbound.java b/src/main/java/org/codenil/comm/handshake/HandshakeHandlerInbound.java new file mode 100644 index 0000000..5cc2db9 --- /dev/null +++ b/src/main/java/org/codenil/comm/handshake/HandshakeHandlerInbound.java @@ -0,0 +1,33 @@ +package org.codenil.comm.handshake; + +import io.netty.buffer.ByteBuf; +import org.codenil.comm.connections.PeerConnection; +import org.codenil.comm.connections.PeerConnectionEvents; + +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +/** + * 入站握手信息处理器 + */ +public class HandshakeHandlerInbound extends AbstractHandshakeHandler { + + public HandshakeHandlerInbound( + final CompletableFuture connectionFuture, + final PeerConnectionEvents connectionEvent, + final Handshaker handshaker) { + super(connectionFuture, connectionEvent, handshaker); + handshaker.prepareResponder(); + } + + @Override + protected Optional nextHandshakeMessage(ByteBuf msg) { + final Optional nextMsg; + if (handshaker.getStatus() == HandshakeStatus.IN_PROGRESS) { + nextMsg = handshaker.handleMessage(msg); + } else { + nextMsg = Optional.empty(); + } + return nextMsg; + } +} diff --git a/src/main/java/org/codenil/comm/handshake/HandshakeHandlerOutbound.java b/src/main/java/org/codenil/comm/handshake/HandshakeHandlerOutbound.java new file mode 100644 index 0000000..6e2a88b --- /dev/null +++ b/src/main/java/org/codenil/comm/handshake/HandshakeHandlerOutbound.java @@ -0,0 +1,53 @@ +package org.codenil.comm.handshake; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import org.codenil.comm.connections.PeerConnection; +import org.codenil.comm.connections.PeerConnectionEvents; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +/** + * 出站握手信息处理器 + */ +public class HandshakeHandlerOutbound extends AbstractHandshakeHandler { + + private static final Logger logger = LoggerFactory.getLogger(AbstractHandshakeHandler.class); + + private final ByteBuf first; + + public HandshakeHandlerOutbound( + final CompletableFuture connectionFuture, + final PeerConnectionEvents connectionEvent, + final Handshaker handshaker) { + super(connectionFuture, connectionEvent, handshaker); + + handshaker.prepareInitiator(); + this.first = handshaker.firstMessage(); + } + + @Override + protected Optional nextHandshakeMessage(ByteBuf msg) { + final Optional nextMsg; + if (handshaker.getStatus() == HandshakeStatus.IN_PROGRESS) { + nextMsg = handshaker.handleMessage(msg); + } else { + nextMsg = Optional.empty(); + } + return nextMsg; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + ctx.writeAndFlush(first) + .addListener(f -> { + if (f.isSuccess()) { + logger.trace("Wrote initial crypto handshake message to {}.", ctx.channel().remoteAddress()); + } + }); + } +} diff --git a/src/main/java/org/codenil/comm/handshake/HandshakeSecrets.java b/src/main/java/org/codenil/comm/handshake/HandshakeSecrets.java new file mode 100644 index 0000000..4292b47 --- /dev/null +++ b/src/main/java/org/codenil/comm/handshake/HandshakeSecrets.java @@ -0,0 +1,113 @@ +package org.codenil.comm.handshake; + +import org.apache.tuweni.bytes.Bytes; +import org.apache.tuweni.bytes.Bytes32; +import org.bouncycastle.crypto.digests.KeccakDigest; + +import java.util.Arrays; +import java.util.Objects; + +import static com.google.common.base.Preconditions.checkArgument; + +public class HandshakeSecrets { + private final byte[] aesSecret; + private final byte[] macSecret; + private final byte[] token; + private final KeccakDigest egressMac = new KeccakDigest(Bytes32.SIZE * 8); + private final KeccakDigest ingressMac = new KeccakDigest(Bytes32.SIZE * 8); + + public HandshakeSecrets(final byte[] aesSecret, final byte[] macSecret, final byte[] token) { + checkArgument(aesSecret.length == Bytes32.SIZE, "aes secret must be exactly 32 bytes long"); + checkArgument(macSecret.length == Bytes32.SIZE, "mac secret must be exactly 32 bytes long"); + checkArgument(token.length == Bytes32.SIZE, "token must be exactly 32 bytes long"); + + this.aesSecret = aesSecret; + this.macSecret = macSecret; + this.token = token; + } + + public HandshakeSecrets updateEgress(final byte[] bytes) { + egressMac.update(bytes, 0, bytes.length); + return this; + } + + public HandshakeSecrets updateIngress(final byte[] bytes) { + ingressMac.update(bytes, 0, bytes.length); + return this; + } + + @Override + public String toString() { + return "HandshakeSecrets{" + + "aesSecret=" + + Bytes.wrap(aesSecret) + + ", macSecret=" + + Bytes.wrap(macSecret) + + ", token=" + + Bytes.wrap(token) + + ", egressMac=" + + Bytes.wrap(snapshot(egressMac)) + + ", ingressMac=" + + Bytes.wrap(snapshot(ingressMac)) + + '}'; + } + + public byte[] getAesSecret() { + return aesSecret; + } + + public byte[] getMacSecret() { + return macSecret; + } + + public byte[] getToken() { + return token; + } + + public byte[] getEgressMac() { + return snapshot(egressMac); + } + + public byte[] getIngressMac() { + return snapshot(ingressMac); + } + + private static byte[] snapshot(final KeccakDigest digest) { + final byte[] out = new byte[Bytes32.SIZE]; + new KeccakDigest(digest).doFinal(out, 0); + return out; + } + + @SuppressWarnings("EqualsWhichDoesntCheckParameterClass") // checked in delegated method + @Override + public boolean equals(final Object obj) { + return equals(obj, false); + } + + public boolean equals(final Object o, final boolean flipMacs) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final HandshakeSecrets that = (HandshakeSecrets) o; + final KeccakDigest vsEgress = flipMacs ? that.ingressMac : that.egressMac; + final KeccakDigest vsIngress = flipMacs ? that.egressMac : that.ingressMac; + return Arrays.equals(aesSecret, that.aesSecret) + && Arrays.equals(macSecret, that.macSecret) + && Arrays.equals(token, that.token) + && Arrays.equals(snapshot(egressMac), snapshot(vsEgress)) + && Arrays.equals(snapshot(ingressMac), snapshot(vsIngress)); + } + + @Override + public int hashCode() { + return Objects.hash( + Arrays.hashCode(aesSecret), + Arrays.hashCode(macSecret), + Arrays.hashCode(token), + egressMac, + ingressMac); + } +} diff --git a/src/main/java/org/codenil/comm/handshake/HandshakeStatus.java b/src/main/java/org/codenil/comm/handshake/HandshakeStatus.java new file mode 100644 index 0000000..b0da579 --- /dev/null +++ b/src/main/java/org/codenil/comm/handshake/HandshakeStatus.java @@ -0,0 +1,13 @@ +package org.codenil.comm.handshake; + +public enum HandshakeStatus { + UNINITIALIZED, + + PREPARED, + + IN_PROGRESS, + + SUCCESS, + + FAILED +} diff --git a/src/main/java/org/codenil/comm/handshake/Handshaker.java b/src/main/java/org/codenil/comm/handshake/Handshaker.java new file mode 100644 index 0000000..aac4ac3 --- /dev/null +++ b/src/main/java/org/codenil/comm/handshake/Handshaker.java @@ -0,0 +1,20 @@ +package org.codenil.comm.handshake; + +import io.netty.buffer.ByteBuf; + +import java.util.Optional; + +public interface Handshaker { + + void prepareInitiator(); + + void prepareResponder(); + + HandshakeStatus getStatus(); + + ByteBuf firstMessage(); + + HandshakeSecrets secrets(); + + Optional handleMessage(ByteBuf buf); +} diff --git a/src/main/java/org/codenil/comm/handshake/PlainHandshaker.java b/src/main/java/org/codenil/comm/handshake/PlainHandshaker.java new file mode 100644 index 0000000..5871ba4 --- /dev/null +++ b/src/main/java/org/codenil/comm/handshake/PlainHandshaker.java @@ -0,0 +1,92 @@ +package org.codenil.comm.handshake; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.codenil.comm.message.MessageType; +import org.codenil.comm.netty.handler.MessageHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +import static com.google.common.base.Preconditions.checkState; + +public class PlainHandshaker implements Handshaker { + + private static final Logger logger = LoggerFactory.getLogger(AbstractHandshakeHandler.class); + + private final AtomicReference status = + new AtomicReference<>(HandshakeStatus.UNINITIALIZED); + + private boolean initiator; + private byte[] initiatorMsg; + private byte[] responderMsg; + + @Override + public void prepareInitiator() { + checkState(status.compareAndSet( + HandshakeStatus.UNINITIALIZED, HandshakeStatus.PREPARED), + "handshake was already prepared"); + this.initiator = true; + } + + @Override + public void prepareResponder() { + checkState(status.compareAndSet( + HandshakeStatus.UNINITIALIZED, HandshakeStatus.IN_PROGRESS), + "handshake was already prepared"); + this.initiator = false; + } + + @Override + public HandshakeStatus getStatus() { + return status.get(); + } + + @Override + public ByteBuf firstMessage() { + checkState(initiator, "illegal invocation of firstMessage on non-initiator end of handshake"); + checkState(status.compareAndSet(HandshakeStatus.PREPARED, HandshakeStatus.IN_PROGRESS), + "illegal invocation of firstMessage, handshake had already started"); + initiatorMsg = MessageHandler.buildMessage(MessageType.PING, MessageType.PING.getValue(), new byte[0]); + logger.trace("First plain handshake message under INITIATOR role"); + return Unpooled.wrappedBuffer(initiatorMsg); + } + + @Override + public Optional handleMessage(ByteBuf buf) { + checkState(status.get() == HandshakeStatus.IN_PROGRESS, + "illegal invocation of onMessage on handshake that is not in progress"); + + PlainMessage message = MessageHandler.parseMessage(buf); + + Optional nextMsg = Optional.empty(); + if (initiator) { + checkState(responderMsg == null, + "unexpected message: responder message had " + "already been received"); + + checkState(message.messageType().equals(MessageType.PONG), + "unexpected message: needs to be a pong"); + responderMsg = message.data(); + + } else { + checkState(initiatorMsg == null, + "unexpected message: initiator message " + "had already been received"); + checkState(message.messageType().equals(MessageType.PING), + "unexpected message: needs to be a ping"); + + initiatorMsg = message.data(); + responderMsg = MessageHandler.buildMessage(MessageType.PONG, MessageType.PONG.getValue(), new byte[0]); + nextMsg = Optional.of(responderMsg); + } + status.set(HandshakeStatus.SUCCESS); + logger.trace("Handshake status set to {}", status.get()); + return nextMsg.map(Unpooled::wrappedBuffer); + } + + @Override + public HandshakeSecrets secrets() { + return null; + } +} diff --git a/src/main/java/org/codenil/comm/handshake/PlainMessage.java b/src/main/java/org/codenil/comm/handshake/PlainMessage.java new file mode 100644 index 0000000..aa376e1 --- /dev/null +++ b/src/main/java/org/codenil/comm/handshake/PlainMessage.java @@ -0,0 +1,31 @@ +package org.codenil.comm.handshake; + +import org.codenil.comm.message.MessageType; + +public class PlainMessage { + private final MessageType messageType; + private final int code; + private final byte[] data; + + public PlainMessage(final MessageType messageType, final byte[] data) { + this(messageType, -1, data); + } + + public PlainMessage(final MessageType messageType, final int code, final byte[] data) { + this.messageType = messageType; + this.code = code; + this.data = data; + } + + public MessageType messageType() { + return messageType; + } + + public byte[] data() { + return data; + } + + public int code() { + return code; + } +} diff --git a/src/main/java/org/codenil/comm/message/AbstractMessage.java b/src/main/java/org/codenil/comm/message/AbstractMessage.java new file mode 100644 index 0000000..18d416f --- /dev/null +++ b/src/main/java/org/codenil/comm/message/AbstractMessage.java @@ -0,0 +1,32 @@ +package org.codenil.comm.message; + +public abstract class AbstractMessage implements Message { + + private String requestId; + + private byte[] data = new byte[]{}; + + public AbstractMessage( + final byte[] data) { + this.data = data; + } + + @Override + public String getRequestId() { + return requestId; + } + + @Override + public final int getSize() { + return data.length; + } + + @Override + public byte[] getData() { + return data; + } + + public void setRequestId(String requestId) { + this.requestId = requestId; + } +} diff --git a/src/main/java/org/codenil/comm/message/DefaultMessage.java b/src/main/java/org/codenil/comm/message/DefaultMessage.java new file mode 100644 index 0000000..aecb4ed --- /dev/null +++ b/src/main/java/org/codenil/comm/message/DefaultMessage.java @@ -0,0 +1,26 @@ +package org.codenil.comm.message; + +import org.codenil.comm.connections.PeerConnection; + +public class DefaultMessage { + + private final RawMessage message; + + private final PeerConnection connection; + + public DefaultMessage( + final PeerConnection connection, + final RawMessage message) { + this.message = message; + this.connection = connection; + } + + public RawMessage message() { + return message; + } + + public PeerConnection connection() { + return connection; + } +} + diff --git a/src/main/java/org/codenil/comm/message/DisconnectCallback.java b/src/main/java/org/codenil/comm/message/DisconnectCallback.java new file mode 100644 index 0000000..b6597ef --- /dev/null +++ b/src/main/java/org/codenil/comm/message/DisconnectCallback.java @@ -0,0 +1,8 @@ +package org.codenil.comm.message; + +import org.codenil.comm.connections.PeerConnection; + +@FunctionalInterface +public interface DisconnectCallback { + void onDisconnect(final PeerConnection connection); +} diff --git a/src/main/java/org/codenil/comm/message/DisconnectMessage.java b/src/main/java/org/codenil/comm/message/DisconnectMessage.java new file mode 100644 index 0000000..02aef0b --- /dev/null +++ b/src/main/java/org/codenil/comm/message/DisconnectMessage.java @@ -0,0 +1,31 @@ +package org.codenil.comm.message; + +import com.google.gson.Gson; + +public class DisconnectMessage extends AbstractMessage { + + private DisconnectMessage(final byte[] data) { + super(data); + } + + public static DisconnectMessage create(final DisconnectReason reason) { + return new DisconnectMessage(new Gson().toJson(reason).getBytes()); + } + + public static DisconnectMessage readFrom(final Message message) { + if (message instanceof DisconnectMessage) { + return (DisconnectMessage) message; + } + final int code = message.getCode(); + if (code != MessageCodes.DISCONNECT) { + throw new IllegalArgumentException( + String.format("Message has code %d and thus is not a DisconnectMessage.", code)); + } + return new DisconnectMessage(message.getData()); + } + + @Override + public int getCode() { + return MessageCodes.DISCONNECT; + } +} diff --git a/src/main/java/org/codenil/comm/message/DisconnectReason.java b/src/main/java/org/codenil/comm/message/DisconnectReason.java new file mode 100644 index 0000000..213a5ee --- /dev/null +++ b/src/main/java/org/codenil/comm/message/DisconnectReason.java @@ -0,0 +1,27 @@ +package org.codenil.comm.message; + +import java.util.Optional; + +public enum DisconnectReason { + + UNKNOWN(null), + + TIMEOUT((byte) 0x0b), + + INVALID_MESSAGE_RECEIVED((byte) 0x02, "An exception was caught decoding message"), + ; + + private final Optional code; + private final Optional message; + + DisconnectReason(final Byte code) { + this.code = Optional.ofNullable(code); + this.message = Optional.empty(); + } + + DisconnectReason(final Byte code, final String message) { + this.code = Optional.ofNullable(code); + this.message = Optional.of(message); + } + +} diff --git a/src/main/java/org/codenil/comm/message/EmptyMessage.java b/src/main/java/org/codenil/comm/message/EmptyMessage.java new file mode 100644 index 0000000..bfd9a54 --- /dev/null +++ b/src/main/java/org/codenil/comm/message/EmptyMessage.java @@ -0,0 +1,19 @@ +package org.codenil.comm.message; + +public abstract class EmptyMessage implements Message { + + @Override + public final int getSize() { + return 0; + } + + @Override + public byte[] getData() { + return new byte[]{}; + } + + @Override + public String toString() { + return getClass().getSimpleName() + "{ code=" + getCode() + ", size=0}"; + } +} diff --git a/src/main/java/org/codenil/comm/message/HelloMessage.java b/src/main/java/org/codenil/comm/message/HelloMessage.java new file mode 100644 index 0000000..401fb70 --- /dev/null +++ b/src/main/java/org/codenil/comm/message/HelloMessage.java @@ -0,0 +1,17 @@ +package org.codenil.comm.message; + +public class HelloMessage extends AbstractMessage { + + public HelloMessage(final byte[] data) { + super(data); + } + + public static HelloMessage create() { + return new HelloMessage(new byte[0]); + } + + @Override + public int getCode() { + return MessageCodes.HELLO; + } +} diff --git a/src/main/java/org/codenil/comm/message/Message.java b/src/main/java/org/codenil/comm/message/Message.java new file mode 100644 index 0000000..3ee6f3d --- /dev/null +++ b/src/main/java/org/codenil/comm/message/Message.java @@ -0,0 +1,18 @@ +package org.codenil.comm.message; + +public interface Message { + + String getRequestId(); + + int getSize(); + + int getCode(); + + byte[] getData(); + + default RawMessage wrapMessage(final String requestId) { + RawMessage rawMessage = new RawMessage(getCode(), getData()); + rawMessage.setRequestId(requestId); + return rawMessage; + } +} diff --git a/src/main/java/org/codenil/comm/message/MessageCallback.java b/src/main/java/org/codenil/comm/message/MessageCallback.java new file mode 100644 index 0000000..cd2d368 --- /dev/null +++ b/src/main/java/org/codenil/comm/message/MessageCallback.java @@ -0,0 +1,7 @@ +package org.codenil.comm.message; + +@FunctionalInterface +public interface MessageCallback { + + void onMessage(final DefaultMessage message); +} diff --git a/src/main/java/org/codenil/comm/message/MessageCodes.java b/src/main/java/org/codenil/comm/message/MessageCodes.java new file mode 100644 index 0000000..6943556 --- /dev/null +++ b/src/main/java/org/codenil/comm/message/MessageCodes.java @@ -0,0 +1,22 @@ +package org.codenil.comm.message; + +public class MessageCodes { + public static final int HELLO = 0x00; + public static final int DISCONNECT = 0x01; + public static final int PING = 0x02; + public static final int PONG = 0x03; + + public static final int DATA_UPDATE = 0x04; + private MessageCodes() {} + + public static String messageName(final int code) { + return switch (code) { + case HELLO -> "Hello"; + case DISCONNECT -> "Disconnect"; + case PING -> "Ping"; + case PONG -> "Pong"; + default -> "invalid"; + }; + } +} + diff --git a/src/main/java/org/codenil/comm/message/MessageType.java b/src/main/java/org/codenil/comm/message/MessageType.java new file mode 100644 index 0000000..6a048c3 --- /dev/null +++ b/src/main/java/org/codenil/comm/message/MessageType.java @@ -0,0 +1,29 @@ +package org.codenil.comm.message; + +public enum MessageType { + PING(0), + PONG(1), + DATA(2), + UNRECOGNIZED(-1), + ; + + private final int value; + + MessageType(final int value) { + this.value = value; + } + + public int getValue() { + return value; + } + + public static MessageType forNumber(final int value) { + return switch (value) { + case 0 -> PING; + case 1 -> PONG; + case 2 -> DATA; + case -1 -> UNRECOGNIZED; + default -> null; + }; + } +} diff --git a/src/main/java/org/codenil/comm/message/PingMessage.java b/src/main/java/org/codenil/comm/message/PingMessage.java new file mode 100644 index 0000000..579dd51 --- /dev/null +++ b/src/main/java/org/codenil/comm/message/PingMessage.java @@ -0,0 +1,26 @@ +package org.codenil.comm.message; + +public class PingMessage extends EmptyMessage { + private static final PingMessage INSTANCE = new PingMessage(); + + public static PingMessage get() { + return INSTANCE; + } + + private PingMessage() {} + + @Override + public String getRequestId() { + return ""; + } + + @Override + public int getCode() { + return MessageCodes.PING; + } + + @Override + public String toString() { + return "PingMessage{data=''}"; + } +} diff --git a/src/main/java/org/codenil/comm/message/PongMessage.java b/src/main/java/org/codenil/comm/message/PongMessage.java new file mode 100644 index 0000000..715ebf6 --- /dev/null +++ b/src/main/java/org/codenil/comm/message/PongMessage.java @@ -0,0 +1,21 @@ +package org.codenil.comm.message; + +public class PongMessage extends EmptyMessage { + private static final PongMessage INSTANCE = new PongMessage(); + + public static PongMessage get() { + return INSTANCE; + } + + private PongMessage() {} + + @Override + public String getRequestId() { + return ""; + } + + @Override + public int getCode() { + return MessageCodes.PONG; + } +} diff --git a/src/main/java/org/codenil/comm/message/RawMessage.java b/src/main/java/org/codenil/comm/message/RawMessage.java new file mode 100644 index 0000000..2594a36 --- /dev/null +++ b/src/main/java/org/codenil/comm/message/RawMessage.java @@ -0,0 +1,18 @@ +package org.codenil.comm.message; + +public class RawMessage extends AbstractMessage { + + private final int code; + + public RawMessage( + final int code, + final byte[] data) { + super(data); + this.code = code; + } + + @Override + public int getCode() { + return code; + } +} diff --git a/src/main/java/org/codenil/comm/netty/NettyConnectionInitializer.java b/src/main/java/org/codenil/comm/netty/NettyConnectionInitializer.java new file mode 100644 index 0000000..d581cc3 --- /dev/null +++ b/src/main/java/org/codenil/comm/netty/NettyConnectionInitializer.java @@ -0,0 +1,200 @@ +package org.codenil.comm.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.*; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import org.codenil.comm.RemotePeer; +import org.codenil.comm.connections.*; +import org.codenil.comm.NetworkConfig; +import org.codenil.comm.handshake.HandshakeHandlerInbound; +import org.codenil.comm.handshake.HandshakeHandlerOutbound; +import org.codenil.comm.handshake.PlainHandshaker; +import org.codenil.comm.netty.handler.TimeoutHandler; + +import javax.annotation.Nonnull; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.security.GeneralSecurityException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * netty初始化 + */ +public class NettyConnectionInitializer implements ConnectionInitializer { + + private static final int TIMEOUT_SECONDS = 10; + + private final Subscribers connectSubscribers = Subscribers.create(); + private final PeerConnectionEvents eventDispatcher; + + private final EventLoopGroup boss = new NioEventLoopGroup(1); + private final EventLoopGroup workers = new NioEventLoopGroup(10); + private final AtomicBoolean started = new AtomicBoolean(false); + private final AtomicBoolean stopped = new AtomicBoolean(false); + + private final NetworkConfig config; + + private ChannelFuture server; + + public NettyConnectionInitializer( + final NetworkConfig config, + final PeerConnectionEvents eventDispatcher) { + this.config = config; + this.eventDispatcher = eventDispatcher; + } + + /** + * 启动netty服务器 + */ + @Override + public CompletableFuture start() { + final CompletableFuture listeningPortFuture = new CompletableFuture<>(); + if (!started.compareAndSet(false, true)) { + listeningPortFuture.completeExceptionally( + new IllegalStateException( + "Attempt to start an already started " + this.getClass().getSimpleName())); + return listeningPortFuture; + } + + this.server = new ServerBootstrap() + .group(boss, workers) + .channel(NioServerSocketChannel.class) + .childHandler(inboundChannelInitializer()) + .bind(config.bindHost(), config.bindPort()); + server.addListener(future -> { + final InetSocketAddress socketAddress = + (InetSocketAddress) server.channel().localAddress(); + if (!future.isSuccess() || socketAddress == null) { + final String message = + String.format("Unable to start listening on %s:%s. Check for port conflicts.", + config.bindHost(), config.bindPort()); + listeningPortFuture.completeExceptionally( + new IllegalStateException(message, future.cause())); + return; + } + + listeningPortFuture.complete(socketAddress); + }); + + return listeningPortFuture; + } + + /** + * 停止netty服务器 + */ + @Override + public CompletableFuture stop() { + final CompletableFuture stoppedFuture = new CompletableFuture<>(); + if (!started.get() || !stopped.compareAndSet(false, true)) { + stoppedFuture.completeExceptionally( + new IllegalStateException("Illegal attempt to stop " + this.getClass().getSimpleName())); + return stoppedFuture; + } + + workers.shutdownGracefully(); + boss.shutdownGracefully(); + server.channel() + .closeFuture() + .addListener((future) -> { + if (future.isSuccess()) { + stoppedFuture.complete(null); + } else { + stoppedFuture.completeExceptionally(future.cause()); + } + }); + return stoppedFuture; + } + + @Override + public void subscribeIncomingConnect(ConnectCallback callback) { + connectSubscribers.subscribe(callback); + } + + /** + * 连接到远程 + */ + @Override + public CompletableFuture connect(RemotePeer remotePeer) { + final CompletableFuture connectionFuture = new CompletableFuture<>(); + + new Bootstrap() + .group(workers) + .channel(NioSocketChannel.class) + .remoteAddress(new InetSocketAddress(remotePeer.ip(), remotePeer.listeningPort())) + .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, TIMEOUT_SECONDS * 1000) + .handler(outboundChannelInitializer(remotePeer, connectionFuture)) + .connect() + .addListener( + (f) -> { + if (!f.isSuccess()) { + connectionFuture.completeExceptionally(f.cause()); + } + }); + + return connectionFuture; + } + + private ChannelInitializer inboundChannelInitializer() { + return new ChannelInitializer<>() { + @Override + protected void initChannel(final SocketChannel ch) throws Exception { + final CompletableFuture connectionFuture = new CompletableFuture<>(); + connectionFuture.thenAccept(connection -> connectSubscribers.forEach(c -> c.onConnect(connection))); + //连接处理器 + ch.pipeline().addLast(timeoutHandler(connectionFuture, "Timed out waiting to fully establish incoming connection")); + //其他处理器,TLS之类的 + addAdditionalInboundHandlers(ch); + //握手消息处理器 + ch.pipeline().addLast(inboundHandler(connectionFuture)); + } + }; + } + + @Nonnull + private ChannelInitializer outboundChannelInitializer( + final RemotePeer remotePeer, final CompletableFuture connectionFuture) { + return new ChannelInitializer<>() { + @Override + protected void initChannel(final SocketChannel ch) throws Exception { + //连接处理器 + ch.pipeline().addLast(timeoutHandler(connectionFuture, "Timed out waiting to establish connection with peer: " + remotePeer.ip())); + //其他处理器 + addAdditionalOutboundHandlers(ch, remotePeer); + //握手消息处理器 + ch.pipeline().addLast(outboundHandler(remotePeer, connectionFuture)); + } + }; + } + + @Nonnull + private TimeoutHandler timeoutHandler( + final CompletableFuture connectionFuture, final String message) { + return new TimeoutHandler<>(connectionFuture::isDone, TIMEOUT_SECONDS, + () -> connectionFuture.completeExceptionally(new TimeoutException(message))); + } + + @Nonnull + private HandshakeHandlerInbound inboundHandler( + final CompletableFuture connectionFuture) { + return new HandshakeHandlerInbound(connectionFuture, eventDispatcher, new PlainHandshaker()); + } + + @Nonnull + private HandshakeHandlerOutbound outboundHandler( + final RemotePeer remotePeer, final CompletableFuture connectionFuture) { + return new HandshakeHandlerOutbound(connectionFuture, eventDispatcher, new PlainHandshaker()); + } + + void addAdditionalOutboundHandlers(final Channel channel, final RemotePeer remotePeer) + throws GeneralSecurityException, IOException {} + + void addAdditionalInboundHandlers(final Channel channel) + throws GeneralSecurityException, IOException {} +} diff --git a/src/main/java/org/codenil/comm/netty/NettyPeerConnection.java b/src/main/java/org/codenil/comm/netty/NettyPeerConnection.java new file mode 100644 index 0000000..b10a40c --- /dev/null +++ b/src/main/java/org/codenil/comm/netty/NettyPeerConnection.java @@ -0,0 +1,45 @@ +package org.codenil.comm.netty; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import org.codenil.comm.connections.AbstractPeerConnection; +import org.codenil.comm.connections.PeerConnectionEvents; +import org.codenil.comm.message.Message; +import org.codenil.comm.message.RawMessage; + +import java.util.concurrent.Callable; + +import static java.util.concurrent.TimeUnit.SECONDS; + +public class NettyPeerConnection extends AbstractPeerConnection { + + private final ChannelHandlerContext ctx; + + public NettyPeerConnection( + final ChannelHandlerContext ctx, + final PeerConnectionEvents connectionEvents) { + super(connectionEvents); + this.ctx = ctx; + } + + + @Override + public String peerIdentity() { + return ctx.channel().remoteAddress().toString(); + } + + @Override + protected void doSendMessage(final Message message) { + ctx.channel().writeAndFlush(new RawMessage(message.getCode(), message.getData())); + } + + @Override + protected void closeConnectionImmediately() { + ctx.close(); + } + + @Override + protected void closeConnection() { + ctx.channel().eventLoop().schedule((Callable) ctx::close, 2L, SECONDS); + } +} diff --git a/src/main/java/org/codenil/comm/netty/OutboundMessage.java b/src/main/java/org/codenil/comm/netty/OutboundMessage.java new file mode 100644 index 0000000..3101426 --- /dev/null +++ b/src/main/java/org/codenil/comm/netty/OutboundMessage.java @@ -0,0 +1,17 @@ +package org.codenil.comm.netty; + +import org.codenil.comm.message.Message; + +public class OutboundMessage { + + private final Message message; + + public OutboundMessage( + final Message message) { + this.message = message; + } + + public Message message() { + return message; + } +} diff --git a/src/main/java/org/codenil/comm/netty/handler/CommonHandler.java b/src/main/java/org/codenil/comm/netty/handler/CommonHandler.java new file mode 100644 index 0000000..dee1b54 --- /dev/null +++ b/src/main/java/org/codenil/comm/netty/handler/CommonHandler.java @@ -0,0 +1,67 @@ +package org.codenil.comm.netty.handler; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.codenil.comm.connections.PeerConnection; +import org.codenil.comm.connections.PeerConnectionEvents; +import org.codenil.comm.message.Message; +import org.codenil.comm.message.MessageCodes; +import org.codenil.comm.message.PongMessage; +import org.codenil.comm.message.RawMessage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.atomic.AtomicBoolean; + +public class CommonHandler extends SimpleChannelInboundHandler { + + private static final Logger logger = LoggerFactory.getLogger(CommonHandler.class); + + private final AtomicBoolean waitingForPong; + private final PeerConnection connection; + private final PeerConnectionEvents connectionEvents; + + public CommonHandler( + final PeerConnection connection, + final PeerConnectionEvents connectionEvents, + final AtomicBoolean waitingForPong) { + this.connection = connection; + this.connectionEvents = connectionEvents; + this.waitingForPong = waitingForPong; + } + + @Override + protected void channelRead0(final ChannelHandlerContext ctx, final RawMessage originalMessage) { + logger.debug("Received a message from {}", originalMessage.getCode()); + switch (originalMessage.getCode()) { + case MessageCodes.PING: + logger.trace("Received Wire PING"); + try { + connection.send(PongMessage.get()); + } catch (Exception e) { + // Nothing to do + } + break; + case MessageCodes.PONG: + logger.trace("Received Wire PONG"); + waitingForPong.set(false); + break; + case MessageCodes.DISCONNECT: + try { + logger.trace("Received DISCONNECT Message"); + } catch (final Exception e) { + logger.error("Received Wire DISCONNECT, but unable to parse reason. "); + } + connection.terminateConnection(); + } + + connectionEvents.dispatchMessage(connection, originalMessage); + } + + @Override + public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable throwable) { + logger.error("Error:", throwable); + connectionEvents.dispatchDisconnect(connection); + ctx.close(); + } +} diff --git a/src/main/java/org/codenil/comm/netty/handler/MessageFrameDecoder.java b/src/main/java/org/codenil/comm/netty/handler/MessageFrameDecoder.java new file mode 100644 index 0000000..0759ff6 --- /dev/null +++ b/src/main/java/org/codenil/comm/netty/handler/MessageFrameDecoder.java @@ -0,0 +1,139 @@ +package org.codenil.comm.netty.handler; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.timeout.IdleStateHandler; +import org.codenil.comm.connections.KeepAlive; +import org.codenil.comm.connections.PeerConnection; +import org.codenil.comm.connections.PeerConnectionEvents; +import org.codenil.comm.message.DisconnectMessage; +import org.codenil.comm.message.DisconnectReason; +import org.codenil.comm.message.MessageCodes; +import org.codenil.comm.message.RawMessage; +import org.codenil.comm.netty.NettyPeerConnection; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; + +public class MessageFrameDecoder extends ByteToMessageDecoder { + + private static final Logger logger = LoggerFactory.getLogger(MessageFrameDecoder.class); + + private final CompletableFuture connectFuture; + private final PeerConnectionEvents connectionEvents; + + private boolean hellosExchanged; + + public MessageFrameDecoder( + final PeerConnectionEvents connectionEvents, + final CompletableFuture connectFuture) { + this.connectionEvents = connectionEvents; + this.connectFuture = connectFuture; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf byteBuf, List out) throws Exception { + if (byteBuf.readableBytes() < 4) { + return; // 不足4字节长度字段,等待更多数据 + } + byteBuf.readerIndex(0); + + // 读取协议头:消息总长度 + int totalLength = byteBuf.readInt(); + + if (byteBuf.readableBytes() < totalLength - 4) { + return; // 不足消息总长度,等待更多数据 + } + + // 读取payload + + // 读取id + int idLength = byteBuf.readInt(); + byte[] idBytes = new byte[idLength]; + byteBuf.readBytes(idBytes); + String id = new String(idBytes, StandardCharsets.UTF_8); + + // 读取code + int code = byteBuf.readInt(); + + // 读取data + int dataLength = byteBuf.readInt();; + byte[] data = new byte[dataLength]; + byteBuf.readBytes(data); + + // 创建消息对象 + RawMessage message = new RawMessage(code, data); + message.setRequestId(id); + + if (hellosExchanged) { + out.add(message); + } else if (message.getCode() == MessageCodes.HELLO) { + hellosExchanged = true; + final PeerConnection connection = new NettyPeerConnection(ctx, connectionEvents); + + /* + * 如果收到的消息是Hello消息 + * 添加一个空闲链接检测处理器 + * 添加一个连接保活处理器,检测到连接空闲后发送一个Ping消息 + * 通用消息处理器,处理所有的协议消息 + * 添加一个消息封帧处理器 + */ + final AtomicBoolean waitingForPong = new AtomicBoolean(false); + ctx.channel() + .pipeline() + .addLast(new IdleStateHandler(15, 0, 0), + new KeepAlive(connection, waitingForPong), + new CommonHandler(connection, connectionEvents, waitingForPong), + new MessageFrameEncoder()); + connectFuture.complete(connection); + } else if (message.getCode() == MessageCodes.DISCONNECT) { + logger.debug("Disconnected before sending HELLO."); + ctx.close(); + connectFuture.completeExceptionally(new RuntimeException("Disconnect")); + } else { + logger.debug( + "Message received before HELLO's exchanged, disconnecting. Code: {}, Data: {}", + message.getCode(), Arrays.toString(message.getData())); + + DisconnectMessage disconnectMessage = DisconnectMessage.create(DisconnectReason.UNKNOWN); + ctx.writeAndFlush(new RawMessage(disconnectMessage.getCode(), disconnectMessage.getData())) + .addListener((f) -> ctx.close()); + connectFuture.completeExceptionally(new RuntimeException("Message received before HELLO's exchanged")); + } + } + + @Override + public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable throwable) + throws Exception { + final Throwable cause = + throwable instanceof DecoderException && throwable.getCause() != null + ? throwable.getCause() + : throwable; + if (cause instanceof IllegalArgumentException) { + logger.debug("Invalid incoming message ", throwable); + if (connectFuture.isDone() && !connectFuture.isCompletedExceptionally()) { + connectFuture.get().disconnect(DisconnectReason.INVALID_MESSAGE_RECEIVED); + return; + } + } else if (cause instanceof IOException) { + // IO failures are routine when communicating with random peers across the network. + logger.debug("IO error while processing incoming message", throwable); + } else { + logger.error("Exception while processing incoming message", throwable); + } + if (connectFuture.isDone() && !connectFuture.isCompletedExceptionally()) { + connectFuture.get().terminateConnection(); + } else { + connectFuture.completeExceptionally(throwable); + ctx.close(); + } + } +} diff --git a/src/main/java/org/codenil/comm/netty/handler/MessageFrameEncoder.java b/src/main/java/org/codenil/comm/netty/handler/MessageFrameEncoder.java new file mode 100644 index 0000000..adc8173 --- /dev/null +++ b/src/main/java/org/codenil/comm/netty/handler/MessageFrameEncoder.java @@ -0,0 +1,32 @@ +package org.codenil.comm.netty.handler; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToByteEncoder; +import org.codenil.comm.message.RawMessage; +import org.codenil.comm.serialize.SerializeHelper; + +import java.nio.charset.StandardCharsets; +import java.util.Optional; + +public class MessageFrameEncoder extends MessageToByteEncoder { + + public MessageFrameEncoder() {} + + @Override + protected void encode( + final ChannelHandlerContext ctx, + final RawMessage msg, + final ByteBuf out) { + byte[] idBytes = Optional.ofNullable(msg.getRequestId()).orElse("").getBytes(StandardCharsets.UTF_8); + + SerializeHelper builder = new SerializeHelper(); + ByteBuf buf = builder.writeBytes(idBytes) + .writeInt(msg.getCode()) + .writeBytes(msg.getData()) + .build(); + + out.writeBytes(buf); + buf.release(); + } +} \ No newline at end of file diff --git a/src/main/java/org/codenil/comm/netty/handler/MessageHandler.java b/src/main/java/org/codenil/comm/netty/handler/MessageHandler.java new file mode 100644 index 0000000..779fb2f --- /dev/null +++ b/src/main/java/org/codenil/comm/netty/handler/MessageHandler.java @@ -0,0 +1,51 @@ +package org.codenil.comm.netty.handler; + +import io.netty.buffer.ByteBuf; +import org.codenil.comm.handshake.PlainMessage; +import org.codenil.comm.message.MessageType; +import org.codenil.comm.serialize.SerializeHelper; + +public class MessageHandler { + public static byte[] buildMessage(final PlainMessage message) { + SerializeHelper builder = new SerializeHelper(); + ByteBuf buf = builder.writeInt(message.messageType().getValue()) + .writeInt(message.code()) + .writeBytes(message.data()).build(); + + byte[] result = new byte[buf.readableBytes()]; + buf.readBytes(result); + buf.release(); + return result; + } + + public static byte[] buildMessage( + final MessageType messageType, + final int code, + final byte[] data) { + return buildMessage(new PlainMessage(messageType, code, data)); + } + + public static PlainMessage parseMessage(final ByteBuf buf) { + PlainMessage ret = null; + + buf.readerIndex(0); + + //跳过版本 + int versionLength = buf.readInt(); + buf.skipBytes(versionLength); + + int payloadLength = buf.readInt(); + if(payloadLength < 8) { + return ret; + } + + int messageType = buf.readInt(); + int code = buf.readInt(); + int dataLength = buf.readInt(); + byte[] data = new byte[dataLength]; + buf.readBytes(data); + + ret = new PlainMessage(MessageType.forNumber(messageType), code, data); + return ret; + } +} diff --git a/src/main/java/org/codenil/comm/netty/handler/TimeoutHandler.java b/src/main/java/org/codenil/comm/netty/handler/TimeoutHandler.java new file mode 100644 index 0000000..5442458 --- /dev/null +++ b/src/main/java/org/codenil/comm/netty/handler/TimeoutHandler.java @@ -0,0 +1,38 @@ +package org.codenil.comm.netty.handler; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelInitializer; + +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +public class TimeoutHandler extends ChannelInitializer { + private final Supplier condition; + private final int timeoutInSeconds; + private final OnTimeoutCallback callback; + + public TimeoutHandler( + final Supplier condition, + final int timeoutInSeconds, + final OnTimeoutCallback callback) { + this.condition = condition; + this.timeoutInSeconds = timeoutInSeconds; + this.callback = callback; + } + + @Override + protected void initChannel(final C ch) throws Exception { + ch.eventLoop().schedule(() -> { + if (!condition.get()) { + callback.invoke(); + ch.close(); + } + }, timeoutInSeconds, TimeUnit.SECONDS); + } + + @FunctionalInterface + public interface OnTimeoutCallback { + void invoke(); + } +} + diff --git a/src/main/java/org/codenil/comm/serialize/SerializeHelper.java b/src/main/java/org/codenil/comm/serialize/SerializeHelper.java new file mode 100644 index 0000000..5eaba23 --- /dev/null +++ b/src/main/java/org/codenil/comm/serialize/SerializeHelper.java @@ -0,0 +1,374 @@ +package org.codenil.comm.serialize; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.apache.commons.lang3.StringUtils; + +import java.nio.charset.StandardCharsets; +import java.util.List; + +/** + * 根据协议格式写入数据 + * 返回组装好协议的ByteBuf + * 只能通过这个方法创建协议,其他方式创建的ByteBuf + * 后续修改协议头时不确定兼容 + */ +public class SerializeHelper { + + private final ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + + private ByteBuf buf; + + private boolean hasHeader = false; + + public SerializeHelper() { + buf = allocator.buffer(); + } + + public SerializeHelper( + final ByteBuf buf, + final boolean hasHeader) { + this.buf = buf; + this.hasHeader = hasHeader; + } + + /** + * 协议版本号,固定长度4 + */ + public SerializeHelper writeVersion(final String version) { + if (!Version.supported(version)) { + throw new IllegalArgumentException("Unsupported protocol version: " + version); + } + + if (StringUtils.isBlank(version)) { + throw new IllegalArgumentException("Version cannot be null or empty"); + } + byte[] bytes = version.getBytes(); + if (bytes.length > 4) { + throw new IllegalArgumentException("Version length exceeds 4 bytes"); + } + + ByteBuf buf = allocator.buffer(); + if (this.buf.readableBytes() == 0) { + buf.writeInt(bytes.length); + buf.writeBytes(bytes); + buf.writeInt(0); + this.buf.release(); + this.buf = buf; + return this; + } + + String oldVersion = readVersion(this.buf); + if (oldVersion.equals(version)) { + return this; + } + + //version长度 + buf.writeInt(bytes.length); + //version内容 + buf.writeBytes(bytes); + //payload长度 + buf.writeInt(0); + + this.buf.release(); + this.buf = buf; + this.hasHeader = true; + return this; + } + + public SerializeHelper writeInt(int value) { + ByteBuf buf = allocator.buffer(); + + this.buf.readerIndex(0); + if (this.buf.readableBytes() == 0) { + buf.writeInt(value); + + this.buf.release(); + this.buf = buf; + return this; + } + + if (hasHeader) { + String version = readVersion(this.buf); + buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(version.getBytes(StandardCharsets.UTF_8)); + + //写入总长度 + int payloadLength = this.buf.readInt(); + buf.writeInt(payloadLength + 4); + } + + byte[] payloadBytes = new byte[buf.readableBytes()]; + this.buf.readBytes(payloadBytes); + buf.writeBytes(payloadBytes); + buf.writeInt(value); + + this.buf.release(); + this.buf = buf; + return this; + } + + public SerializeHelper writeLong(long value) { + ByteBuf buf = allocator.buffer(); + + this.buf.readerIndex(0); + if (this.buf.readableBytes() == 0) { + buf.writeLong(value); + + this.buf.release(); + this.buf = buf; + return this; + } + + if (hasHeader) { + String version = readVersion(this.buf); + buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(version.getBytes(StandardCharsets.UTF_8)); + + int payloadLength = this.buf.readInt(); + buf.writeInt(payloadLength + 8); + } + + byte[] payloadBytes = new byte[buf.readableBytes()]; + this.buf.readBytes(payloadBytes); + buf.writeBytes(payloadBytes); + buf.writeLong(value); + + this.buf.release(); + this.buf = buf; + return this; + } + + public SerializeHelper writeString(String value) { + if (value == null) { + return this; + } + ByteBuf buf = allocator.buffer(); + + this.buf.readerIndex(0); + if (this.buf.readableBytes() == 0) { + buf.writeInt(value.getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(value.getBytes(StandardCharsets.UTF_8)); + + this.buf.release(); + this.buf = buf; + return this; + } + + if (hasHeader) { + String version = readVersion(this.buf); + buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(version.getBytes(StandardCharsets.UTF_8)); + + int payloadLength = this.buf.readInt(); + buf.writeInt(payloadLength + value.getBytes(StandardCharsets.UTF_8).length + 4); + } + + byte[] payloadBytes = new byte[buf.readableBytes()]; + this.buf.readBytes(payloadBytes); + buf.writeBytes(payloadBytes); + + buf.writeInt(value.getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(value.getBytes(StandardCharsets.UTF_8)); + + this.buf.release(); + this.buf = buf; + return this; + } + + public SerializeHelper writeBoolean(boolean value) { + ByteBuf buf = allocator.buffer(); + + this.buf.readerIndex(0); + if (this.buf.readableBytes() == 0) { + buf.writeInt(1); + buf.writeBoolean(value); + + this.buf.release(); + this.buf = buf; + return this; + } + + if (hasHeader) { + String version = readVersion(this.buf); + buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(version.getBytes(StandardCharsets.UTF_8)); + + int payloadLength = this.buf.readInt(); + buf.writeInt(payloadLength + 1 + 4); + } + + byte[] payloadBytes = new byte[buf.readableBytes()]; + this.buf.readBytes(payloadBytes); + buf.writeBytes(payloadBytes); + + buf.writeInt(1); + buf.writeBoolean(value); + + this.buf.release(); + this.buf = buf; + return this; + } + + public SerializeHelper writeBytes(byte[] value) { + ByteBuf buf = allocator.buffer(); + + this.buf.readerIndex(0); + if (this.buf.readableBytes() == 0) { + buf.writeInt(value.length); + buf.writeBytes(value); + + this.buf.release(); + this.buf = buf; + return this; + } + + if (hasHeader) { + String version = readVersion(this.buf); + buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(version.getBytes(StandardCharsets.UTF_8)); + + int payloadLength = this.buf.readInt(); + buf.writeInt(payloadLength + value.length + 4); + } + + byte[] payloadBytes = new byte[buf.readableBytes()]; + this.buf.readBytes(payloadBytes); + buf.writeBytes(payloadBytes); + + buf.writeInt(value.length); + buf.writeBytes(value); + + this.buf.release(); + this.buf = buf; + return this; + } + + public SerializeHelper writeBytes(List values) { + if (values == null) { + return this; + } + + int totalLength = values.stream() + .mapToInt(bytes -> bytes.length + 4).sum(); + + ByteBuf buf = allocator.buffer(); + + this.buf.readerIndex(0); + if (this.buf.readableBytes() == 0) { + buf.writeInt(totalLength); + values.forEach(value -> { + buf.writeInt(value.length); + buf.writeBytes(value); + }); + + this.buf.release(); + this.buf = buf; + return this; + } + + if (hasHeader) { + String version = readVersion(this.buf); + buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(version.getBytes(StandardCharsets.UTF_8)); + + int payloadLength = this.buf.readInt(); + buf.writeInt(payloadLength + totalLength); + } + + byte[] payloadBytes = new byte[buf.readableBytes()]; + this.buf.readBytes(payloadBytes); + buf.writeBytes(payloadBytes); + + buf.writeInt(totalLength); + values.forEach(value -> { + buf.writeInt(value.length); + buf.writeBytes(value); + }); + this.buf.release(); + this.buf = buf; + return this; + } + + public SerializeHelper writeList(List values) { + if (values == null) { + return this; + } + + int totalLength = values.stream() + .map(value -> value.getBytes(StandardCharsets.UTF_8)) + .mapToInt(bytes -> bytes.length + 4) + .sum(); + + ByteBuf buf = allocator.buffer(); + + this.buf.readerIndex(0); + if (this.buf.readableBytes() == 0) { + buf.writeInt(Version.defaultVersion().getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(Version.defaultVersion().getBytes(StandardCharsets.UTF_8)); + + buf.writeInt(totalLength); + + values.forEach(value -> { + buf.writeInt(value.getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(value.getBytes(StandardCharsets.UTF_8)); + }); + + this.buf.release(); + this.buf = buf; + return this; + } + + String version = readVersion(this.buf); + buf.writeInt(version.getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(version.getBytes(StandardCharsets.UTF_8)); + + int payloadLength = this.buf.readInt(); + byte[] payloadBytes = new byte[payloadLength]; + this.buf.readBytes(payloadBytes); + + buf.writeInt(payloadLength + totalLength); + + buf.writeBytes(payloadBytes); + + values.forEach(value -> { + buf.writeInt(value.getBytes(StandardCharsets.UTF_8).length); + buf.writeBytes(value.getBytes(StandardCharsets.UTF_8)); + }); + this.buf.release(); + this.buf = buf; + return this; + } + + public String readVersion(ByteBuf buf) { + this.buf.readerIndex(0); + int versionLength = this.buf.readInt(); + byte[] versionBytes = new byte[versionLength]; + this.buf.readBytes(versionBytes); + + return new String(versionBytes); + } + + public byte[] readPayload(ByteBuf buf) { + this.buf.readerIndex(0); + + if (hasHeader) { + int versionLength = this.buf.readInt(); + this.buf.skipBytes(versionLength); + } + + if (this.buf.readableBytes() == 0) { + return new byte[]{}; + } + + int payloadLength = this.buf.readInt(); + byte[] payloadBytes = new byte[payloadLength]; + this.buf.readBytes(payloadBytes); + + return payloadBytes; + } + + public ByteBuf build() { + return buf; + } +} \ No newline at end of file diff --git a/src/main/java/org/codenil/comm/serialize/Version.java b/src/main/java/org/codenil/comm/serialize/Version.java new file mode 100644 index 0000000..4f56d0f --- /dev/null +++ b/src/main/java/org/codenil/comm/serialize/Version.java @@ -0,0 +1,14 @@ +package org.codenil.comm.serialize; + +public class Version { + + public static final String VERSION_1_0 = "1.0"; + + public static boolean supported(String version) { + return VERSION_1_0.equals(version); + } + + public static String defaultVersion() { + return VERSION_1_0; + } +} diff --git a/src/main/resources/logback.xml b/src/main/resources/logback.xml new file mode 100644 index 0000000..e9b764e --- /dev/null +++ b/src/main/resources/logback.xml @@ -0,0 +1,12 @@ + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + \ No newline at end of file