-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathwebpush.go
More file actions
536 lines (469 loc) · 17.1 KB
/
webpush.go
File metadata and controls
536 lines (469 loc) · 17.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
package webpush
import (
"context"
"database/sql"
"encoding/json"
"errors"
"io"
"net"
"net/http"
"net/netip"
"slices"
"sync"
"syscall"
"time"
"github.com/SherClockHolmes/webpush-go"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"tailscale.com/util/singleflight"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
const defaultSubscriptionCacheTTL = 3 * time.Minute
// isStaleSubscriptionStatus reports whether a status code from a push
// service indicates that the subscription is permanently invalid and
// should be removed from the database. Other 4xx and 5xx responses
// (rate limits, transient failures) leave the subscription in place
// so it can be retried on the next dispatch.
func isStaleSubscriptionStatus(statusCode int) bool {
switch statusCode {
case http.StatusBadRequest, // 400: malformed subscription per the push service.
http.StatusForbidden, // 403: Apple BadJwtToken / VAPID rejected, key rotation.
http.StatusNotFound, // 404: FCM/Mozilla endpoint no longer valid.
http.StatusGone: // 410: standard "subscription expired" signal.
return true
}
return false
}
// Dispatcher is an interface that can be used to dispatch
// web push notifications to clients such as browsers.
type Dispatcher interface {
// Dispatch sends a web push notification to all subscriptions
// for a user. Any notifications that fail to send are silently dropped.
Dispatch(ctx context.Context, userID uuid.UUID, notification codersdk.WebpushMessage) error
// Test sends a test web push notificatoin to a subscription to ensure it is valid.
Test(ctx context.Context, req codersdk.WebpushSubscription) error
// PublicKey returns the VAPID public key for the webpush dispatcher.
PublicKey() string
}
// SubscriptionCacheInvalidator is an optional interface that lets local
// subscription mutation handlers invalidate cached subscriptions.
type SubscriptionCacheInvalidator interface {
InvalidateUser(userID uuid.UUID)
}
type options struct {
clock quartz.Clock
subscriptionCacheTTL time.Duration
httpClient *http.Client
}
// Option configures optional behavior for a Webpusher.
type Option func(*options)
// WithClock sets the clock used by the subscription cache. Defaults to a real
// clock when not provided.
func WithClock(clock quartz.Clock) Option {
return func(o *options) {
o.clock = clock
}
}
// WithSubscriptionCacheTTL sets the in-memory subscription cache TTL. Defaults
// to three minutes when not provided or when given a non-positive duration.
func WithSubscriptionCacheTTL(ttl time.Duration) Option {
return func(o *options) {
o.subscriptionCacheTTL = ttl
}
}
// WithHTTPClient overrides the default SSRF-safe HTTP client used to deliver
// push notifications. This is intended for tests that need to deliver to
// localhost test servers.
func WithHTTPClient(client *http.Client) Option {
return func(o *options) {
o.httpClient = client
}
}
// New creates a new Dispatcher to dispatch web push notifications.
//
// This is *not* integrated into the enqueue system unfortunately.
// That's because the notifications system has a enqueue system,
// and push notifications at time of implementation are being used
// for updates inside of a workspace, which we want to be immediate.
//
// See: https://github.com/coder/internal/issues/528
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string, opts ...Option) (Dispatcher, error) {
cfg := options{
clock: quartz.NewReal(),
subscriptionCacheTTL: defaultSubscriptionCacheTTL,
}
for _, opt := range opts {
opt(&cfg)
}
if cfg.clock == nil {
cfg.clock = quartz.NewReal()
}
if cfg.subscriptionCacheTTL <= 0 {
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
}
if cfg.httpClient == nil {
cfg.httpClient = newSSRFSafeHTTPClient()
}
keys, err := db.GetWebpushVAPIDKeys(ctx)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get notification vapid keys: %w", err)
}
}
if keys.VapidPublicKey == "" || keys.VapidPrivateKey == "" {
// Generate new VAPID keys. This also deletes all existing push
// subscriptions as part of the transaction, as they are no longer
// valid.
newPrivateKey, newPublicKey, err := RegenerateVAPIDKeys(ctx, db)
if err != nil {
return nil, xerrors.Errorf("regenerate vapid keys: %w", err)
}
keys.VapidPublicKey = newPublicKey
keys.VapidPrivateKey = newPrivateKey
}
return &Webpusher{
vapidSub: vapidSub,
store: db,
log: log,
VAPIDPublicKey: keys.VapidPublicKey,
VAPIDPrivateKey: keys.VapidPrivateKey,
clock: cfg.clock,
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
subscriptionGenerations: make(map[uuid.UUID]uint64),
httpClient: cfg.httpClient,
}, nil
}
type cachedSubscriptions struct {
subscriptions []database.WebpushSubscription
expiresAt time.Time
}
type Webpusher struct {
store database.Store
log *slog.Logger
// VAPID allows us to identify the sender of the message.
// This must be a https:// URL or an email address.
// Some push services (such as Apple's) require this to be set.
vapidSub string
// public and private keys for VAPID. These are used to sign and encrypt
// the message payload.
VAPIDPublicKey string
VAPIDPrivateKey string
// httpClient is an SSRF-safe HTTP client that rejects connections to
// private, loopback, and link-local IP addresses at dial time. This
// closes the DNS rebinding TOCTOU gap where a hostname passes URL
// validation but resolves to a private IP when the connection is made.
httpClient *http.Client
clock quartz.Clock
cacheMu sync.RWMutex
subscriptionCache map[uuid.UUID]cachedSubscriptions
subscriptionGenerations map[uuid.UUID]uint64
subscriptionCacheTTL time.Duration
subscriptionFetches singleflight.Group[string, []database.WebpushSubscription]
}
func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
subscriptions, err := n.subscriptionsForUser(ctx, userID)
if err != nil {
return xerrors.Errorf("get web push subscriptions by user ID: %w", err)
}
if len(subscriptions) == 0 {
return nil
}
msgJSON, err := json.Marshal(msg)
if err != nil {
return xerrors.Errorf("marshal webpush notification: %w", err)
}
cleanupSubscriptions := make([]uuid.UUID, 0)
var mu sync.Mutex
var eg errgroup.Group
for _, subscription := range subscriptions {
eg.Go(func() error {
// TODO: Implement some retry logic here. For now, this is just a
// best-effort attempt.
statusCode, body, err := n.webpushSend(ctx, msgJSON, subscription.Endpoint, webpush.Keys{
Auth: subscription.EndpointAuthKey,
P256dh: subscription.EndpointP256dhKey,
})
if err != nil {
return xerrors.Errorf("send webpush notification: %w", err)
}
if isStaleSubscriptionStatus(statusCode) {
// Remove subscriptions that the push service has marked as
// permanently invalid (Apple returns 403 BadJwtToken and 404
// for invalidated subscriptions, FCM returns 404 for
// expired endpoints, all push services return 410 for
// permanently gone subscriptions, and 400 indicates a
// malformed subscription that cannot be retried). Without
// this, stale rows accumulate after PWA reinstalls and the
// in-memory cache keeps trying to deliver to dead
// subscriptions.
mu.Lock()
cleanupSubscriptions = append(cleanupSubscriptions, subscription.ID)
mu.Unlock()
}
if statusCode == http.StatusGone {
// 410 Gone is informational, not a delivery error.
return nil
}
// 200, 201, and 202 are common for successful delivery.
if statusCode > http.StatusAccepted {
// It's likely the subscription failed to deliver for some reason.
return xerrors.Errorf("web push dispatch failed with status code %d: %s", statusCode, string(body))
}
return nil
})
}
dispatchErr := eg.Wait()
// Always remove subscriptions that the push service rejected as
// permanently invalid, even when sibling deliveries returned a
// non-stale error. The cleanup must run before the error return so a
// transient delivery failure on one subscription cannot block the
// deletion of a 410/404/403/400 sibling. Without this ordering,
// stale rows accumulate after PWA reinstalls and silently mask the
// new subscription on every subsequent dispatch.
n.cleanupStaleSubscriptions(ctx, userID, cleanupSubscriptions)
if dispatchErr != nil {
return xerrors.Errorf("send webpush notifications: %w", dispatchErr)
}
return nil
}
// cleanupStaleSubscriptions deletes the rows the push service flagged as
// permanently invalid (see isStaleSubscriptionStatus) and clears the cached
// entries for the affected user. Failures are logged at error level rather
// than returned: the caller is in the middle of returning a delivery error
// and shouldn't have its error shadowed by a cleanup failure. The cache
// prune is gated on a successful database delete so a partial state cannot
// leak into the cache.
func (n *Webpusher) cleanupStaleSubscriptions(ctx context.Context, userID uuid.UUID, ids []uuid.UUID) {
if len(ids) == 0 {
return
}
// nolint:gocritic // These are known to be invalid subscriptions.
if err := n.store.DeleteWebpushSubscriptions(dbauthz.AsNotifier(ctx), ids); err != nil {
n.log.Error(ctx, "failed to delete stale push subscriptions", slog.Error(err))
return
}
n.pruneSubscriptions(userID, ids)
}
func (n *Webpusher) subscriptionsForUser(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
if subscriptions, ok := n.cachedSubscriptions(userID); ok {
return subscriptions, nil
}
subscriptions, err, _ := n.subscriptionFetches.Do(userID.String(), func() ([]database.WebpushSubscription, error) {
if cached, ok := n.cachedSubscriptions(userID); ok {
return cached, nil
}
generation := n.subscriptionGeneration(userID)
fetched, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
if err != nil {
return nil, err
}
n.storeSubscriptions(userID, generation, fetched)
return slices.Clone(fetched), nil
})
if err != nil {
return nil, err
}
return slices.Clone(subscriptions), nil
}
func (n *Webpusher) cachedSubscriptions(userID uuid.UUID) ([]database.WebpushSubscription, bool) {
n.cacheMu.RLock()
entry, ok := n.subscriptionCache[userID]
n.cacheMu.RUnlock()
if !ok {
return nil, false
}
if n.clock.Now().Before(entry.expiresAt) {
return slices.Clone(entry.subscriptions), true
}
n.cacheMu.Lock()
if current, ok := n.subscriptionCache[userID]; ok && !n.clock.Now().Before(current.expiresAt) {
delete(n.subscriptionCache, userID)
}
n.cacheMu.Unlock()
return nil, false
}
func (n *Webpusher) subscriptionGeneration(userID uuid.UUID) uint64 {
n.cacheMu.RLock()
generation := n.subscriptionGenerations[userID]
n.cacheMu.RUnlock()
return generation
}
func (n *Webpusher) storeSubscriptions(userID uuid.UUID, generation uint64, subscriptions []database.WebpushSubscription) {
n.cacheMu.Lock()
defer n.cacheMu.Unlock()
if n.subscriptionGenerations[userID] != generation {
return
}
n.subscriptionCache[userID] = cachedSubscriptions{
subscriptions: slices.Clone(subscriptions),
expiresAt: n.clock.Now().Add(n.subscriptionCacheTTL),
}
}
func (n *Webpusher) pruneSubscriptions(userID uuid.UUID, staleIDs []uuid.UUID) {
if len(staleIDs) == 0 {
return
}
stale := make(map[uuid.UUID]struct{}, len(staleIDs))
for _, id := range staleIDs {
stale[id] = struct{}{}
}
n.cacheMu.Lock()
defer n.cacheMu.Unlock()
entry, ok := n.subscriptionCache[userID]
if !ok {
return
}
if !n.clock.Now().Before(entry.expiresAt) {
delete(n.subscriptionCache, userID)
return
}
filtered := make([]database.WebpushSubscription, 0, len(entry.subscriptions))
for _, subscription := range entry.subscriptions {
if _, shouldDelete := stale[subscription.ID]; shouldDelete {
continue
}
filtered = append(filtered, subscription)
}
if len(filtered) == 0 {
delete(n.subscriptionCache, userID)
return
}
entry.subscriptions = filtered
n.subscriptionCache[userID] = entry
}
// InvalidateUser clears the cached subscriptions for a user and advances
// its invalidation generation. Local subscribe and unsubscribe handlers call
// this after mutating subscriptions in the same process.
func (n *Webpusher) InvalidateUser(userID uuid.UUID) {
n.cacheMu.Lock()
delete(n.subscriptionCache, userID)
n.subscriptionGenerations[userID]++
n.cacheMu.Unlock()
n.subscriptionFetches.Forget(userID.String())
}
func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string, keys webpush.Keys) (int, []byte, error) {
// Copy the message to avoid modifying the original.
cpy := slices.Clone(msg)
resp, err := webpush.SendNotificationWithContext(ctx, cpy, &webpush.Subscription{
Endpoint: endpoint,
Keys: keys,
}, &webpush.Options{
HTTPClient: n.httpClient,
Subscriber: n.vapidSub,
VAPIDPublicKey: n.VAPIDPublicKey,
VAPIDPrivateKey: n.VAPIDPrivateKey,
})
if err != nil {
n.log.Error(ctx, "failed to send webpush notification", slog.Error(err), slog.F("endpoint", endpoint))
return -1, nil, xerrors.Errorf("send webpush notification: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return -1, nil, xerrors.Errorf("read response body: %w", err)
}
return resp.StatusCode, body, nil
}
func (n *Webpusher) Test(ctx context.Context, req codersdk.WebpushSubscription) error {
msgJSON, err := json.Marshal(codersdk.WebpushMessage{
Title: "It's working!",
Body: "You've subscribed to push notifications.",
})
if err != nil {
return xerrors.Errorf("marshal webpush notification: %w", err)
}
statusCode, body, err := n.webpushSend(ctx, msgJSON, req.Endpoint, webpush.Keys{
Auth: req.AuthKey,
P256dh: req.P256DHKey,
})
if err != nil {
return xerrors.Errorf("send test webpush notification: %w", err)
}
// 200, 201, and 202 are common for successful delivery.
if statusCode > http.StatusAccepted {
// It's likely the subscription failed to deliver for some reason.
return xerrors.Errorf("web push dispatch failed with status code %d: %s", statusCode, string(body))
}
return nil
}
// PublicKey returns the VAPID public key for the webpush dispatcher.
// Clients need this, so it's exposed via the BuildInfo endpoint.
func (n *Webpusher) PublicKey() string {
return n.VAPIDPublicKey
}
// NoopWebpusher is a Dispatcher that always fails, returning Msg as
// the error. It is used as a fallback when VAPID key setup fails.
// The underlying error is not included to avoid leaking internal
// details (e.g. database errors) in API responses; it is logged at
// the call site instead.
type NoopWebpusher struct {
Msg string
}
func (n *NoopWebpusher) Dispatch(context.Context, uuid.UUID, codersdk.WebpushMessage) error {
return xerrors.New(n.Msg)
}
func (n *NoopWebpusher) Test(context.Context, codersdk.WebpushSubscription) error {
return xerrors.New(n.Msg)
}
func (*NoopWebpusher) PublicKey() string {
return ""
}
// newSSRFSafeHTTPClient returns an HTTP client that rejects connections to
// private, loopback, link-local, multicast, and unspecified IP addresses.
// This prevents DNS rebinding attacks where a hostname passes URL-level
// validation but resolves to an internal IP at dial time.
func newSSRFSafeHTTPClient() *http.Client {
return &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Control: func(_ string, address string, _ syscall.RawConn) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return xerrors.Errorf("split host/port: %w", err)
}
ip, err := netip.ParseAddr(host)
if err != nil {
return xerrors.Errorf("parse resolved IP: %w", err)
}
if ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsMulticast() ||
ip.IsUnspecified() {
return xerrors.Errorf(
"webpush endpoint resolved to non-public address %s", ip.String(),
)
}
return nil
},
}).DialContext,
},
}
}
// RegenerateVAPIDKeys regenerates the VAPID keys and deletes all existing
// push subscriptions as part of the transaction, as they are no longer valid.
func RegenerateVAPIDKeys(ctx context.Context, db database.Store) (newPrivateKey string, newPublicKey string, err error) {
newPrivateKey, newPublicKey, err = webpush.GenerateVAPIDKeys()
if err != nil {
return "", "", xerrors.Errorf("generate new vapid keypair: %w", err)
}
if txErr := db.InTx(func(tx database.Store) error {
if err := tx.DeleteAllWebpushSubscriptions(ctx); err != nil {
return xerrors.Errorf("delete all webpush subscriptions: %w", err)
}
if err := tx.UpsertWebpushVAPIDKeys(ctx, database.UpsertWebpushVAPIDKeysParams{
VapidPrivateKey: newPrivateKey,
VapidPublicKey: newPublicKey,
}); err != nil {
return xerrors.Errorf("upsert notification vapid key: %w", err)
}
return nil
}, nil); txErr != nil {
return "", "", xerrors.Errorf("regenerate vapid keypair: %w", txErr)
}
return newPrivateKey, newPublicKey, nil
}