forked from coder/coder
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstatus_writer.go
More file actions
130 lines (111 loc) · 2.97 KB
/
status_writer.go
File metadata and controls
130 lines (111 loc) · 2.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package tracing
import (
"bufio"
"flag"
"fmt"
"log"
"net"
"net/http"
"runtime"
"strings"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/buildinfo"
)
var (
_ http.ResponseWriter = (*StatusWriter)(nil)
_ http.Hijacker = (*StatusWriter)(nil)
)
// StatusWriter intercepts the status of the request and the response body up
// to maxBodySize if Status >= 400. It is guaranteed to be the ResponseWriter
// directly downstream from Middleware.
type StatusWriter struct {
http.ResponseWriter
Status int
Hijacked bool
responseBody []byte
wroteHeader bool
wroteHeaderStack string
}
func StatusWriterMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
sw := &StatusWriter{ResponseWriter: rw}
next.ServeHTTP(sw, r)
})
}
func (w *StatusWriter) WriteHeader(status int) {
if buildinfo.IsDev() || flag.Lookup("test.v") != nil {
if w.wroteHeader {
stack := getStackString(2)
wroteHeaderStack := w.wroteHeaderStack
if wroteHeaderStack == "" {
wroteHeaderStack = "unknown"
}
// It's fine that this logs to stdlib logger since it only happens
// in dev builds and tests.
log.Printf("duplicate call to (*StatusWriter.).WriteHeader(%d):\n\nstack: %s\n\nheader written at: %s", status, stack, wroteHeaderStack)
} else {
w.wroteHeaderStack = getStackString(2)
}
}
if !w.wroteHeader {
w.Status = status
w.wroteHeader = true
}
w.ResponseWriter.WriteHeader(status)
}
func (w *StatusWriter) Write(b []byte) (int, error) {
const maxBodySize = 4096
if !w.wroteHeader {
w.Status = http.StatusOK
w.wroteHeader = true
}
if w.Status >= http.StatusBadRequest {
// This is technically wrong as multiple calls to write
// will simply overwrite w.ResponseBody but given that
// we typically only write to the response body once
// and this field is only used for logging I'm leaving
// this as-is.
w.responseBody = make([]byte, minInt(len(b), maxBodySize))
copy(w.responseBody, b)
}
return w.ResponseWriter.Write(b)
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, xerrors.Errorf("%T is not a http.Hijacker", w.ResponseWriter)
}
w.Hijacked = true
return hijacker.Hijack()
}
func (w *StatusWriter) ResponseBody() []byte {
return w.responseBody
}
func (w *StatusWriter) Flush() {
f, ok := w.ResponseWriter.(http.Flusher)
if !ok {
panic("http.ResponseWriter is not http.Flusher")
}
f.Flush()
}
func getStackString(skip int) string {
// Get up to 5 callers, skipping this one and the skip count.
pcs := make([]uintptr, 5)
got := runtime.Callers(skip+1, pcs)
frames := runtime.CallersFrames(pcs[:got])
callers := []string{}
for {
frame, more := frames.Next()
callers = append(callers, fmt.Sprintf("%s:%v", frame.File, frame.Line))
if !more {
break
}
}
return strings.Join(callers, " -> ")
}