88
99#include " openssl/sha.h" // Sha-1 hash
1010
11+ #include < map>
1112#include < string.h>
12- #include < vector>
1313
1414#define ACCEPT_KEY_LENGTH base64_encoded_size (20 )
1515#define BUFFER_GROWTH_CHUNK_SIZE 1024
@@ -63,7 +63,7 @@ class ProtocolHandler {
6363 virtual void Write (const std::vector<char > data) = 0;
6464 virtual void CancelHandshake () = 0;
6565
66- std::string GetHost ();
66+ std::string GetHost () const ;
6767
6868 InspectorSocket* inspector () {
6969 return inspector_;
@@ -160,6 +160,48 @@ static void generate_accept_string(const std::string& client_key,
160160 node::base64_encode (hash, sizeof (hash), *buffer, sizeof (*buffer));
161161}
162162
163+ static bool IsOneOf (const std::string& host,
164+ const std::vector<std::string>& hosts) {
165+ for (const std::string& candidate : hosts) {
166+ if (node::StringEqualNoCase (host.data (), candidate.data ()))
167+ return true ;
168+ }
169+ return false ;
170+ }
171+
172+ static std::string TrimPort (const std::string& host) {
173+ size_t last_colon_pos = host.rfind (" :" );
174+ if (last_colon_pos == std::string::npos)
175+ return host;
176+ size_t bracket = host.rfind (" ]" );
177+ if (bracket == std::string::npos || last_colon_pos > bracket)
178+ return host.substr (0 , last_colon_pos);
179+ return host;
180+ }
181+
182+ static bool IsIPAddress (const std::string& host) {
183+ if (host.length () >= 4 && host.front () == ' [' && host.back () == ' ]' )
184+ return true ;
185+ int quads = 0 ;
186+ for (char c : host) {
187+ if (c == ' .' )
188+ quads++;
189+ else if (!isdigit (c))
190+ return false ;
191+ }
192+ return quads == 3 ;
193+ }
194+
195+ // This is a value coming from the interface, it can only be IPv4 or IPv6
196+ // address string.
197+ static bool IsIPv4Localhost (const std::string& host) {
198+ std::string v6_tunnel_prefix = " ::ffff:" ;
199+ if (host.substr (0 , v6_tunnel_prefix.length ()) == v6_tunnel_prefix)
200+ return IsIPv4Localhost (host.substr (v6_tunnel_prefix.length ()));
201+ std::string localhost_net = " 127." ;
202+ return host.substr (0 , localhost_net.length ()) == localhost_net;
203+ }
204+
163205// Constants for hybi-10 frame format.
164206
165207typedef int OpCode;
@@ -298,7 +340,6 @@ static ws_decode_result decode_frame_hybi17(const std::vector<char>& buffer,
298340 return closed ? FRAME_CLOSE : FRAME_OK;
299341}
300342
301-
302343// WS protocol
303344class WsHandler : public ProtocolHandler {
304345 public:
@@ -400,17 +441,16 @@ class WsHandler : public ProtocolHandler {
400441// HTTP protocol
401442class HttpEvent {
402443 public:
403- HttpEvent (const std::string& path, bool upgrade,
404- bool isGET, const std::string& ws_key) : path(path),
405- upgrade (upgrade),
406- isGET(isGET),
407- ws_key(ws_key) { }
444+ HttpEvent (const std::string& path, bool upgrade, bool isGET,
445+ const std::string& ws_key, const std::string& host)
446+ : path(path), upgrade(upgrade), isGET(isGET), ws_key(ws_key),
447+ host (host) { }
408448
409449 std::string path;
410450 bool upgrade;
411451 bool isGET;
412452 std::string ws_key;
413- std::string current_header_ ;
453+ std::string host ;
414454};
415455
416456class HttpHandler : public ProtocolHandler {
@@ -472,18 +512,17 @@ class HttpHandler : public ProtocolHandler {
472512 std::vector<HttpEvent> events;
473513 std::swap (events, events_);
474514 for (const HttpEvent& event : events) {
475- bool shouldContinue = event.isGET && !event.upgrade ;
476- if (!event.isGET ) {
515+ if (!IsAllowedHost (event.host ) || !event.isGET ) {
477516 CancelHandshake ();
517+ return ;
478518 } else if (!event.upgrade ) {
479519 delegate ()->OnHttpGet (event.path );
480520 } else if (event.ws_key .empty ()) {
481521 CancelHandshake ();
522+ return ;
482523 } else {
483524 delegate ()->OnSocketUpgrade (event.path , event.ws_key );
484525 }
485- if (!shouldContinue)
486- return ;
487526 }
488527 }
489528
@@ -504,16 +543,9 @@ class HttpHandler : public ProtocolHandler {
504543 }
505544
506545 static int OnHeaderValue (http_parser* parser, const char * at, size_t length) {
507- static const char SEC_WEBSOCKET_KEY_HEADER[] = " Sec-WebSocket-Key" ;
508546 HttpHandler* handler = From (parser);
509547 handler->parsing_value_ = true ;
510- if (handler->current_header_ .size () ==
511- sizeof (SEC_WEBSOCKET_KEY_HEADER) - 1 &&
512- node::StringEqualNoCaseN (handler->current_header_ .data (),
513- SEC_WEBSOCKET_KEY_HEADER,
514- sizeof (SEC_WEBSOCKET_KEY_HEADER) - 1 )) {
515- handler->ws_key_ .append (at, length);
516- }
548+ handler->headers_ [handler->current_header_ ].append (at, length);
517549 return 0 ;
518550 }
519551
@@ -540,23 +572,53 @@ class HttpHandler : public ProtocolHandler {
540572 static int OnMessageComplete (http_parser* parser) {
541573 // Event needs to be fired after the parser is done.
542574 HttpHandler* handler = From (parser);
543- handler->events_ .push_back (HttpEvent (handler->path_ , parser->upgrade ,
544- parser->method == HTTP_GET,
545- handler->ws_key_ ));
575+ handler->events_ .push_back (
576+ HttpEvent (handler->path_ , parser->upgrade , parser->method == HTTP_GET,
577+ handler->HeaderValue (" Sec-WebSocket-Key" ),
578+ handler->HeaderValue (" Host" )));
546579 handler->path_ = " " ;
547- handler->ws_key_ = " " ;
548580 handler->parsing_value_ = false ;
581+ handler->headers_ .clear ();
549582 handler->current_header_ = " " ;
550-
551583 return 0 ;
552584 }
553585
586+ std::string HeaderValue (const std::string& header) const {
587+ bool header_found = false ;
588+ std::string value;
589+ for (const auto & header_value : headers_) {
590+ if (node::StringEqualNoCaseN (header_value.first .data (), header.data (),
591+ header.length ())) {
592+ if (header_found)
593+ return " " ;
594+ value = header_value.second ;
595+ header_found = true ;
596+ }
597+ }
598+ return value;
599+ }
600+
601+ bool IsAllowedHost (const std::string& host_with_port) const {
602+ std::string host = TrimPort (host_with_port);
603+ if (host.empty ())
604+ return false ;
605+ if (IsIPAddress (host))
606+ return true ;
607+ std::string socket_host = GetHost ();
608+ if (IsIPv4Localhost (socket_host)) {
609+ return IsOneOf (host, { " localhost" });
610+ } else if (socket_host == " ::1" ) {
611+ return IsOneOf (host, { " localhost" , " localhost6" });
612+ }
613+ return true ;
614+ }
615+
554616 bool parsing_value_;
555617 http_parser parser_;
556618 http_parser_settings parser_settings;
557619 std::vector<HttpEvent> events_;
558620 std::string current_header_;
559- std::string ws_key_ ;
621+ std::map<std:: string, std::string> headers_ ;
560622 std::string path_;
561623};
562624
@@ -579,7 +641,7 @@ InspectorSocket::Delegate* ProtocolHandler::delegate() {
579641 return tcp_->delegate ();
580642}
581643
582- std::string ProtocolHandler::GetHost () {
644+ std::string ProtocolHandler::GetHost () const {
583645 char ip[INET6_ADDRSTRLEN];
584646 sockaddr_storage addr;
585647 int len = sizeof (addr);
@@ -622,8 +684,6 @@ TcpHolder::Pointer TcpHolder::Accept(
622684 if (err == 0 ) {
623685 return { result, DisconnectAndDispose };
624686 } else {
625- fprintf (stderr, " [%s:%d@%s]\n " , __FILE__, __LINE__, __FUNCTION__);
626-
627687 delete result;
628688 return { nullptr , nullptr };
629689 }
0 commit comments