@@ -10,68 +10,56 @@ import (
1010 "net/http"
1111 "net/textproto"
1212 "net/url"
13+ "nhooyr.io/websocket/internal/errd"
1314 "strings"
1415)
1516
1617// AcceptOptions represents the options available to pass to Accept.
1718type AcceptOptions struct {
18- // Subprotocols lists the websocket subprotocols that Accept will negotiate with a client.
19+ // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
1920 // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
20- // reject it, close the connection if c.Subprotocol() == "".
21+ // reject it, close the connection when c.Subprotocol() == "".
2122 Subprotocols []string
2223
23- // InsecureSkipVerify disables Accept's origin verification
24- // behaviour. By default Accept only allows the handshake to
25- // succeed if the javascript that is initiating the handshake
26- // is on the same domain as the server. This is to prevent CSRF
27- // attacks when secure data is stored in a cookie as there is no same
28- // origin policy for WebSockets. In other words, javascript from
29- // any domain can perform a WebSocket dial on an arbitrary server.
30- // This dial will include cookies which means the arbitrary javascript
31- // can perform actions as the authenticated user.
24+ // InsecureSkipVerify disables Accept's origin verification behaviour. By default,
25+ // the connection will only be accepted if the request origin is equal to the request
26+ // host.
27+ //
28+ // This is only required if you want javascript served from a different domain
29+ // to access your WebSocket server.
3230 //
3331 // See https://stackoverflow.com/a/37837709/4283659
3432 //
35- // The only time you need this is if your javascript is running on a different domain
36- // than your WebSocket server.
37- // Think carefully about whether you really need this option before you use it.
38- // If you do, remember that if you store secure data in cookies, you wil need to verify the
39- // Origin header yourself otherwise you are exposing yourself to a CSRF attack.
33+ // Please ensure you understand the ramifications of enabling this.
34+ // If used incorrectly your WebSocket server will be open to CSRF attacks.
4035 InsecureSkipVerify bool
4136
4237 // CompressionMode sets the compression mode.
43- // See docs on the CompressionMode type and defined constants .
38+ // See docs on the CompressionMode type.
4439 CompressionMode CompressionMode
4540}
4641
47- // Accept accepts a WebSocket HTTP handshake from a client and upgrades the
42+ // Accept accepts a WebSocket handshake from a client and upgrades the
4843// the connection to a WebSocket.
4944//
50- // Accept will reject the handshake if the Origin domain is not the same as the Host unless
51- // the InsecureSkipVerify option is set. In other words, by default it does not allow
52- // cross origin requests.
45+ // Accept will not allow cross origin requests by default.
46+ // See the InsecureSkipVerify option to allow cross origin requests.
5347//
54- // If an error occurs, Accept will write a response with a safe error message to w .
48+ // Accept will write a response to w on all errors .
5549func Accept (w http.ResponseWriter , r * http.Request , opts * AcceptOptions ) (* Conn , error ) {
56- c , err := accept (w , r , opts )
57- if err != nil {
58- return nil , fmt .Errorf ("failed to accept websocket connection: %w" , err )
59- }
60- return c , nil
50+ return accept (w , r , opts )
6151}
6252
63- func (opts * AcceptOptions ) ensure () * AcceptOptions {
53+ func accept (w http.ResponseWriter , r * http.Request , opts * AcceptOptions ) (_ * Conn , err error ) {
54+ defer errd .Wrap (& err , "failed to accept WebSocket connection" )
55+
6456 if opts == nil {
65- return & AcceptOptions {}
57+ opts = & AcceptOptions {}
6658 }
67- return opts
68- }
69-
70- func accept (w http.ResponseWriter , r * http.Request , opts * AcceptOptions ) (* Conn , error ) {
71- opts = opts .ensure ()
7259
73- err : = verifyClientRequest (w , r )
60+ err = verifyClientRequest (r )
7461 if err != nil {
62+ http .Error (w , err .Error (), http .StatusBadRequest )
7563 return nil , err
7664 }
7765
@@ -85,15 +73,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
8573
8674 hj , ok := w .(http.Hijacker )
8775 if ! ok {
88- err = errors .New ("passed ResponseWriter does not implement http.Hijacker" )
76+ err = errors .New ("http. ResponseWriter does not implement http.Hijacker" )
8977 http .Error (w , http .StatusText (http .StatusNotImplemented ), http .StatusNotImplemented )
9078 return nil , err
9179 }
9280
9381 w .Header ().Set ("Upgrade" , "websocket" )
9482 w .Header ().Set ("Connection" , "Upgrade" )
9583
96- handleSecWebSocketKey (w , r )
84+ key := r .Header .Get ("Sec-WebSocket-Key" )
85+ w .Header ().Set ("Sec-WebSocket-Accept" , secWebSocketAccept (key ))
9786
9887 subproto := selectSubprotocol (r , opts .Subprotocols )
9988 if subproto != "" {
@@ -102,7 +91,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
10291
10392 copts , err := acceptCompression (r , w , opts .CompressionMode )
10493 if err != nil {
105- http .Error (w , err .Error (), http .StatusBadRequest )
10694 return nil , err
10795 }
10896
@@ -129,72 +117,50 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
129117 }), nil
130118}
131119
132- func verifyClientRequest (w http. ResponseWriter , r * http.Request ) error {
120+ func verifyClientRequest (r * http.Request ) error {
133121 if ! r .ProtoAtLeast (1 , 1 ) {
134- err := fmt .Errorf ("websocket protocol violation: handshake request must be at least HTTP/1.1: %q" , r .Proto )
135- http .Error (w , err .Error (), http .StatusBadRequest )
136- return err
122+ return fmt .Errorf ("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q" , r .Proto )
137123 }
138124
139125 if ! headerContainsToken (r .Header , "Connection" , "Upgrade" ) {
140- err := fmt .Errorf ("websocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
141- http .Error (w , err .Error (), http .StatusBadRequest )
142- return err
126+ return fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
143127 }
144128
145- if ! headerContainsToken (r .Header , "Upgrade" , "WebSocket" ) {
146- err := fmt .Errorf ("websocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
147- http .Error (w , err .Error (), http .StatusBadRequest )
148- return err
129+ if ! headerContainsToken (r .Header , "Upgrade" , "websocket" ) {
130+ return fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
149131 }
150132
151133 if r .Method != "GET" {
152- err := fmt .Errorf ("websocket protocol violation: handshake request method is not GET but %q" , r .Method )
153- http .Error (w , err .Error (), http .StatusBadRequest )
154- return err
134+ return fmt .Errorf ("WebSocket protocol violation: handshake request method is not GET but %q" , r .Method )
155135 }
156136
157137 if r .Header .Get ("Sec-WebSocket-Version" ) != "13" {
158- err := fmt .Errorf ("unsupported websocket protocol version (only 13 is supported): %q" , r .Header .Get ("Sec-WebSocket-Version" ))
159- http .Error (w , err .Error (), http .StatusBadRequest )
160- return err
138+ return fmt .Errorf ("unsupported WebSocket protocol version (only 13 is supported): %q" , r .Header .Get ("Sec-WebSocket-Version" ))
161139 }
162140
163141 if r .Header .Get ("Sec-WebSocket-Key" ) == "" {
164- err := errors .New ("websocket protocol violation: missing Sec-WebSocket-Key" )
165- http .Error (w , err .Error (), http .StatusBadRequest )
166- return err
142+ return errors .New ("WebSocket protocol violation: missing Sec-WebSocket-Key" )
167143 }
168144
169145 return nil
170146}
171147
172148func authenticateOrigin (r * http.Request ) error {
173149 origin := r .Header .Get ("Origin" )
174- if origin == "" {
175- return nil
176- }
177- u , err := url .Parse (origin )
178- if err != nil {
179- return fmt .Errorf ("failed to parse Origin header %q: %w" , origin , err )
180- }
181- if ! strings .EqualFold (u .Host , r .Host ) {
182- return fmt .Errorf ("request Origin %q is not authorized for Host %q" , origin , r .Host )
150+ if origin != "" {
151+ u , err := url .Parse (origin )
152+ if err != nil {
153+ return fmt .Errorf ("failed to parse Origin header %q: %w" , origin , err )
154+ }
155+ if ! strings .EqualFold (u .Host , r .Host ) {
156+ return fmt .Errorf ("request Origin %q is not authorized for Host %q" , origin , r .Host )
157+ }
183158 }
184159 return nil
185160}
186161
187- func handleSecWebSocketKey (w http.ResponseWriter , r * http.Request ) {
188- key := r .Header .Get ("Sec-WebSocket-Key" )
189- w .Header ().Set ("Sec-WebSocket-Accept" , secWebSocketAccept (key ))
190- }
191-
192162func selectSubprotocol (r * http.Request , subprotocols []string ) string {
193163 cps := headerTokens (r .Header , "Sec-WebSocket-Protocol" )
194- if len (cps ) == 0 {
195- return ""
196- }
197-
198164 for _ , sp := range subprotocols {
199165 for _ , cp := range cps {
200166 if strings .EqualFold (sp , cp ) {
@@ -236,7 +202,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
236202 continue
237203 }
238204
239- return nil , fmt .Errorf ("unsupported permessage-deflate parameter: %q" , p )
205+ err := fmt .Errorf ("unsupported permessage-deflate parameter: %q" , p )
206+ http .Error (w , err .Error (), http .StatusBadRequest )
207+ return nil , err
240208 }
241209
242210 copts .setHeader (w .Header ())
@@ -264,7 +232,9 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com
264232 //
265233 // Either way, we're only implementing this for webkit which never sends the max_window_bits
266234 // parameter so we don't need to worry about it.
267- return nil , fmt .Errorf ("unsupported x-webkit-deflate-frame parameter: %q" , p )
235+ err := fmt .Errorf ("unsupported x-webkit-deflate-frame parameter: %q" , p )
236+ http .Error (w , err .Error (), http .StatusBadRequest )
237+ return nil , err
268238 }
269239
270240 s := "x-webkit-deflate-frame"
0 commit comments