aboutsummaryrefslogtreecommitdiff
path: root/agent/src/main/java/moe/yuuta/dn42peering/agent/provision/WireGuardProvisioner.java
blob: fee8917211854915656fc1229b9e1ab3f1526a69 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
package moe.yuuta.dn42peering.agent.provision;

import io.vertx.core.CompositeFuture;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.impl.logging.Logger;
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.ext.web.common.template.TemplateEngine;
import io.vertx.ext.web.templ.freemarker.FreeMarkerTemplateEngine;
import moe.yuuta.dn42peering.agent.ip.AddrInfoItem;
import moe.yuuta.dn42peering.agent.ip.Address;
import moe.yuuta.dn42peering.agent.ip.IP;
import moe.yuuta.dn42peering.agent.ip.IPOptions;
import moe.yuuta.dn42peering.agent.proto.Node;
import moe.yuuta.dn42peering.agent.proto.WireGuardConfig;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.net.Inet6Address;
import java.util.*;
import java.util.stream.Collectors;

public class WireGuardProvisioner implements IProvisioner<WireGuardConfig> {
    private final Logger logger = LoggerFactory.getLogger(getClass().getSimpleName());

    private final TemplateEngine engine;
    private final Vertx vertx;

    public WireGuardProvisioner(@Nonnull Vertx vertx) {
        this(FreeMarkerTemplateEngine.create(vertx, "ftlh"), vertx);
    }

    public WireGuardProvisioner(@Nonnull TemplateEngine engine,
                                @Nonnull Vertx vertx) {
        this.engine = engine;
        this.vertx = vertx;
    }

    @Nonnull
    private Future<List<Change>> calculateDeleteChanges(@Nonnull List<WireGuardConfig> allDesired) {
        final String[] actualNamesRaw = new File("/etc/wireguard/").list((dir, name) -> name.matches("wg_.*\\.conf"));
        final List<String> actualNames = Arrays.stream(actualNamesRaw == null ? new String[]{} : actualNamesRaw)
                .sorted()
                .collect(Collectors.toList());
        return Future.succeededFuture(actualNames.stream()
                .flatMap(string -> {
                    return Arrays.stream(new Change[]{
                            new CommandChange(new String[]{"systemctl", "disable", "--now", "-q", "wg-quick@" + string.replace(".conf", ".service")}),
                            new FileChange("/etc/wireguard/" + string, null, FileChange.Action.DELETE.toString())
                    });
                })
                .collect(Collectors.toList()));
    }

    @Nonnull
    private Future<Buffer> renderConfig(@Nonnull WireGuardConfig config) {
        final Map<String, Object> params = new HashMap<>(5);
        params.put("listen_port", config.getListenPort());
        params.put("self_priv_key", config.getSelfPrivKey());
        params.put("preshared_key", config.getSelfPresharedSecret());
        if (!config.getEndpoint().equals("")) {
            params.put("endpoint", config.getEndpoint());
        }
        params.put("peer_pub_key", config.getPeerPubKey());

        return engine.render(params, "wg.conf.ftlh");
    }

    @Nullable
    private Address searchActualAddress(@Nonnull List<Address> addresses,
                                        @Nonnull String device) {
        // TODO: Optimize algorithm
        for (final Address address : addresses) {
            if(address.getIfname().equals(device))
                return address;
        }
        return null;
    }

