forked from coder/coder
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathspeaker.go
More file actions
376 lines (348 loc) · 12.5 KB
/
speaker.go
File metadata and controls
376 lines (348 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
package vpn
import (
"context"
"fmt"
"io"
"strings"
"sync"
"golang.org/x/xerrors"
"google.golang.org/protobuf/proto"
"cdr.dev/slog"
)
type SpeakerRole string
type rpcMessage interface {
proto.Message
GetRpc() *RPC
// EnsureRPC isn't autogenerated, but we'll manually add it for RPC types so that the speaker
// can allocate the RPC.
EnsureRPC() *RPC
}
func (t *TunnelMessage) EnsureRPC() *RPC {
if t.Rpc == nil {
t.Rpc = &RPC{}
}
return t.Rpc
}
func (m *ManagerMessage) EnsureRPC() *RPC {
if m.Rpc == nil {
m.Rpc = &RPC{}
}
return m.Rpc
}
// receivableRPCMessage is an rpcMessage that we can receive, and unmarshal, using generics, from a
// byte stream. proto.Unmarshal requires us to have already allocated the memory for the message
// type we are unmarshalling. All our message types are pointers like *TunnelMessage, so to
// allocate, the compiler needs to know:
//
// a) that the type is a pointer type
// b) what type it is pointing to
//
// So, this generic interface requires that the message is a pointer to the type RR. Then, we pass
// both the receivableRPCMessage and RR as type constraints, so that we'll have access to the
// underlying type when it comes time to allocate it. It's a bit messy, but the alternative is
// reflection, which has its own challenges in understandability.
type receivableRPCMessage[RR any] interface {
rpcMessage
*RR
}
const (
SpeakerRoleManager SpeakerRole = "manager"
SpeakerRoleTunnel SpeakerRole = "tunnel"
)
// speaker is an implementation of the CoderVPN protocol. It handles unary RPCs and their responses,
// as well as the low-level serialization & deserialization to the ReadWriteCloser (rwc).
//
// ┌────────┐ sendCh
// ◄─────│ ◄────────────────────────────────────────────────────────────────── ◄┐
// │ │ ▲ rpc requests
// rwc │ serdes │ │ │ sendReply()
// │ │ ┌───────────────────┐ ┌──────┼──────┐
// ──────► ┼────────► recvFromSerdes() │ rpc │rpc handling │ │
// └────────┘ recvCh │ ┼────────────► ◄──── unaryRPC()
// │ │ responses │ │ │
// │ │ │ │
// │ │ └─────────────┘ ┌ ─ ─│─ ─ ─ ─ ─ ─ ─ ┐
// │ ┼──────────────────────────────────────────► request handling
// └───────────────────┘ requests (outside speaker)
// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘
//
// speaker is implemented as a generic type that accepts the type of message we send (S), the type we receive (R), and
// the underlying type that R points to (RR). The speaker is intended to be wrapped by another, non-generic type for
// the role (manager or tunnel). E.g. Tunnel from this package.
//
// The serdes handles SERialiazation and DESerialization of the low level message types. The wrapping type may send
// non-RPC messages (that is messages that don't expect an explicit reply) by sending on the sendCh.
//
// Unary RPCs are handled by the unaryRPC() function, which handles sending the message and waiting for the response.
//
// recvFromSerdes() reads all incoming messages from the serdes. If they are RPC responses, it dispatches them to the
// waiting unaryRPC() function call, if any. If they are RPC requests or non-RPC messages, it wraps them in a request
// struct and sends them over the requests chan. The manager/tunnel role type must read from this chan and handle
// the requests. If they are RPC types, it should call sendReply() on the request with the reply message.
type speaker[S rpcMessage, R receivableRPCMessage[RR], RR any] struct {
serdes *serdes[S, R, RR]
requests chan *request[S, R]
logger slog.Logger
nextMsgID uint64
ctx context.Context
cancel context.CancelFunc
sendCh chan<- S
recvCh <-chan R
recvLoopDone chan struct{}
mu sync.Mutex
responseChans map[uint64]chan R
}
// newSpeaker creates a new protocol speaker.
func newSpeaker[S rpcMessage, R receivableRPCMessage[RR], RR any](
ctx context.Context, logger slog.Logger, conn io.ReadWriteCloser,
me, them SpeakerRole,
) (
*speaker[S, R, RR], error,
) {
ctx, cancel := context.WithCancel(ctx)
if err := handshake(ctx, conn, logger, me, them); err != nil {
cancel()
return nil, xerrors.Errorf("handshake failed: %w", err)
}
sendCh := make(chan S)
recvCh := make(chan R)
s := &speaker[S, R, RR]{
serdes: newSerdes(ctx, logger, conn, sendCh, recvCh),
logger: logger,
requests: make(chan *request[S, R]),
responseChans: make(map[uint64]chan R),
nextMsgID: 1,
ctx: ctx,
cancel: cancel,
sendCh: sendCh,
recvCh: recvCh,
recvLoopDone: make(chan struct{}),
}
return s, nil
}
// start starts the serialzation/deserialization. It's important this happens
// after any assignments of the speaker to its owning Tunnel or Manager, since
// the mutex is copied and that is not threadsafe.
// nolint: revive
func (s *speaker[_, _, _]) start() {
s.serdes.start()
go s.recvFromSerdes()
}
func (s *speaker[S, R, _]) recvFromSerdes() {
defer close(s.recvLoopDone)
defer close(s.requests)
for {
select {
case <-s.ctx.Done():
s.logger.Debug(s.ctx, "recvFromSerdes context done while waiting for proto", slog.Error(s.ctx.Err()))
return
case msg, ok := <-s.recvCh:
if !ok {
s.logger.Debug(s.ctx, "recvCh is closed")
return
}
rpc := msg.GetRpc()
if rpc != nil && rpc.ResponseTo != 0 {
// this is a unary response
s.tryToDeliverResponse(msg)
continue
}
req := &request[S, R]{
ctx: s.ctx,
msg: msg,
replyCh: s.sendCh,
}
select {
case <-s.ctx.Done():
s.logger.Debug(s.ctx, "recvFromSerdes context done while waiting for request handler", slog.Error(s.ctx.Err()))
return
case s.requests <- req:
}
}
}
}
// Close closes the speaker
// nolint: revive
func (s *speaker[_, _, _]) Close() error {
s.cancel()
err := s.serdes.Close()
return err
}
// unaryRPC sends a request/response style RPC over the protocol, waits for the response, then
// returns the response
func (s *speaker[S, R, _]) unaryRPC(ctx context.Context, req S) (resp R, err error) {
rpc := req.EnsureRPC()
msgID, respCh := s.newRPC()
rpc.MsgId = msgID
logger := s.logger.With(slog.F("msg_id", msgID))
select {
case <-ctx.Done():
return resp, ctx.Err()
case <-s.ctx.Done():
return resp, xerrors.Errorf("vpn protocol closed: %w", s.ctx.Err())
case <-s.recvLoopDone:
logger.Debug(s.ctx, "recvLoopDone while sending request")
return resp, io.ErrUnexpectedEOF
case s.sendCh <- req:
logger.Debug(s.ctx, "sent rpc request", slog.F("req", req))
}
select {
case <-ctx.Done():
s.rmResponseChan(msgID)
return resp, ctx.Err()
case <-s.ctx.Done():
s.rmResponseChan(msgID)
return resp, xerrors.Errorf("vpn protocol closed: %w", s.ctx.Err())
case <-s.recvLoopDone:
logger.Debug(s.ctx, "recvLoopDone while waiting for response")
return resp, io.ErrUnexpectedEOF
case resp = <-respCh:
logger.Debug(s.ctx, "got response", slog.F("resp", resp))
return resp, nil
}
}
func (s *speaker[_, R, _]) newRPC() (uint64, chan R) {
s.mu.Lock()
defer s.mu.Unlock()
msgID := s.nextMsgID
s.nextMsgID++
c := make(chan R)
s.responseChans[msgID] = c
return msgID, c
}
func (s *speaker[_, _, _]) rmResponseChan(msgID uint64) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.responseChans, msgID)
}
func (s *speaker[_, R, _]) tryToDeliverResponse(resp R) {
msgID := resp.GetRpc().GetResponseTo()
s.mu.Lock()
defer s.mu.Unlock()
c, ok := s.responseChans[msgID]
if ok {
c <- resp
// Remove the channel since we delivered a response. This ensures that each response channel
// gets _at most_ one response. Since the channels are buffered with size 1, send will
// never block.
delete(s.responseChans, msgID)
}
}
// handshake performs the initial CoderVPN protocol handshake over the given conn
func handshake(
ctx context.Context, conn io.ReadWriteCloser, logger slog.Logger, me, them SpeakerRole,
) error {
// read and write simultaneously to avoid deadlocking if the conn is not buffered
errCh := make(chan error, 2)
go func() {
ours := headerString(me, CurrentSupportedVersions)
_, err := conn.Write([]byte(ours))
logger.Debug(ctx, "wrote out header")
if err != nil {
err = xerrors.Errorf("write header: %w", err)
}
errCh <- err
}()
headerCh := make(chan string, 1)
go func() {
// we can't use bufio.Scanner here because we need to ensure we don't read beyond the
// first newline. So, we'll read one byte at a time. It's inefficient, but the initial
// header is only a few characters, so we'll keep this code simple.
buf := make([]byte, 256)
have := 0
for {
_, err := conn.Read(buf[have : have+1])
if err != nil {
errCh <- xerrors.Errorf("read header: %w", err)
return
}
if buf[have] == '\n' {
logger.Debug(ctx, "got newline header delimiter")
// use have (not have+1) since we don't want the delimiter for verification.
headerCh <- string(buf[:have])
return
}
have++
if have >= len(buf) {
errCh <- xerrors.Errorf("header malformed or too large: %s", string(buf))
return
}
}
}()
writeOK := false
theirHeader := ""
readOK := false
for !(readOK && writeOK) {
select {
case <-ctx.Done():
_ = conn.Close() // ensure our read/write goroutines get a chance to clean up
return ctx.Err()
case err := <-errCh:
if err == nil {
// write goroutine sends nil when completing successfully.
logger.Debug(ctx, "write ok")
writeOK = true
continue
}
_ = conn.Close()
return err
case theirHeader = <-headerCh:
logger.Debug(ctx, "read ok")
readOK = true
}
}
logger.Debug(ctx, "handshake read/write complete", slog.F("their_header", theirHeader))
gotVersion, err := validateHeader(theirHeader, them, CurrentSupportedVersions)
if err != nil {
return xerrors.Errorf("validate header (%s): %w", theirHeader, err)
}
logger.Debug(ctx, "handshake validated", slog.F("common_version", gotVersion))
// TODO: actually use the common version to perform different behavior once
// we have multiple versions
return nil
}
const headerPreamble = "codervpn"
func headerString(role SpeakerRole, versions RPCVersionList) string {
return fmt.Sprintf("%s %s %s\n", headerPreamble, role, versions.String())
}
func validateHeader(header string, expectedRole SpeakerRole, supportedVersions RPCVersionList) (RPCVersion, error) {
parts := strings.Split(header, " ")
if len(parts) != 3 {
return RPCVersion{}, xerrors.New("wrong number of parts")
}
if parts[0] != headerPreamble {
return RPCVersion{}, xerrors.New("invalid preamble")
}
if parts[1] != string(expectedRole) {
return RPCVersion{}, xerrors.New("unexpected role")
}
otherVersions, err := ParseRPCVersionList(parts[2])
if err != nil {
return RPCVersion{}, xerrors.Errorf("parse version list %q: %w", parts[2], err)
}
compatibleVersion, ok := supportedVersions.IsCompatibleWith(otherVersions)
if !ok {
return RPCVersion{},
xerrors.Errorf("current supported versions %q is not compatible with peer versions %q", supportedVersions.String(), otherVersions.String())
}
return compatibleVersion, nil
}
type request[S rpcMessage, R rpcMessage] struct {
ctx context.Context
msg R
replyCh chan<- S
}
func (r *request[S, _]) sendReply(reply S) error {
rrpc := reply.EnsureRPC()
mrpc := r.msg.GetRpc()
if mrpc == nil {
return xerrors.Errorf("message didn't want a reply")
}
rrpc.ResponseTo = mrpc.MsgId
select {
case <-r.ctx.Done():
return r.ctx.Err()
case r.replyCh <- reply:
}
return nil
}