diff --git a/go/internal/feast/metrics/metrics.go b/go/internal/feast/metrics/metrics.go index 804eef6fa1b..d4783f257b7 100644 --- a/go/internal/feast/metrics/metrics.go +++ b/go/internal/feast/metrics/metrics.go @@ -30,7 +30,6 @@ var ( TimeHistogramType = reflect.TypeOf((*TimeHistogram)(nil)).Elem() ) - func RegisterTimeHistogram(name, help, namespace string, labelNames []string, tag reflect.StructTag) (func(prometheus.Labels) interface{}, prometheus.Collector, error) { f, collector, err := prometheusvanilla.BuildHistogram(name, help, namespace, labelNames, tag) if err != nil { diff --git a/go/internal/feast/onlinestore/postgresonlinestore.go b/go/internal/feast/onlinestore/postgresonlinestore.go index 4813f341db7..4077a9e06fa 100644 --- a/go/internal/feast/onlinestore/postgresonlinestore.go +++ b/go/internal/feast/onlinestore/postgresonlinestore.go @@ -194,4 +194,4 @@ func buildPostgresConnString(config map[string]interface{}) string { } return connURL.String() -} \ No newline at end of file +} diff --git a/go/internal/feast/server/http_server.go b/go/internal/feast/server/http_server.go index adfd40110e7..876f42f846b 100644 --- a/go/internal/feast/server/http_server.go +++ b/go/internal/feast/server/http_server.go @@ -2,6 +2,7 @@ package server import ( "context" + "crypto/tls" "encoding/json" "fmt" "net/http" @@ -396,6 +397,34 @@ func (s *httpServer) Serve(host string, port int) error { return err } +func (s *httpServer) ServeTLS(host string, port int, certFile string, keyFile string) error { + mux := http.NewServeMux() + mux.Handle("/get-online-features", metricsMiddleware(recoverMiddleware(http.HandlerFunc(s.getOnlineFeatures)))) + mux.Handle("/health", metricsMiddleware(http.HandlerFunc(healthCheckHandler))) + s.server = &http.Server{ + Addr: fmt.Sprintf("%s:%d", host, port), + Handler: mux, + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 15 * time.Second, + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{ + tls.CurveP256, + tls.X25519MLKEM768, + //tls.SecP256r1MLKEM768, // Only available in Go 1.26 + }, + }, + } + err := s.server.ListenAndServeTLS(certFile, keyFile) + // Don't return the error if it's caused by graceful shutdown using Stop() + if err == http.ErrServerClosed { + return nil + } + log.Fatal().Stack().Err(err).Msg("Failed to start HTTPS server") + return err +} + func healthCheckHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "Healthy") diff --git a/go/main.go b/go/main.go index f49a27efa46..7f89fe66c3b 100644 --- a/go/main.go +++ b/go/main.go @@ -11,6 +11,7 @@ import ( "strings" "sync" "syscall" + "time" "github.com/feast-dev/feast/go/internal/feast" "github.com/feast-dev/feast/go/internal/feast/registry" @@ -36,15 +37,28 @@ import ( var tracer trace.Tracer +var newSignalStopChannel = func() (chan os.Signal, func()) { + stop := make(chan os.Signal, 1) + signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + return stop, func() { + signal.Stop(stop) + } +} + type ServerStarter interface { StartHttpServer(fs *feast.FeatureStore, host string, port int, metricsPort int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts *logging.LoggingOptions) error StartGrpcServer(fs *feast.FeatureStore, host string, port int, metricsPort int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts *logging.LoggingOptions) error + StartHttpsServer(fs *feast.FeatureStore, host string, port int, metricsPort int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts *logging.LoggingOptions, certFile string, keyFile string) error } type RealServerStarter struct{} func (s *RealServerStarter) StartHttpServer(fs *feast.FeatureStore, host string, port int, metricsPort int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts *logging.LoggingOptions) error { - return StartHttpServer(fs, host, port, metricsPort, writeLoggedFeaturesCallback, loggingOpts) + return StartHttpServer(fs, host, port, metricsPort, writeLoggedFeaturesCallback, loggingOpts, false, "", "") +} + +func (s *RealServerStarter) StartHttpsServer(fs *feast.FeatureStore, host string, port int, metricsPort int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts *logging.LoggingOptions, certFile string, keyFile string) error { + return StartHttpServer(fs, host, port, metricsPort, writeLoggedFeaturesCallback, loggingOpts, true, certFile, keyFile) } func (s *RealServerStarter) StartGrpcServer(fs *feast.FeatureStore, host string, port int, metricsPort int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts *logging.LoggingOptions) error { @@ -58,18 +72,22 @@ func main() { port := 8080 metricsPort := 9090 server := RealServerStarter{} + certFile := "" + keyFile := "" // Current Directory repoPath, err := os.Getwd() if err != nil { log.Error().Stack().Err(err).Msg("Failed to get current directory") } - flag.StringVar(&serverType, "type", serverType, "Specify the server type (http or grpc)") + flag.StringVar(&serverType, "type", serverType, "Specify the server type (http, https or grpc)") flag.StringVar(&repoPath, "chdir", repoPath, "Repository path where feature store yaml file is stored") flag.StringVar(&host, "host", host, "Specify a host for the server") flag.IntVar(&port, "port", port, "Specify a port for the server") flag.IntVar(&metricsPort, "metrics-port", metricsPort, "Specify a port for the metrics server") + flag.StringVar(&certFile, "tls-cert-file", "", "Path to the TLS certificate file") + flag.StringVar(&keyFile, "tls-key-file", "", "Path to the TLS key file") flag.Parse() // Initialize tracer @@ -119,8 +137,10 @@ func main() { err = server.StartHttpServer(fs, host, port, metricsPort, nil, loggingOptions) } else if serverType == "grpc" { err = server.StartGrpcServer(fs, host, port, metricsPort, nil, loggingOptions) + } else if serverType == "https" { + err = server.StartHttpsServer(fs, host, port, metricsPort, nil, loggingOptions, certFile, keyFile) } else { - fmt.Println("Unknown server type. Please specify 'http' or 'grpc'.") + fmt.Println("Unknown server type. Please specify 'http' or 'grpc' or 'https'.") } if err != nil { @@ -227,27 +247,34 @@ func StartGrpcServer(fs *feast.FeatureStore, host string, port int, metricsPort // StartHttpServerWithLogging starts HTTP server with enabled feature logging // Go does not allow direct assignment to package-level functions as a way to // mock them for tests -func StartHttpServer(fs *feast.FeatureStore, host string, port int, metricsPort int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts *logging.LoggingOptions) error { +func StartHttpServer(fs *feast.FeatureStore, host string, port int, metricsPort int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts *logging.LoggingOptions, httpsEnable bool, certFile string, keyFile string) error { + if httpsEnable && (certFile == "" || keyFile == "") { + return fmt.Errorf("--tls-cert-file and --tls-key-file must be provided for HTTPS server.") + } + loggingService, err := constructLoggingService(fs, writeLoggedFeaturesCallback, loggingOpts) if err != nil { return err } ser := server.NewHttpServer(fs, loggingService) log.Info().Msgf("Starting a HTTP server on host %s, port %d", host, port) + // Start metrics server - metricsServer := &http.Server{Addr: fmt.Sprintf(":%d", metricsPort)} + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.Handler()) + metricsServer := &http.Server{ + Addr: fmt.Sprintf(":%d", metricsPort), + Handler: mux, + } go func() { log.Info().Msgf("Starting metrics server on port %d", metricsPort) - mux := http.NewServeMux() - mux.Handle("/metrics", promhttp.Handler()) - metricsServer.Handler = mux if err := metricsServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Error().Err(err).Msg("Failed to start metrics server") } }() - stop := make(chan os.Signal, 1) - signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + stop, stopCleanup := newSignalStopChannel() + defer stopCleanup() var wg sync.WaitGroup wg.Add(1) @@ -263,7 +290,9 @@ func StartHttpServer(fs *feast.FeatureStore, host string, port int, metricsPort log.Error().Err(err).Msg("Error when stopping the HTTP server") } log.Info().Msg("Stopping metrics server...") - if err := metricsServer.Shutdown(context.Background()); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := metricsServer.Shutdown(ctx); err != nil { log.Error().Err(err).Msg("Error stopping metrics server") } if loggingService != nil { @@ -279,7 +308,11 @@ func StartHttpServer(fs *feast.FeatureStore, host string, port int, metricsPort } }() - err = ser.Serve(host, port) + if httpsEnable { + err = ser.ServeTLS(host, port, certFile, keyFile) + } else { + err = ser.Serve(host, port) + } close(serverExited) wg.Wait() return err diff --git a/go/main_test.go b/go/main_test.go index f1f2ae98698..7eb0e0a6676 100644 --- a/go/main_test.go +++ b/go/main_test.go @@ -1,12 +1,28 @@ package main import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io" + "math/big" + "net" + "net/http" + "os" + "strings" + "syscall" "testing" + "time" "github.com/feast-dev/feast/go/internal/feast" "github.com/feast-dev/feast/go/internal/feast/server/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) // MockServerStarter is a mock of ServerStarter interface for testing @@ -67,3 +83,130 @@ func TestConstructLoggingService(t *testing.T) { assert.NoError(t, err) // Further assertions can be added here based on the expected behavior of constructLoggingService } + +func TestStartHttpsServerHealthEndpoint(t *testing.T) { + certPath, keyPath := createSelfSignedTLSFiles(t) + host := "127.0.0.1" + port := getFreePort(t) + metricsPort := getFreePort(t) + + stop := make(chan os.Signal, 1) + prevNewSignalStopChannel := newSignalStopChannel + newSignalStopChannel = func() (chan os.Signal, func()) { + return stop, func() {} + } + t.Cleanup(func() { + newSignalStopChannel = prevNewSignalStopChannel + }) + + errCh := make(chan error, 1) + go func() { + errCh <- StartHttpServer(&feast.FeatureStore{}, host, port, metricsPort, nil, &logging.LoggingOptions{}, true, certPath, keyPath) + }() + + httpsClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec + }, + } + t.Cleanup(httpsClient.CloseIdleConnections) + + url := fmt.Sprintf("https://%s:%d/health", host, port) + + var ( + resp *http.Response + err error + ) + require.Eventually(t, func() bool { + resp, err = httpsClient.Get(url) + if err != nil { + return false + } + return true + }, 5*time.Second, 100*time.Millisecond) + require.NoError(t, err) + t.Cleanup(func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }) + + body, readErr := io.ReadAll(resp.Body) + require.NoError(t, readErr) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "Healthy", strings.TrimSpace(string(body))) + + stop <- syscall.SIGTERM + + select { + case startErr := <-errCh: + require.NoError(t, startErr) + case <-time.After(5 * time.Second): + t.Fatal("StartHttpsServer did not shutdown within timeout") + } +} + +func TestStartHttpsServerTLSFilesRequired(t *testing.T) { + err := StartHttpServer(&feast.FeatureStore{}, "127.0.0.1", 0, 0, nil, &logging.LoggingOptions{}, true, "", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "--tls-cert-file and --tls-key-file must be provided") +} + +func getFreePort(t *testing.T) int { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = listener.Close() + }() + + addr, ok := listener.Addr().(*net.TCPAddr) + require.True(t, ok) + return addr.Port +} + +func createSelfSignedTLSFiles(t *testing.T) (string, string) { + t.Helper() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + tmpl := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "localhost", + }, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + der, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) + require.NoError(t, err) + + certFile, err := os.CreateTemp(t.TempDir(), "feast-test-cert-*.pem") + require.NoError(t, err) + defer func() { + _ = certFile.Close() + }() + + keyFile, err := os.CreateTemp(t.TempDir(), "feast-test-key-*.pem") + require.NoError(t, err) + defer func() { + _ = keyFile.Close() + }() + + err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: der}) + require.NoError(t, err) + + err = pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + require.NoError(t, err) + + return certFile.Name(), keyFile.Name() +}