    @Nonnull
    private List<String> calculateSingleNetlinkChanges(@Nonnull Node node,
                                                       @Nonnull WireGuardConfig desired,
                                                       @Nullable Address actual) throws IOException {
        final boolean linkLocal =
                !desired.getPeerIPv6().isEmpty() &&
                        Inet6Address.getByName(desired.getPeerIPv6()).isLinkLocalAddress();

        final boolean desireIP6 = !desired.getPeerIPv6().isEmpty();
        final boolean needCreateInterface = actual == null;
        final boolean needCreateAddrs;
        final boolean needUp;

        if(actual == null) {
            needCreateAddrs = true;
            needUp = true;
        } else {
            needUp = !actual.getOperstate().equals("UP") &&
            !actual.getOperstate().equals("UNKNOWN");
            AddrInfoItem actualIP4 = null;
            AddrInfoItem actualIP6 = null;
            boolean excessiveIPs = false;
            for (final AddrInfoItem item : actual.getAddrInfo()) {
                switch (item.getFamily()) {
                    case "inet":
                        if(actualIP4 != null) {
                            excessiveIPs = true;
                            break;
                        } else {
                            actualIP4 = item;
                        }
                        break;
                    case "inet6":
                        if(actualIP6 != null) {
                            excessiveIPs = true;
                            break;
                        } else {
                            actualIP6 = item;
                        }
                        break;
                    default:
                        excessiveIPs = true;
                        break;
                }
            }
            if(excessiveIPs || actualIP4 == null || (desireIP6 && actualIP6 == null) ||
                    (!desireIP6 && actualIP6 != null)) {
                logger.info("Recreating addresses for " + desired.getId() + " since there are extra addresses or necessary addresses cannot be found.");
                needCreateAddrs = true;
            } else {
                boolean needCreateIP4Addr =
                        actualIP4.getPrefixlen() != 32 ||
                                !node.getIpv4().equals(actualIP4.getLocal()) ||
                                !desired.getPeerIPv4().equals(actualIP4.getAddress());
                boolean needCreateIP6Addr = false;
                if(desireIP6) {
                    needCreateIP6Addr =
                            actualIP6.getPrefixlen() != (linkLocal ? 64 : 128) ||
                                    !(linkLocal ? node.getIpv6() : node.getIpv6NonLL()).equals(actualIP6.getLocal()) ||
                                    (linkLocal ? (actualIP6.getAddress() != null) :
                                            !desired.getPeerIPv6().equals(actualIP6.getAddress()));
                    if(needCreateIP6Addr) {
                        logger.info("IPv6 addresses for " + desired.getId() + " is outdated.\n" +
                                "Prefixes match: " + (actualIP6.getPrefixlen() == (linkLocal ? 64 : 128)) + "\n" +
                                "Local addresses match: " + ((linkLocal ? node.getIpv6() : node.getIpv6NonLL()).equals(actualIP6.getLocal())) + "\n" +
                                "Peer addresses match: " + (linkLocal ? (actualIP6.getAddress() == null) :
                                desired.getPeerIPv6().equals(actualIP6.getAddress())));
                    }
                }
                needCreateAddrs = needCreateIP4Addr || needCreateIP6Addr;
                if(needCreateAddrs)
                    logger.info("Recreating addresses for " + desired.getId() +
                            " since IPv4 or IPv6 information is updated: " + needCreateIP4Addr + ", " + needCreateIP6Addr + ".");
            }
        }

        final List<List<String>> changes = new ArrayList<>();
        if(needCreateInterface)
            changes.add(IP.Link.add(desired.getInterface(), "wireguard"));
        if(needCreateAddrs) {
            changes.add(IP.Addr.flush(desired.getInterface()));
            changes.add(IP.Addr.add(node.getIpv4() + "/32",
                    desired.getInterface(),
                    desired.getPeerIPv4() + "/32"));
            if(!desired.getPeerIPv6().isEmpty()) {
                if(linkLocal)
                    changes.add(IP.Addr.add(node.getIpv6() + "/64",
                            desired.getInterface(),
                            null));
                else
                    changes.add(IP.Addr.add(node.getIpv6NonLL() + "/128",
                            desired.getInterface(),
                            desired.getPeerIPv6() + "/128"));
            }
        }
        if(needUp)
            changes.add(IP.Link.set(desired.getInterface(), "up"));
        return changes
                .stream().map(cmd -> String.join(" ", cmd))
                .collect(Collectors.toList());
    }

    @Nonnull
    private Future<List<Change>> calculateTotalNetlinkChanges(@Nonnull Node node,
                                           @Nonnull List<WireGuardConfig> allDesired) {
        return IP.ip(vertx, new IPOptions(), IP.Addr.show(null))
                .compose(IP.Addr::handler)
                .compose(addrs -> {
                    final List<String> ipCommands = new ArrayList<>();
                    for (final WireGuardConfig desired : allDesired) {
                        final Address actual = searchActualAddress(addrs, desired.getInterface());
                        try {
                            ipCommands.addAll(calculateSingleNetlinkChanges(node,
                                    desired,
                                    actual));
                        } catch (IOException e) {
                            return Future.failedFuture(e);
                        }
                    }
                    final List<Change> changes = new ArrayList<>();
                    if(!ipCommands.isEmpty()) {
                        changes.add(new IPChange(true, ipCommands));
                    }
                    return Future.succeededFuture(changes);
                });
    }

    @Nonnull
    private Future<List<Change>> calculateTotalWireGuardChanges(@Nonnull Node node,
                                                              @Nonnull List<WireGuardConfig> allDesired) {
        return CompositeFuture.join(allDesired.stream().map(desired -> {
            return renderConfig(desired)
                    .compose(desiredConf -> {
                        return Future.succeededFuture(new WireGuardSyncConfChange(desired.getInterface(),
                                desiredConf.toString()));
                    });
        }).collect(Collectors.toList()))
                .compose(compositeFuture -> {
                    final List<Change> changes = new ArrayList<>(allDesired.size());
                    for (int i = 0; i < allDesired.size(); i ++) {
                        final Change change = compositeFuture.resultAt(i);
                        if(change == null) continue;
                        changes.add(change);
                    }
                    return Future.succeededFuture(changes);
                });
    }

    @Nonnull
    @Override
    public Future<List<Change>> calculateChanges(@Nonnull Node node, @Nonnull List<WireGuardConfig> allDesired) {
        return calculateDeleteChanges(allDesired).compose(changes -> {
            return calculateTotalNetlinkChanges(node, allDesired)
                    .compose(netlinkChanges -> {
                        changes.addAll(netlinkChanges);
                        return Future.succeededFuture(changes);
                    });
        }).compose(changes -> {
            return calculateTotalWireGuardChanges(node, allDesired)
                    .compose(wireguardChanges -> {
                        changes.addAll(wireguardChanges);
                        return Future.succeededFuture(changes);
                    });
        });
    }
}