aboutsummaryrefslogtreecommitdiff
path: root/mod/src/main/java/moe/ymc/acron/net/AuthHandler.java
blob: 3e42e14164522eccb3cdde403edb71abd08d7cec (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package moe.ymc.acron.net;

import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import moe.ymc.acron.auth.Client;
import moe.ymc.acron.config.Config;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/**
 * Handle handshake request and authentication.
 * We cannot use WebSocketServerProtocolHandler because it does not allow
 * us doing anything before handshaking.
 */
public class AuthHandler extends SimpleChannelInboundHandler<HttpRequest> {
    private static final Logger LOGGER = LogManager.getLogger();

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, HttpRequest msg) throws Exception {
        LOGGER.debug("channelRead0: {}", msg.uri());
        if (msg.method() != HttpMethod.GET) {
            ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
                            HttpResponseStatus.BAD_REQUEST))
                    .addListener(ChannelFutureListener.CLOSE);
            return;
        }
        HttpHeaders headers = msg.headers();

        if (!"Upgrade".equalsIgnoreCase(headers.get(HttpHeaderNames.CONNECTION)) ||
                !"WebSocket".equalsIgnoreCase(headers.get(HttpHeaderNames.UPGRADE))) {
            ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
                            HttpResponseStatus.BAD_REQUEST))
                    .addListener(ChannelFutureListener.CLOSE);
            return;
        }

        final QueryStringDecoder decoder = new QueryStringDecoder(msg.uri());
        if (!decoder.path().equals("/ws")) {
            ctx.fireChannelRead(msg);
            return;
        }
        if (decoder.parameters().isEmpty() ||
                decoder.parameters().get("id") == null ||
                decoder.parameters().get("id").size() != 1 ||
                decoder.parameters().get("token") == null ||
                decoder.parameters().get("token").size() != 1) {
            ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
                            HttpResponseStatus.BAD_REQUEST))
                    .addListener(ChannelFutureListener.CLOSE);
            return;
        }

        final String id = decoder.parameters().get("id").get(0);
        final String token = decoder.parameters().get("token").get(0);
        final String versionRaw = (decoder.parameters().get("version") == null ||
                decoder.parameters().get("version").isEmpty()) ? "0" :
                decoder.parameters().get("version").get(0);
        try {
            if (Integer.parseInt(versionRaw) != 0) {
                ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
                                HttpResponseStatus.BAD_REQUEST))
                        .addListener(ChannelFutureListener.CLOSE);
                return;
            }
        } catch (NumberFormatException ignored) {
            ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
                            HttpResponseStatus.BAD_REQUEST))
                    .addListener(ChannelFutureListener.CLOSE);
            return;
        }

        final Client client = Config.getGlobalConfig().clients().get(id);
        if (client == null ||
                !client.token().equals(DigestUtils.sha256Hex(token))) {
            ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
                            HttpResponseStatus.UNAUTHORIZED))
                    .addListener(ChannelFutureListener.CLOSE);
            return;
        }
        ctx.channel().attr(Attributes.ID).set(new ClientIdentification(0, client));
        WebSocketServerHandshakerFactory wsFactory =
                new WebSocketServerHandshakerFactory("/ws", null, true);
        final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(msg);
        if (handshaker == null) {
            WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
            return;
        }
        ctx.channel().attr(Attributes.HANDSHAKER).set(handshaker);
        handshaker.handshake(ctx.channel(), msg);
        ctx.pipeline().replace(this, "websocketHandler", new WSFrameHandler());
        ctx.fireUserEventTriggered(new HandshakeComplete());
    }
}