forked from coder/coder
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhttpapi.go
More file actions
295 lines (259 loc) · 7.82 KB
/
httpapi.go
File metadata and controls
295 lines (259 loc) · 7.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
package httpapi
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
"strings"
"time"
"github.com/go-playground/validator/v10"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/codersdk"
)
var Validate *validator.Validate
// This init is used to create a validator and register validation-specific
// functionality for the HTTP API.
//
// A single validator instance is used, because it caches struct parsing.
func init() {
Validate = validator.New()
Validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
nameValidator := func(fl validator.FieldLevel) bool {
f := fl.Field().Interface()
str, ok := f.(string)
if !ok {
return false
}
valid := NameValid(str)
return valid == nil
}
for _, tag := range []string{"username", "template_name", "workspace_name"} {
err := Validate.RegisterValidation(tag, nameValidator)
if err != nil {
panic(err)
}
}
templateDisplayNameValidator := func(fl validator.FieldLevel) bool {
f := fl.Field().Interface()
str, ok := f.(string)
if !ok {
return false
}
valid := TemplateDisplayNameValid(str)
return valid == nil
}
err := Validate.RegisterValidation("template_display_name", templateDisplayNameValidator)
if err != nil {
panic(err)
}
}
// Convenience error functions don't take contexts since their responses are
// static, it doesn't make much sense to trace them.
// ResourceNotFound is intentionally vague. All 404 responses should be identical
// to prevent leaking existence of resources.
func ResourceNotFound(rw http.ResponseWriter) {
Write(context.Background(), rw, http.StatusNotFound, codersdk.Response{
Message: "Resource not found or you do not have access to this resource",
})
}
func Forbidden(rw http.ResponseWriter) {
Write(context.Background(), rw, http.StatusForbidden, codersdk.Response{
Message: "Forbidden.",
})
}
func InternalServerError(rw http.ResponseWriter, err error) {
var details string
if err != nil {
details = err.Error()
}
Write(context.Background(), rw, http.StatusInternalServerError, codersdk.Response{
Message: "An internal server error occurred.",
Detail: details,
})
}
func RouteNotFound(rw http.ResponseWriter) {
Write(context.Background(), rw, http.StatusNotFound, codersdk.Response{
Message: "Route not found.",
})
}
// Write outputs a standardized format to an HTTP response body. ctx is used for
// tracing and can be nil for tracing to be disabled. Tracing this function is
// helpful because JSON marshaling can sometimes take a non-insignificant amount
// of time, and could help us catch outliers. Additionally, we can enrich span
// data a bit more since we have access to the actual interface{} we're
// marshaling, such as the number of elements in an array, which could help us
// spot routes that need to be paginated.
func Write(ctx context.Context, rw http.ResponseWriter, status int, response interface{}) {
_, span := tracing.StartSpan(ctx)
defer span.End()
buf := &bytes.Buffer{}
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(true)
err := enc.Encode(response)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
rw.WriteHeader(status)
_, err = rw.Write(buf.Bytes())
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
}
// Read decodes JSON from the HTTP request into the value provided. It uses
// go-validator to validate the incoming request body. ctx is used for tracing
// and can be nil. Although tracing this function isn't likely too helpful, it
// was done to be consistent with Write.
func Read(ctx context.Context, rw http.ResponseWriter, r *http.Request, value interface{}) bool {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
err := json.NewDecoder(r.Body).Decode(value)
if err != nil {
Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Request body must be valid JSON.",
Detail: err.Error(),
})
return false
}
err = Validate.Struct(value)
var validationErrors validator.ValidationErrors
if errors.As(err, &validationErrors) {
apiErrors := make([]codersdk.ValidationError, 0, len(validationErrors))
for _, validationError := range validationErrors {
apiErrors = append(apiErrors, codersdk.ValidationError{
Field: validationError.Field(),
Detail: fmt.Sprintf("Validation failed for tag %q with value: \"%v\"", validationError.Tag(), validationError.Value()),
})
}
Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Validation failed.",
Validations: apiErrors,
})
return false
}
if err != nil {
Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error validating request body payload.",
Detail: err.Error(),
})
return false
}
return true
}
const websocketCloseMaxLen = 123
// WebsocketCloseSprintf formats a websocket close message and ensures it is
// truncated to the maximum allowed length.
func WebsocketCloseSprintf(format string, vars ...any) string {
msg := fmt.Sprintf(format, vars...)
// Cap msg length at 123 bytes. nhooyr/websocket only allows close messages
// of this length.
if len(msg) > websocketCloseMaxLen {
// Trim the string to 123 bytes. If we accidentally cut in the middle of
// a UTF-8 character, remove it from the string.
return strings.ToValidUTF8(string(msg[123]), "")
}
return msg
}
func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent func(ctx context.Context, sse codersdk.ServerSentEvent) error, closed chan struct{}, err error) {
h := rw.Header()
h.Set("Content-Type", "text/event-stream")
h.Set("Cache-Control", "no-cache")
h.Set("Connection", "keep-alive")
h.Set("X-Accel-Buffering", "no")
f, ok := rw.(http.Flusher)
if !ok {
panic("http.ResponseWriter is not http.Flusher")
}
closed = make(chan struct{})
type sseEvent struct {
payload []byte
errC chan error
}
eventC := make(chan sseEvent)
// Synchronized handling of events (no guarantee of order).
go func() {
defer close(closed)
// Send a heartbeat every 15 seconds to avoid the connection being killed.
ticker := time.NewTicker(time.Second * 15)
defer ticker.Stop()
for {
var event sseEvent
select {
case <-r.Context().Done():
return
case event = <-eventC:
case <-ticker.C:
event = sseEvent{
payload: []byte(fmt.Sprintf("event: %s\n\n", codersdk.ServerSentEventTypePing)),
}
}
_, err := rw.Write(event.payload)
if event.errC != nil {
event.errC <- err
}
if err != nil {
return
}
f.Flush()
}
}()
sendEvent = func(ctx context.Context, sse codersdk.ServerSentEvent) error {
buf := &bytes.Buffer{}
enc := json.NewEncoder(buf)
_, err := buf.WriteString(fmt.Sprintf("event: %s\n", sse.Type))
if err != nil {
return err
}
if sse.Data != nil {
_, err = buf.WriteString("data: ")
if err != nil {
return err
}
err = enc.Encode(sse.Data)
if err != nil {
return err
}
}
err = buf.WriteByte('\n')
if err != nil {
return err
}
event := sseEvent{
payload: buf.Bytes(),
errC: make(chan error, 1), // Buffered to prevent deadlock.
}
select {
case <-r.Context().Done():
return r.Context().Err()
case <-ctx.Done():
return ctx.Err()
case <-closed:
return xerrors.New("server sent event sender closed")
case eventC <- event:
// Re-check closure signals after sending the event to allow
// for early exit. We don't check closed here because it
// can't happen while processing the event.
select {
case <-r.Context().Done():
return r.Context().Err()
case <-ctx.Done():
return ctx.Err()
case err := <-event.errC:
return err
}
}
}
return sendEvent, closed, nil
}