aboutsummaryrefslogtreecommitdiff
path: root/mod/src/main/java/moe/ymc/acron/net
diff options
context:
space:
mode:
Diffstat (limited to 'mod/src/main/java/moe/ymc/acron/net')
-rw-r--r--mod/src/main/java/moe/ymc/acron/net/AcronInitializer.java25
-rw-r--r--mod/src/main/java/moe/ymc/acron/net/Attributes.java13
-rw-r--r--mod/src/main/java/moe/ymc/acron/net/AuthHandler.java98
-rw-r--r--mod/src/main/java/moe/ymc/acron/net/ClientConfiguration.java20
-rw-r--r--mod/src/main/java/moe/ymc/acron/net/ClientIdentification.java11
-rw-r--r--mod/src/main/java/moe/ymc/acron/net/HandshakeComplete.java7
-rw-r--r--mod/src/main/java/moe/ymc/acron/net/WSFrameHandler.java174
7 files changed, 348 insertions, 0 deletions
diff --git a/mod/src/main/java/moe/ymc/acron/net/AcronInitializer.java b/mod/src/main/java/moe/ymc/acron/net/AcronInitializer.java
new file mode 100644
index 0000000..c9953e3
--- /dev/null
+++ b/mod/src/main/java/moe/ymc/acron/net/AcronInitializer.java
@@ -0,0 +1,25 @@
+package moe.ymc.acron.net;
+
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.http.HttpObjectAggregator;
+import io.netty.handler.codec.http.HttpServerCodec;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+
+/**
+ * A channel initializer for all Acron handlers.
+ */
+public class AcronInitializer extends ChannelInitializer<SocketChannel> {
+ private static final Logger LOGGER = LogManager.getLogger();
+
+ @Override
+ protected void initChannel(SocketChannel ch) throws Exception {
+ LOGGER.debug("initChannel");
+ ch.pipeline()
+ .addLast(new HttpServerCodec())
+ .addLast(new HttpObjectAggregator(65536))
+ .addLast(new AuthHandler())
+ ;
+ }
+}
diff --git a/mod/src/main/java/moe/ymc/acron/net/Attributes.java b/mod/src/main/java/moe/ymc/acron/net/Attributes.java
new file mode 100644
index 0000000..ddb0f5c
--- /dev/null
+++ b/mod/src/main/java/moe/ymc/acron/net/Attributes.java
@@ -0,0 +1,13 @@
+package moe.ymc.acron.net;
+
+import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
+import io.netty.util.AttributeKey;
+
+final class Attributes {
+ public static final AttributeKey<ClientIdentification> ID =
+ AttributeKey.newInstance("CLENT_ID");
+ public static final AttributeKey<ClientConfiguration> CONFIGURATION =
+ AttributeKey.newInstance("CLIENT_CONFIG");
+ public static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER =
+ AttributeKey.newInstance("HANDSHAKER");
+}
diff --git a/mod/src/main/java/moe/ymc/acron/net/AuthHandler.java b/mod/src/main/java/moe/ymc/acron/net/AuthHandler.java
new file mode 100644
index 0000000..3e42e14
--- /dev/null
+++ b/mod/src/main/java/moe/ymc/acron/net/AuthHandler.java
@@ -0,0 +1,98 @@
+package moe.ymc.acron.net;
+
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.handler.codec.http.*;
+import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
+import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
+import moe.ymc.acron.auth.Client;
+import moe.ymc.acron.config.Config;
+import org.apache.commons.codec.digest.DigestUtils;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+
+/**
+ * Handle handshake request and authentication.
+ * We cannot use WebSocketServerProtocolHandler because it does not allow
+ * us doing anything before handshaking.
+ */
+public class AuthHandler extends SimpleChannelInboundHandler<HttpRequest> {
+ private static final Logger LOGGER = LogManager.getLogger();
+
+ @Override
+ protected void channelRead0(ChannelHandlerContext ctx, HttpRequest msg) throws Exception {
+ LOGGER.debug("channelRead0: {}", msg.uri());
+ if (msg.method() != HttpMethod.GET) {
+ ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
+ HttpResponseStatus.BAD_REQUEST))
+ .addListener(ChannelFutureListener.CLOSE);
+ return;
+ }
+ HttpHeaders headers = msg.headers();
+
+ if (!"Upgrade".equalsIgnoreCase(headers.get(HttpHeaderNames.CONNECTION)) ||
+ !"WebSocket".equalsIgnoreCase(headers.get(HttpHeaderNames.UPGRADE))) {
+ ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
+ HttpResponseStatus.BAD_REQUEST))
+ .addListener(ChannelFutureListener.CLOSE);
+ return;
+ }
+
+ final QueryStringDecoder decoder = new QueryStringDecoder(msg.uri());
+ if (!decoder.path().equals("/ws")) {
+ ctx.fireChannelRead(msg);
+ return;
+ }
+ if (decoder.parameters().isEmpty() ||
+ decoder.parameters().get("id") == null ||
+ decoder.parameters().get("id").size() != 1 ||
+ decoder.parameters().get("token") == null ||
+ decoder.parameters().get("token").size() != 1) {
+ ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
+ HttpResponseStatus.BAD_REQUEST))
+ .addListener(ChannelFutureListener.CLOSE);
+ return;
+ }
+
+ final String id = decoder.parameters().get("id").get(0);
+ final String token = decoder.parameters().get("token").get(0);
+ final String versionRaw = (decoder.parameters().get("version") == null ||
+ decoder.parameters().get("version").isEmpty()) ? "0" :
+ decoder.parameters().get("version").get(0);
+ try {
+ if (Integer.parseInt(versionRaw) != 0) {
+ ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
+ HttpResponseStatus.BAD_REQUEST))
+ .addListener(ChannelFutureListener.CLOSE);
+ return;
+ }
+ } catch (NumberFormatException ignored) {
+ ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
+ HttpResponseStatus.BAD_REQUEST))
+ .addListener(ChannelFutureListener.CLOSE);
+ return;
+ }
+
+ final Client client = Config.getGlobalConfig().clients().get(id);
+ if (client == null ||
+ !client.token().equals(DigestUtils.sha256Hex(token))) {
+ ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
+ HttpResponseStatus.UNAUTHORIZED))
+ .addListener(ChannelFutureListener.CLOSE);
+ return;
+ }
+ ctx.channel().attr(Attributes.ID).set(new ClientIdentification(0, client));
+ WebSocketServerHandshakerFactory wsFactory =
+ new WebSocketServerHandshakerFactory("/ws", null, true);
+ final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(msg);
+ if (handshaker == null) {
+ WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
+ return;
+ }
+ ctx.channel().attr(Attributes.HANDSHAKER).set(handshaker);
+ handshaker.handshake(ctx.channel(), msg);
+ ctx.pipeline().replace(this, "websocketHandler", new WSFrameHandler());
+ ctx.fireUserEventTriggered(new HandshakeComplete());
+ }
+}
diff --git a/mod/src/main/java/moe/ymc/acron/net/ClientConfiguration.java b/mod/src/main/java/moe/ymc/acron/net/ClientConfiguration.java
new file mode 100644
index 0000000..450ccd4
--- /dev/null
+++ b/mod/src/main/java/moe/ymc/acron/net/ClientConfiguration.java
@@ -0,0 +1,20 @@
+package moe.ymc.acron.net;
+
+import net.minecraft.server.world.ServerWorld;
+import net.minecraft.util.math.Vec2f;
+import net.minecraft.util.math.Vec3d;
+import org.jetbrains.annotations.NotNull;
+
+public record ClientConfiguration(@NotNull Vec3d pos,
+ @NotNull Vec2f rot,
+ @NotNull ServerWorld world,
+ @NotNull String name) {
+ public ClientConfiguration(@NotNull ServerWorld world,
+ @NotNull String name) {
+ // Rcon defaults. @see RconCommandOutput
+ this(Vec3d.of(world.getSpawnPos()),
+ Vec2f.ZERO,
+ world,
+ name);
+ }
+}
diff --git a/mod/src/main/java/moe/ymc/acron/net/ClientIdentification.java b/mod/src/main/java/moe/ymc/acron/net/ClientIdentification.java
new file mode 100644
index 0000000..1cb4375
--- /dev/null
+++ b/mod/src/main/java/moe/ymc/acron/net/ClientIdentification.java
@@ -0,0 +1,11 @@
+package moe.ymc.acron.net;
+
+import moe.ymc.acron.auth.Client;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Read only client configurations.
+ */
+public record ClientIdentification(int version,
+ @NotNull Client client) {
+}
diff --git a/mod/src/main/java/moe/ymc/acron/net/HandshakeComplete.java b/mod/src/main/java/moe/ymc/acron/net/HandshakeComplete.java
new file mode 100644
index 0000000..348b5e2
--- /dev/null
+++ b/mod/src/main/java/moe/ymc/acron/net/HandshakeComplete.java
@@ -0,0 +1,7 @@
+package moe.ymc.acron.net;
+
+/**
+ * User event used to tell WSFrameHandler that the handshake is complete.
+ */
+public class HandshakeComplete {
+}
diff --git a/mod/src/main/java/moe/ymc/acron/net/WSFrameHandler.java b/mod/src/main/java/moe/ymc/acron/net/WSFrameHandler.java
new file mode 100644
index 0000000..912e73a
--- /dev/null
+++ b/mod/src/main/java/moe/ymc/acron/net/WSFrameHandler.java
@@ -0,0 +1,174 @@
+package moe.ymc.acron.net;
+
+import com.google.gson.JsonParseException;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.handler.codec.http.websocketx.*;
+import moe.ymc.acron.MinecraftServerHolder;
+import moe.ymc.acron.auth.Action;
+import moe.ymc.acron.auth.PolicyChecker;
+import moe.ymc.acron.c2s.ReqCmd;
+import moe.ymc.acron.c2s.ReqSetConfig;
+import moe.ymc.acron.c2s.Request;
+import moe.ymc.acron.cmd.CmdQueue;
+import moe.ymc.acron.jvav.Pair;
+import moe.ymc.acron.s2c.Event;
+import moe.ymc.acron.s2c.EventQueue;
+import moe.ymc.acron.s2c.response.EventError;
+import moe.ymc.acron.s2c.response.EventOk;
+import moe.ymc.acron.serialization.Serializer;
+import net.minecraft.server.world.ServerWorld;
+import net.minecraft.util.math.Vec2f;
+import net.minecraft.util.math.Vec3d;
+import net.minecraft.world.World;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * The handler for WebSocket requests.
+ */
+public class WSFrameHandler extends SimpleChannelInboundHandler<WebSocketFrame> {
+ private static final Logger LOGGER = LogManager.getLogger();
+
+ @Override
+ public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
+ super.handlerAdded(ctx);
+ }
+
+ @Override
+ protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame msg) throws Exception {
+ LOGGER.debug("channelRead0: {} {}",
+ this,
+ ctx.channel());
+ final WebSocketServerHandshaker handshaker =
+ ctx.channel().attr(Attributes.HANDSHAKER).get();
+ if (msg instanceof CloseWebSocketFrame) {
+ handshaker.close(ctx.channel(), (CloseWebSocketFrame) msg.retain());
+ return;
+ }
+ if (msg instanceof PingWebSocketFrame) {
+ ctx.write(new PongWebSocketFrame(msg.content().retain()));
+ return;
+ }
+ if (msg instanceof BinaryWebSocketFrame) {
+ throw new UnsupportedOperationException("Only text frames are accepted.");
+ }
+ final TextWebSocketFrame frame = (TextWebSocketFrame) msg;
+
+ final ClientIdentification identification = ctx.channel().attr(Attributes.ID).get();
+ final ClientConfiguration configuration = ctx.channel().attr(Attributes.CONFIGURATION).get();
+ int id;
+ final Request request;
+ try {
+ request = Serializer.deserialize(frame);
+ id = request.getId();
+ } catch (JsonParseException | IllegalArgumentException | IllegalStateException e) {
+ ctx.channel().writeAndFlush(
+ Serializer.serialize(new EventError(-2, EventError.Code.BAD_REQUEST.value, e.getMessage()))
+ );
+ return;
+ }
+ try {
+ ctx.channel().writeAndFlush(Serializer.serialize(handle(request,
+ identification,
+ configuration,
+ ctx.channel())));
+ } catch (Throwable e) {
+ LOGGER.info("An error occurred while processing the request. " +
+ "This may just be a malformed request. " +
+ "It is reported to the client.",
+ e);
+ ctx.channel().writeAndFlush(
+ Serializer.serialize(new EventError(id, EventError.Code.SERVER_ERROR.value, e.getMessage()))
+ );
+ }
+ }
+
+ @NotNull
+ private Event handle(@NotNull Request request,
+ @NotNull ClientIdentification identification,
+ @NotNull ClientConfiguration configuration,
+ @NotNull Channel channel) throws Throwable {
+ if (request instanceof final ReqCmd reqCmd) {
+ LOGGER.info("Client {} executed a command: `{}`.",
+ identification.client().id(),
+ reqCmd.cmd());
+ final Pair<Action, Boolean> res = PolicyChecker.check(identification.client(),
+ reqCmd.cmd());
+ if (res.l() == Action.DENY) {
+ return new EventError(reqCmd.id(),
+ EventError.Code.FORBIDDEN.value, "This client is not allowed to " +
+ "execute this command.");
+ }
+ // TODO: Ok event may be sent after executing the command.
+ CmdQueue.enqueue(reqCmd.id(),
+ res.r(),
+ channel,
+ reqCmd.config() == null ?
+ configuration :
+ convertConfiguration(reqCmd.config()),
+ reqCmd.cmd());
+ return new EventOk(request.getId());
+ } else if (request instanceof final ReqSetConfig reqSetConfig) {
+ channel.attr(Attributes.CONFIGURATION).set(convertConfiguration(reqSetConfig));
+ return new EventOk(request.getId());
+ }
+ // This should not occur.
+ throw new IllegalStateException("This should not occur.");
+ }
+
+ private ClientConfiguration convertConfiguration(@NotNull ReqSetConfig request) {
+ final ServerWorld world;
+ if (request.world() != null) {
+ switch (request.world()) {
+ case OVERWORLD -> world = MinecraftServerHolder.getServer().getWorld(World.OVERWORLD);
+ case NETHER -> world = MinecraftServerHolder.getServer().getWorld(World.NETHER);
+ case END -> world = MinecraftServerHolder.getServer().getWorld(World.END);
+ default -> throw new IllegalArgumentException();
+ }
+ } else {
+ world = MinecraftServerHolder.getServer().getOverworld();
+ }
+ if (world == null) {
+ throw new IllegalStateException(String.format("The requested world %s is not available at this time.",
+ request.world()));
+ }
+ return new ClientConfiguration(
+ request.pos() == null ?
+ Vec3d.of(world.getSpawnPos()) :
+ new Vec3d(request.pos().x(), request.pos().y(), request.pos().z()),
+ request.rot() == null ?
+ Vec2f.ZERO :
+ new Vec2f(request.rot().x(), request.rot().y()),
+ world,
+ request.name() == null ? this.toString() : request.name()
+ );
+ }
+
+ @Override
+ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
+ LOGGER.debug("handshakeComplete: {} {}",
+ this,
+ ctx.channel());
+ if (evt instanceof HandshakeComplete) {
+ final ClientIdentification identification = ctx.channel().attr(Attributes.ID).get();
+ LOGGER.info("Client {} connected. It has {} rules with {} policy mode.",
+ identification.client().id(),
+ identification.client().rules().length,
+ identification.client().policyMode());
+ final ServerWorld defaultWorld = MinecraftServerHolder.getServer().getOverworld();
+ if (defaultWorld == null) {
+ throw new IllegalStateException("The default world is not available at this time.");
+ }
+ final ClientConfiguration configuration =
+ new ClientConfiguration(defaultWorld,
+ identification.client().id());
+ ctx.channel().attr(Attributes.CONFIGURATION).set(configuration);
+ EventQueue.registerMessageRecipient(ctx.channel());
+ } else {
+ ctx.fireUserEventTriggered(evt);
+ }
+ }
+}