1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
|
package moe.ymc.acron.net;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import moe.ymc.acron.auth.Client;
import moe.ymc.acron.config.Config;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
/**
* Handle handshake request and authentication.
* We cannot use WebSocketServerProtocolHandler because it does not allow
* us doing anything before handshaking.
*/
public class AuthHandler extends SimpleChannelInboundHandler<HttpRequest> {
private static final Logger LOGGER = LogManager.getLogger();
@Override
protected void channelRead0(ChannelHandlerContext ctx, HttpRequest msg) throws Exception {
LOGGER.debug("channelRead0: {}", msg.uri());
if (msg.method() != HttpMethod.GET) {
ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.BAD_REQUEST))
.addListener(ChannelFutureListener.CLOSE);
return;
}
HttpHeaders headers = msg.headers();
if (!"Upgrade".equalsIgnoreCase(headers.get(HttpHeaderNames.CONNECTION)) ||
!"WebSocket".equalsIgnoreCase(headers.get(HttpHeaderNames.UPGRADE))) {
ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.BAD_REQUEST))
.addListener(ChannelFutureListener.CLOSE);
return;
}
final QueryStringDecoder decoder = new QueryStringDecoder(msg.uri());
if (!decoder.path().equals("/ws")) {
ctx.fireChannelRead(msg);
return;
}
if (decoder.parameters().isEmpty() ||
decoder.parameters().get("id") == null ||
decoder.parameters().get("id").size() != 1 ||
decoder.parameters().get("token") == null ||
decoder.parameters().get("token").size() != 1) {
ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.BAD_REQUEST))
.addListener(ChannelFutureListener.CLOSE);
return;
}
final String id = decoder.parameters().get("id").get(0);
final String token = decoder.parameters().get("token").get(0);
final String versionRaw = (decoder.parameters().get("version") == null ||
decoder.parameters().get("version").isEmpty()) ? "0" :
decoder.parameters().get("version").get(0);
try {
if (Integer.parseInt(versionRaw) != 0) {
ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.BAD_REQUEST))
.addListener(ChannelFutureListener.CLOSE);
return;
}
} catch (NumberFormatException ignored) {
ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.BAD_REQUEST))
.addListener(ChannelFutureListener.CLOSE);
return;
}
final Client client = Config.getGlobalConfig().clients().get(id);
if (client == null ||
!client.token().equals(DigestUtils.sha256Hex(token))) {
ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.UNAUTHORIZED))
.addListener(ChannelFutureListener.CLOSE);
return;
}
ctx.channel().attr(Attributes.ID).set(new ClientIdentification(0, client));
WebSocketServerHandshakerFactory wsFactory =
new WebSocketServerHandshakerFactory("/ws", null, true);
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(msg);
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
return;
}
ctx.channel().attr(Attributes.HANDSHAKER).set(handshaker);
handshaker.handshake(ctx.channel(), msg);
ctx.pipeline().replace(this, "websocketHandler", new WSFrameHandler());
ctx.fireUserEventTriggered(new HandshakeComplete());
}
}
|