diff --git a/coderd/azureidentity/azureidentity.go b/coderd/azureidentity/azureidentity.go index 7a06ccbc51175..d16fab58b65b3 100644 --- a/coderd/azureidentity/azureidentity.go +++ b/coderd/azureidentity/azureidentity.go @@ -8,7 +8,9 @@ import ( "encoding/pem" "errors" "io" + "net" "net/http" + "net/url" "regexp" "sync" "time" @@ -25,6 +27,158 @@ var allowedSigners = regexp.MustCompile(`^(.*\.)?metadata\.(azure\.(com|us|cn)|m // each time a parse occurs. var pkcs7Mutex sync.Mutex +// allowedCertHosts contains the hosts Azure intermediate +// certificates are served from. Only these hosts are permitted +// when fetching issuing certificates referenced in the signer +// certificate. This prevents SSRF via crafted +// IssuingCertificateURL values. +// +// Source: https://learn.microsoft.com/en-us/azure/security/fundamentals/azure-ca-details +var allowedCertHosts = map[string]bool{ + "www.microsoft.com": true, + "cacerts.digicert.com": true, +} + +// maxCertResponseBytes is the maximum size of a certificate +// response body we will read. Azure intermediate certificates +// are typically under 4 KiB; 1 MiB is a generous upper bound +// that prevents memory exhaustion from malicious responses. +const maxCertResponseBytes = 1 << 20 // 1 MiB + +// extraBlockedNetworks lists special-use CIDR ranges that the +// stdlib classification methods (IsLoopback, IsPrivate, etc.) do +// not cover. Blocking these prevents SSRF against carrier-grade +// NAT, network-benchmarking, documentation, discard-only, and +// the all-zeros "this network" range. +// +// IPv6 ranges already handled by stdlib: +// - ::1/128 (IsLoopback) +// - fc00::/7 (IsPrivate, ULA) +// - fe80::/10 (IsLinkLocalUnicast) +// - ff00::/8 (IsMulticast) +// - ::/128 (IsUnspecified) +var extraBlockedNetworks []*net.IPNet + +func init() { + for _, cidr := range []string{ + // IPv4 special-use ranges. + "0.0.0.0/8", // RFC 1122 "this network". + "100.64.0.0/10", // RFC 6598 carrier-grade NAT. + "198.18.0.0/15", // RFC 2544 benchmarking. + + // IPv6 special-use ranges not covered by stdlib. + "64:ff9b:1::/48", // RFC 8215 IPv4/IPv6 translation. + "100::/64", // RFC 6666 discard-only. + "2001:2::/48", // RFC 5180 benchmarking. + "2001:db8::/32", // RFC 3849 documentation. + } { + _, network, _ := net.ParseCIDR(cidr) + extraBlockedNetworks = append(extraBlockedNetworks, network) + } +} + +// isPrivateIP reports whether the IP is on a network that must +// not be reachable when fetching certificates. IPv4-mapped IPv6 +// addresses are canonicalized to IPv4 first so a literal like +// ::ffff:169.254.169.254 cannot bypass the IPv4 ranges. +func isPrivateIP(ip net.IP) bool { + if v4 := ip.To4(); v4 != nil { + ip = v4 + } + if ip.IsLoopback() || + ip.IsPrivate() || + ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() || + ip.IsMulticast() || + ip.IsUnspecified() || + ip.IsInterfaceLocalMulticast() { + return true + } + for _, network := range extraBlockedNetworks { + if network.Contains(ip) { + return true + } + } + return false +} + +// certFetchClient is an HTTP client that refuses to connect +// to private or link-local IP addresses. This provides +// defense-in-depth against SSRF even if the host allowlist is +// somehow bypassed (e.g. via DNS rebinding). +var certFetchClient = &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, xerrors.Errorf("split host/port: %w", err) + } + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, xerrors.Errorf("resolve host: %w", err) + } + if len(ips) == 0 { + return nil, xerrors.Errorf("no addresses for %q", host) + } + // Reject up front so a single tainted answer + // short-circuits the dial rather than racing it. + for _, ip := range ips { + if isPrivateIP(ip.IP) { + return nil, xerrors.Errorf( + "certificate fetch blocked: %q resolved to private IP %s", + host, ip.IP, + ) + } + } + // Dial the validated IP directly. If we dialed by + // hostname here, Go's stdlib would re-resolve and a + // hostile resolver could swap in a private IP after + // validation (DNS rebinding). TLS verification still + // uses the URL host via the Transport's TLS config. + var d net.Dialer + var firstErr error + for _, ip := range ips { + conn, derr := d.DialContext(ctx, network, net.JoinHostPort(ip.IP.String(), port)) + if derr == nil { + return conn, nil + } + if firstErr == nil { + firstErr = derr + } + } + return nil, firstErr + }, + }, +} + +// IsAllowedCertificateURL reports whether rawURL points to a +// host on the allowlist, uses http or https, and targets a +// standard PKI distribution port. Microsoft and DigiCert serve +// these artifacts on 80/443 only; any other port is rejected to +// keep the SSRF surface as narrow as the hostname itself. +func IsAllowedCertificateURL(rawURL string) bool { + if rawURL == "" { + return false + } + u, err := url.Parse(rawURL) + if err != nil { + return false + } + if u.Scheme != "http" && u.Scheme != "https" { + return false + } + if !allowedCertHosts[u.Hostname()] { + return false + } + switch u.Port() { + case "", "80", "443": + return true + default: + return false + } +} + type metadata struct { VMID string `json:"vmId"` } @@ -95,29 +249,42 @@ func Validate(ctx context.Context, signature string, options Options) (string, e ctx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) defer cancelFunc() for _, certURL := range signer.IssuingCertificateURL { + if !IsAllowedCertificateURL(certURL) { + return "", xerrors.New("issuing certificate URL not on allowlist") + } req, err := http.NewRequestWithContext(ctx, "GET", certURL, nil) if err != nil { - return "", xerrors.Errorf("new request %q: %w", certURL, err) + return "", xerrors.New("construct certificate request") } - res, err := http.DefaultClient.Do(req) + res, err := certFetchClient.Do(req) if err != nil { - return "", xerrors.Errorf("no cached certificate for %q found. error fetching: %w", certURL, err) + return "", xerrors.New("certificate fetch unsuccessful") } - data, err := io.ReadAll(res.Body) + limited := io.LimitReader(res.Body, maxCertResponseBytes+1) + data, err := io.ReadAll(limited) + _ = res.Body.Close() if err != nil { - _ = res.Body.Close() - return "", xerrors.Errorf("read body %q: %w", certURL, err) + return "", xerrors.New("read certificate response body") + } + if int64(len(data)) > maxCertResponseBytes { + return "", xerrors.New( + "certificate response exceeds maximum size", + ) } - _ = res.Body.Close() cert, err := x509.ParseCertificate(data) if err != nil { - return "", xerrors.Errorf("parse certificate %q: %w", certURL, err) + // Do not wrap the parse error; it may contain + // fragments of the HTTP response body, which + // could leak internal data to the caller. + return "", xerrors.New( + "fetched data is not a valid certificate", + ) } options.Intermediates.AddCert(cert) } _, err = signer.Verify(options.VerifyOptions) if err != nil { - return "", err + return "", xerrors.New("signature verification failed after fetching issuing certificates") } } diff --git a/coderd/azureidentity/azureidentity_internal_test.go b/coderd/azureidentity/azureidentity_internal_test.go new file mode 100644 index 0000000000000..a4b9ddcdb4d93 --- /dev/null +++ b/coderd/azureidentity/azureidentity_internal_test.go @@ -0,0 +1,76 @@ +package azureidentity + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsPrivateIP(t *testing.T) { + t.Parallel() + cases := []struct { + name string + ip string + blocked bool + }{ + {"loopback v4", "127.0.0.1", true}, + {"loopback v6", "::1", true}, + {"link local v4 (azure metadata)", "169.254.169.254", true}, + {"link local v6", "fe80::1", true}, + {"rfc1918 10/8", "10.0.0.1", true}, + {"rfc1918 172.16/12", "172.16.0.1", true}, + {"rfc1918 192.168/16", "192.168.0.1", true}, + {"ipv6 ula", "fc00::1", true}, + {"unspecified v4", "0.0.0.0", true}, + {"unspecified v6", "::", true}, + {"this-network 0.0.0.0/8", "0.1.2.3", true}, + {"cgnat 100.64/10", "100.64.0.1", true}, + {"benchmarking 198.18/15", "198.18.0.1", true}, + {"multicast v4", "224.0.0.1", true}, + {"ipv6 nat64 well-known", "64:ff9b:1::1", true}, + {"ipv6 discard-only", "100::1", true}, + {"ipv6 benchmarking", "2001:2::1", true}, + {"ipv6 documentation", "2001:db8::1", true}, + // IPv4-mapped IPv6: must canonicalize to v4 before + // classification, otherwise an attacker could bypass + // the metadata block via ::ffff:169.254.169.254. + {"ipv4-mapped metadata", "::ffff:169.254.169.254", true}, + {"ipv4-mapped rfc1918", "::ffff:10.0.0.1", true}, + + {"public v4", "8.8.8.8", false}, + {"public v6", "2606:4700:4700::1111", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ip := net.ParseIP(tc.ip) + require.NotNil(t, ip, "parse %q", tc.ip) + require.Equal(t, tc.blocked, isPrivateIP(ip)) + }) + } +} + +// TestCertFetchClientRejectsLoopback proves the dialer refuses +// to connect even when the URL itself would have passed an +// allowlist (httptest.Server always binds to 127.0.0.1, so a +// successful fetch here would mean the SSRF guard had failed). +func TestCertFetchClientRejectsLoopback(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("should never be reached")) + })) + t.Cleanup(srv.Close) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + require.NoError(t, err) + resp, err := certFetchClient.Do(req) + if resp != nil { + defer resp.Body.Close() + } + require.Error(t, err) + require.Contains(t, err.Error(), "private IP") +} diff --git a/coderd/azureidentity/azureidentity_test.go b/coderd/azureidentity/azureidentity_test.go index 93627ff9279e1..2bc5643b5f040 100644 --- a/coderd/azureidentity/azureidentity_test.go +++ b/coderd/azureidentity/azureidentity_test.go @@ -116,3 +116,37 @@ func TestExpiresSoon(t *testing.T) { } } } + +func TestIsAllowedCertificateURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + url string + allowed bool + }{ + {"microsoft http", "http://www.microsoft.com/pki/mscorp/cert.crt", true}, + {"microsoft https", "https://www.microsoft.com/pkiops/certs/cert.crt", true}, + {"digicert http", "http://cacerts.digicert.com/DigiCertGlobalRootG2.crt", true}, + {"digicert https", "https://cacerts.digicert.com/DigiCertGlobalRootG3.crt", true}, + {"evil domain", "http://evil.example.com/cert.crt", false}, + {"metadata endpoint", "http://169.254.169.254/latest/meta-data/", false}, + {"localhost", "http://localhost/secret", false}, + {"subdomain trick", "http://www.microsoft.com.evil.com/cert.crt", false}, + {"empty string", "", false}, + {"ftp scheme", "ftp://www.microsoft.com/cert.crt", false}, + {"no scheme", "www.microsoft.com/cert.crt", false}, + {"javascript scheme", "javascript:alert(1)", false}, + {"microsoft with path", "http://www.microsoft.com/pkiops/certs/cert.crt", true}, + {"microsoft explicit port 80", "http://www.microsoft.com:80/cert.crt", true}, + {"microsoft explicit port 443", "https://www.microsoft.com:443/cert.crt", true}, + {"microsoft non-standard port", "http://www.microsoft.com:8080/cert.crt", false}, + {"microsoft port 22", "http://www.microsoft.com:22/cert.crt", false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := azureidentity.IsAllowedCertificateURL(tc.url) + require.Equal(t, tc.allowed, result, "URL: %s", tc.url) + }) + } +} diff --git a/coderd/workspaceresourceauth.go b/coderd/workspaceresourceauth.go index a9bf320c95391..ba310988a1b41 100644 --- a/coderd/workspaceresourceauth.go +++ b/coderd/workspaceresourceauth.go @@ -8,6 +8,7 @@ import ( "github.com/mitchellh/mapstructure" + "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/awsidentity" "github.com/coder/coder/v2/coderd/azureidentity" "github.com/coder/coder/v2/coderd/database/dbauthz" @@ -38,9 +39,17 @@ func (api *API) postWorkspaceAuthAzureInstanceIdentity(rw http.ResponseWriter, r VerifyOptions: api.AzureCertificates, }) if err != nil { + // Log the full error for operators but return only a + // generic message to the caller. Errors from the + // certificate fetch path may contain fragments of + // internal HTTP responses, so exposing them would be + // an information disclosure risk. + api.Logger.Warn(ctx, "azure identity validation failed", + slog.Error(err), + ) httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{ Message: "Invalid Azure identity.", - Detail: err.Error(), + Detail: "Signature verification failed.", }) return }