package moe.ymc.acron.net; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.*; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory; import moe.ymc.acron.auth.Client; import moe.ymc.acron.config.Config; import org.apache.commons.codec.digest.DigestUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; /** * Handle handshake request and authentication. * We cannot use WebSocketServerProtocolHandler because it does not allow * us doing anything before handshaking. */ public class AuthHandler extends SimpleChannelInboundHandler { private static final Logger LOGGER = LogManager.getLogger(); @Override protected void channelRead0(ChannelHandlerContext ctx, HttpRequest msg) throws Exception { LOGGER.debug("channelRead0: {}", msg.uri()); if (msg.method() != HttpMethod.GET) { ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST)) .addListener(ChannelFutureListener.CLOSE); return; } HttpHeaders headers = msg.headers(); if (!"Upgrade".equalsIgnoreCase(headers.get(HttpHeaderNames.CONNECTION)) || !"WebSocket".equalsIgnoreCase(headers.get(HttpHeaderNames.UPGRADE))) { ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST)) .addListener(ChannelFutureListener.CLOSE); return; } final QueryStringDecoder decoder = new QueryStringDecoder(msg.uri()); if (!decoder.path().equals("/ws")) { ctx.fireChannelRead(msg); return; } if (decoder.parameters().isEmpty() || decoder.parameters().get("id") == null || decoder.parameters().get("id").size() != 1 || decoder.parameters().get("token") == null || decoder.parameters().get("token").size() != 1) { ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST)) .addListener(ChannelFutureListener.CLOSE); return; } final String id = decoder.parameters().get("id").get(0); final String token = decoder.parameters().get("token").get(0); final String versionRaw = (decoder.parameters().get("version") == null || decoder.parameters().get("version").isEmpty()) ? "0" : decoder.parameters().get("version").get(0); try { if (Integer.parseInt(versionRaw) != 0) { ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST)) .addListener(ChannelFutureListener.CLOSE); return; } } catch (NumberFormatException ignored) { ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST)) .addListener(ChannelFutureListener.CLOSE); return; } final Client client = Config.getGlobalConfig().clients().get(id); if (client == null || !client.token().equals(DigestUtils.sha256Hex(token))) { ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED)) .addListener(ChannelFutureListener.CLOSE); return; } ctx.channel().attr(Attributes.ID).set(new ClientIdentification(0, client)); WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory("/ws", null, true); final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(msg); if (handshaker == null) { WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); return; } ctx.channel().attr(Attributes.HANDSHAKER).set(handshaker); handshaker.handshake(ctx.channel(), msg); ctx.pipeline().replace(this, "websocketHandler", new WSFrameHandler()); ctx.fireUserEventTriggered(new HandshakeComplete()); } }