diff --git a/go.mod b/go.mod index 5b2592536df..f0dc02d116d 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/dynamodb v1.43.3 github.com/aws/aws-sdk-go-v2/service/s3 v1.79.3 github.com/ghodss/yaml v1.0.0 + github.com/go-sql-driver/mysql v1.8.1 github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 github.com/mattn/go-sqlite3 v1.14.23 @@ -40,6 +41,7 @@ require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect cloud.google.com/go/iam v1.5.3 // indirect cloud.google.com/go/monitoring v1.24.2 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.54.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0 // indirect @@ -81,7 +83,7 @@ require ( github.com/googleapis/gax-go/v2 v2.15.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect - github.com/klauspost/compress v1.17.9 // indirect + github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/go.sum b/go.sum index aff636cad1b..3080dbdd50a 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ cloud.google.com/go/storage v1.58.0 h1:PflFXlmFJjG/nBeR9B7pKddLQWaFaRWx4uUi/LyNx cloud.google.com/go/storage v1.58.0/go.mod h1:cMWbtM+anpC74gn6qjLh+exqYcfmB9Hqe5z6adx+CLI= cloud.google.com/go/trace v1.11.6 h1:2O2zjPzqPYAHrn3OKl029qlqG6W8ZdYaOWRyr8NgMT4= cloud.google.com/go/trace v1.11.6/go.mod h1:GA855OeDEBiBMzcckLPE2kDunIpC72N+Pq8WFieFjnI= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 h1:UQUsRi8WTzhZntp5313l+CHIAT95ojUI2lpP/ExlZa4= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0/go.mod h1:Cz6ft6Dkn3Et6l2v2a9/RpN7epQ1GtDlO6lj8bEcOvw= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.54.0 h1:lhhYARPUu3LmHysQ/igznQphfzynnqI3D75oUyw1HXk= @@ -110,6 +112,8 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -135,8 +139,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnV github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= -github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= -github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/go/internal/feast/registry/mysql_registry_store.go b/go/internal/feast/registry/mysql_registry_store.go new file mode 100644 index 00000000000..f67e5a2eaf6 --- /dev/null +++ b/go/internal/feast/registry/mysql_registry_store.go @@ -0,0 +1,389 @@ +// Package registry implements Feast registry stores. +// +// MySQL Registry Store: +// The MySQL registry store provides read-only access to a Feast registry stored in MySQL. +// It queries a database schema matching the Python SQLAlchemy schema defined in +// sdk/python/feast/infra/registry/sql.py. When the Python schema evolves, the Go queries +// in this package must be updated accordingly. +package registry + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net" + "net/url" + "strings" + "sync" + "time" + + "github.com/feast-dev/feast/go/protos/feast/core" + "github.com/go-sql-driver/mysql" + "github.com/rs/zerolog/log" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// SQL queries for loading registry objects from the MySQL registry. +// These queries assume the schema defined in sdk/python/feast/infra/registry/sql.py. +const ( + queryProjects = "SELECT project_proto FROM projects WHERE project_id = ?" + queryEntities = "SELECT entity_proto FROM entities WHERE project_id = ?" + queryFeatureViews = "SELECT feature_view_proto FROM feature_views WHERE project_id = ?" + queryStreamFeatureViews = "SELECT feature_view_proto FROM stream_feature_views WHERE project_id = ?" + queryOnDemandFeatureViews = "SELECT feature_view_proto FROM on_demand_feature_views WHERE project_id = ?" + queryFeatureServices = "SELECT feature_service_proto FROM feature_services WHERE project_id = ?" + queryDataSources = "SELECT data_source_proto FROM data_sources WHERE project_id = ?" + querySavedDatasets = "SELECT saved_dataset_proto FROM saved_datasets WHERE project_id = ?" + queryValidationReferences = "SELECT validation_reference_proto FROM validation_references WHERE project_id = ?" + queryPermissions = "SELECT permission_proto FROM permissions WHERE project_id = ?" + queryManagedInfra = "SELECT infra_proto FROM managed_infra WHERE project_id = ?" + queryMaxLastUpdated = "SELECT MAX(last_updated_timestamp) FROM projects WHERE project_id = ?" +) + +type MySQLRegistryStore struct { + dsn string + dsnErr error + db *sql.DB + dbOnce sync.Once + dbErr error + project string + driverName string + registryConfig *RegistryConfig +} + +// NewMySQLRegistryStore creates a MySQLRegistryStore from a SQLAlchemy-style URL or a raw DSN. +func NewMySQLRegistryStore(config *RegistryConfig, repoPath string, project string) *MySQLRegistryStore { + dsn, err := mysqlURLToDSN(config.Path) + return &MySQLRegistryStore{ + dsn: dsn, + dsnErr: err, + project: project, + driverName: "mysql", + registryConfig: config, + } +} + +// newMySQLRegistryStoreWithDB is for tests to inject a pre-configured DB handle. +func newMySQLRegistryStoreWithDB(db *sql.DB, project string) *MySQLRegistryStore { + return &MySQLRegistryStore{ + db: db, + project: project, + driverName: "mysql", + } +} + +func (r *MySQLRegistryStore) GetRegistryProto() (*core.Registry, error) { + if r.project == "" { + return nil, errors.New("mysql registry store requires a project name") + } + db, err := r.getDB() + if err != nil { + return nil, err + } + + queryTimeout := r.getQueryTimeout() + ctx, cancel := context.WithTimeout(context.Background(), queryTimeout) + defer cancel() + + if err := db.PingContext(ctx); err != nil { + return nil, fmt.Errorf("failed to ping MySQL registry database: %w", err) + } + + registry := &core.Registry{ + RegistrySchemaVersion: REGISTRY_SCHEMA_VERSION, + } + + projects, err := readProtoRows(ctx, db, + queryProjects, + []any{r.project}, + func() *core.Project { return &core.Project{} }, + ) + if err != nil { + return nil, fmt.Errorf("failed to load projects: %w", err) + } + registry.Projects = projects + + if lastUpdated, err := r.getMaxProjectUpdatedTimestamp(ctx, db); err == nil && lastUpdated != nil { + registry.LastUpdated = lastUpdated + } + + entities, err := readProtoRows(ctx, db, + queryEntities, + []any{r.project}, + func() *core.Entity { return &core.Entity{} }, + ) + if err != nil { + return nil, fmt.Errorf("failed to load entities: %w", err) + } + registry.Entities = entities + + featureViews, err := readProtoRows(ctx, db, + queryFeatureViews, + []any{r.project}, + func() *core.FeatureView { return &core.FeatureView{} }, + ) + if err != nil { + return nil, fmt.Errorf("failed to load feature_views: %w", err) + } + registry.FeatureViews = featureViews + + streamFeatureViews, err := readProtoRows(ctx, db, + queryStreamFeatureViews, + []any{r.project}, + func() *core.StreamFeatureView { return &core.StreamFeatureView{} }, + ) + if err != nil { + return nil, fmt.Errorf("failed to load stream_feature_views: %w", err) + } + registry.StreamFeatureViews = streamFeatureViews + + onDemandFeatureViews, err := readProtoRows(ctx, db, + queryOnDemandFeatureViews, + []any{r.project}, + func() *core.OnDemandFeatureView { return &core.OnDemandFeatureView{} }, + ) + if err != nil { + return nil, fmt.Errorf("failed to load on_demand_feature_views: %w", err) + } + registry.OnDemandFeatureViews = onDemandFeatureViews + + featureServices, err := readProtoRows(ctx, db, + queryFeatureServices, + []any{r.project}, + func() *core.FeatureService { return &core.FeatureService{} }, + ) + if err != nil { + return nil, fmt.Errorf("failed to load feature_services: %w", err) + } + registry.FeatureServices = featureServices + + dataSources, err := readProtoRows(ctx, db, + queryDataSources, + []any{r.project}, + func() *core.DataSource { return &core.DataSource{} }, + ) + if err != nil { + return nil, fmt.Errorf("failed to load data_sources: %w", err) + } + registry.DataSources = dataSources + + savedDatasets, err := readProtoRows(ctx, db, + querySavedDatasets, + []any{r.project}, + func() *core.SavedDataset { return &core.SavedDataset{} }, + ) + if err != nil { + return nil, fmt.Errorf("failed to load saved_datasets: %w", err) + } + registry.SavedDatasets = savedDatasets + + validationReferences, err := readProtoRows(ctx, db, + queryValidationReferences, + []any{r.project}, + func() *core.ValidationReference { return &core.ValidationReference{} }, + ) + if err != nil { + return nil, fmt.Errorf("failed to load validation_references: %w", err) + } + registry.ValidationReferences = validationReferences + + permissions, err := readProtoRows(ctx, db, + queryPermissions, + []any{r.project}, + func() *core.Permission { return &core.Permission{} }, + ) + if err != nil { + return nil, fmt.Errorf("failed to load permissions: %w", err) + } + registry.Permissions = permissions + + infra, err := readProtoRows(ctx, db, + queryManagedInfra, + []any{r.project}, + func() *core.Infra { return &core.Infra{} }, + ) + if err != nil { + return nil, fmt.Errorf("failed to load managed_infra: %w", err) + } + if len(infra) > 0 { + registry.Infra = infra[0] + } + + log.Debug().Str("project", r.project).Msg("Loaded registry from MySQL") + return registry, nil +} + +func (r *MySQLRegistryStore) UpdateRegistryProto(rp *core.Registry) error { + return errors.New("not implemented in MySQLRegistryStore") +} + +func (r *MySQLRegistryStore) Teardown() error { + if r.db != nil { + return r.db.Close() + } + return nil +} + +func (r *MySQLRegistryStore) getDB() (*sql.DB, error) { + r.dbOnce.Do(func() { + if r.db != nil { + // Already initialized (e.g., via newMySQLRegistryStoreWithDB for tests) + return + } + if r.dsnErr != nil { + r.dbErr = fmt.Errorf("invalid MySQL registry DSN: %w", r.dsnErr) + return + } + if r.dsn == "" { + r.dbErr = errors.New("mysql registry store requires a non-empty DSN") + return + } + db, err := sql.Open(r.driverName, r.dsn) + if err != nil { + r.dbErr = fmt.Errorf("failed to open MySQL registry database: %w", err) + return + } + applyMySQLPoolSettings(db, r.registryConfig) + r.db = db + }) + if r.dbErr != nil { + return nil, r.dbErr + } + return r.db, nil +} + +func (r *MySQLRegistryStore) getMaxProjectUpdatedTimestamp(ctx context.Context, db *sql.DB) (*timestamppb.Timestamp, error) { + var maxUpdated sql.NullInt64 + err := db.QueryRowContext(ctx, + queryMaxLastUpdated, + r.project, + ).Scan(&maxUpdated) + if err != nil { + return nil, err + } + if !maxUpdated.Valid { + return nil, nil + } + return timestamppb.New(time.Unix(maxUpdated.Int64, 0)), nil +} + +func (r *MySQLRegistryStore) getQueryTimeout() time.Duration { + if r.registryConfig != nil && r.registryConfig.MySQLQueryTimeoutSeconds > 0 { + return time.Duration(r.registryConfig.MySQLQueryTimeoutSeconds) * time.Second + } + return time.Duration(defaultMySQLQueryTimeoutSeconds) * time.Second +} + +func readProtoRows[T proto.Message](ctx context.Context, db *sql.DB, query string, args []any, newProto func() T) ([]T, error) { + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + results := make([]T, 0) + for rows.Next() { + var data []byte + if err := rows.Scan(&data); err != nil { + return nil, err + } + msg := newProto() + if err := proto.Unmarshal(data, msg); err != nil { + return nil, err + } + results = append(results, msg) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + +func mysqlURLToDSN(registryPath string) (string, error) { + if strings.TrimSpace(registryPath) == "" { + return "", errors.New("mysql registry path is empty") + } + + parsed, err := url.Parse(registryPath) + if err != nil { + return "", err + } + + if parsed.Scheme == "" { + // Assume raw DSN. + return registryPath, nil + } + if !isMySQLScheme(parsed.Scheme) { + return "", fmt.Errorf("unsupported mysql scheme %q", parsed.Scheme) + } + + cfg := mysql.NewConfig() + if parsed.User != nil { + cfg.User = parsed.User.Username() + if pwd, ok := parsed.User.Password(); ok { + cfg.Passwd = pwd + } + } + + cfg.Net = "tcp" + if host := parsed.Hostname(); host != "" { + if port := parsed.Port(); port != "" { + cfg.Addr = net.JoinHostPort(host, port) + } else { + cfg.Addr = host + } + } + + cfg.DBName = strings.TrimPrefix(parsed.Path, "/") + if cfg.DBName == "" { + return "", errors.New("mysql registry path missing database name") + } + + params := parsed.Query() + if socket := params.Get("unix_socket"); socket != "" { + cfg.Net = "unix" + cfg.Addr = socket + params.Del("unix_socket") + } + + if len(params) > 0 { + cfg.Params = map[string]string{} + for key, values := range params { + if len(values) > 0 { + cfg.Params[key] = values[len(values)-1] + } else { + cfg.Params[key] = "" + } + } + } + + if cfg.Params == nil { + cfg.Params = map[string]string{} + } + if _, ok := cfg.Params["parseTime"]; !ok { + cfg.Params["parseTime"] = "true" + } + + return cfg.FormatDSN(), nil +} + +func applyMySQLPoolSettings(db *sql.DB, config *RegistryConfig) { + if config == nil { + return + } + if config.MySQLMaxOpenConns > 0 { + db.SetMaxOpenConns(config.MySQLMaxOpenConns) + } + if config.MySQLMaxIdleConns > 0 { + db.SetMaxIdleConns(config.MySQLMaxIdleConns) + } + if config.MySQLConnMaxLifetimeSeconds > 0 { + db.SetConnMaxLifetime(time.Duration(config.MySQLConnMaxLifetimeSeconds) * time.Second) + } +} + +func isMySQLScheme(scheme string) bool { + return strings.ToLower(scheme) == "mysql" +} diff --git a/go/internal/feast/registry/mysql_registry_store_test.go b/go/internal/feast/registry/mysql_registry_store_test.go new file mode 100644 index 00000000000..ae661c51c70 --- /dev/null +++ b/go/internal/feast/registry/mysql_registry_store_test.go @@ -0,0 +1,127 @@ +package registry + +import ( + "database/sql" + "testing" + + "github.com/feast-dev/feast/go/protos/feast/core" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +// getRegistrySchemaDDL returns the DDL statements for creating the registry schema. +// Schema must match the Python SQLAlchemy schema defined in: +// sdk/python/feast/infra/registry/sql.py +// When the Python schema evolves, this function must be updated accordingly. +func getRegistrySchemaDDL() []string { + return []string{ + `CREATE TABLE projects (project_id TEXT PRIMARY KEY, project_name TEXT NOT NULL, last_updated_timestamp INTEGER NOT NULL, project_proto BLOB NOT NULL);`, + `CREATE TABLE entities (entity_name TEXT NOT NULL, project_id TEXT NOT NULL, last_updated_timestamp INTEGER NOT NULL, entity_proto BLOB NOT NULL, PRIMARY KEY (entity_name, project_id));`, + `CREATE TABLE feature_views (feature_view_name TEXT NOT NULL, project_id TEXT NOT NULL, last_updated_timestamp INTEGER NOT NULL, materialized_intervals BLOB NULL, feature_view_proto BLOB NOT NULL, user_metadata BLOB NULL, PRIMARY KEY (feature_view_name, project_id));`, + `CREATE TABLE stream_feature_views (feature_view_name TEXT NOT NULL, project_id TEXT NOT NULL, last_updated_timestamp INTEGER NOT NULL, feature_view_proto BLOB NOT NULL, user_metadata BLOB NULL, PRIMARY KEY (feature_view_name, project_id));`, + `CREATE TABLE on_demand_feature_views (feature_view_name TEXT NOT NULL, project_id TEXT NOT NULL, last_updated_timestamp INTEGER NOT NULL, feature_view_proto BLOB NOT NULL, user_metadata BLOB NULL, PRIMARY KEY (feature_view_name, project_id));`, + `CREATE TABLE feature_services (feature_service_name TEXT NOT NULL, project_id TEXT NOT NULL, last_updated_timestamp INTEGER NOT NULL, feature_service_proto BLOB NOT NULL, PRIMARY KEY (feature_service_name, project_id));`, + `CREATE TABLE data_sources (data_source_name TEXT NOT NULL, project_id TEXT NOT NULL, last_updated_timestamp INTEGER NOT NULL, data_source_proto BLOB NOT NULL, PRIMARY KEY (data_source_name, project_id));`, + `CREATE TABLE saved_datasets (saved_dataset_name TEXT NOT NULL, project_id TEXT NOT NULL, last_updated_timestamp INTEGER NOT NULL, saved_dataset_proto BLOB NOT NULL, PRIMARY KEY (saved_dataset_name, project_id));`, + `CREATE TABLE validation_references (validation_reference_name TEXT NOT NULL, project_id TEXT NOT NULL, last_updated_timestamp INTEGER NOT NULL, validation_reference_proto BLOB NOT NULL, PRIMARY KEY (validation_reference_name, project_id));`, + `CREATE TABLE permissions (permission_name TEXT NOT NULL, project_id TEXT NOT NULL, last_updated_timestamp INTEGER NOT NULL, permission_proto BLOB NOT NULL, PRIMARY KEY (permission_name, project_id));`, + `CREATE TABLE managed_infra (infra_name TEXT NOT NULL, project_id TEXT NOT NULL, last_updated_timestamp INTEGER NOT NULL, infra_proto BLOB NOT NULL, PRIMARY KEY (infra_name, project_id));`, + } +} + +func TestMySQLRegistryStore_GetRegistryProto_FromSQLRegistrySchema(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err, "failed to open sqlite db") + defer db.Close() + + project := "feature_repo" + lastUpdated := int64(1710000000) + + for _, stmt := range getRegistrySchemaDDL() { + _, err := db.Exec(stmt) + require.NoError(t, err, "failed to create tables") + } + + projectProto := &core.Project{ + Spec: &core.ProjectSpec{Name: project}, + } + entityProto := &core.Entity{ + Spec: &core.EntitySpecV2{Name: "driver", Project: project}, + } + featureViewProto := &core.FeatureView{ + Spec: &core.FeatureViewSpec{Name: "driver_stats", Project: project}, + } + featureServiceProto := &core.FeatureService{ + Spec: &core.FeatureServiceSpec{Name: "driver_stats_service", Project: project}, + } + infraProto := &core.Infra{} + + projectBlob, err := proto.Marshal(projectProto) + require.NoError(t, err, "failed to marshal project proto") + entityBlob, err := proto.Marshal(entityProto) + require.NoError(t, err, "failed to marshal entity proto") + featureViewBlob, err := proto.Marshal(featureViewProto) + require.NoError(t, err, "failed to marshal feature view proto") + featureServiceBlob, err := proto.Marshal(featureServiceProto) + require.NoError(t, err, "failed to marshal feature service proto") + infraBlob, err := proto.Marshal(infraProto) + require.NoError(t, err, "failed to marshal infra proto") + + _, err = db.Exec( + "INSERT INTO projects (project_id, project_name, last_updated_timestamp, project_proto) VALUES (?, ?, ?, ?)", + project, project, lastUpdated, projectBlob, + ) + require.NoError(t, err, "failed to insert project row") + + _, err = db.Exec( + "INSERT INTO entities (entity_name, project_id, last_updated_timestamp, entity_proto) VALUES (?, ?, ?, ?)", + "driver", project, lastUpdated, entityBlob, + ) + require.NoError(t, err, "failed to insert entity row") + + _, err = db.Exec( + "INSERT INTO feature_views (feature_view_name, project_id, last_updated_timestamp, feature_view_proto) VALUES (?, ?, ?, ?)", + "driver_stats", project, lastUpdated, featureViewBlob, + ) + require.NoError(t, err, "failed to insert feature view row") + + _, err = db.Exec( + "INSERT INTO feature_services (feature_service_name, project_id, last_updated_timestamp, feature_service_proto) VALUES (?, ?, ?, ?)", + "driver_stats_service", project, lastUpdated, featureServiceBlob, + ) + require.NoError(t, err, "failed to insert feature service row") + + _, err = db.Exec( + "INSERT INTO managed_infra (infra_name, project_id, last_updated_timestamp, infra_proto) VALUES (?, ?, ?, ?)", + "infra_obj", project, lastUpdated, infraBlob, + ) + require.NoError(t, err, "failed to insert infra row") + + store := newMySQLRegistryStoreWithDB(db, project) + registryProto, err := store.GetRegistryProto() + require.NoError(t, err, "GetRegistryProto failed") + + assert.Equal(t, REGISTRY_SCHEMA_VERSION, registryProto.RegistrySchemaVersion) + require.Len(t, registryProto.Projects, 1) + assert.Equal(t, project, registryProto.Projects[0].Spec.GetName()) + require.Len(t, registryProto.Entities, 1) + assert.Equal(t, "driver", registryProto.Entities[0].Spec.GetName()) + require.Len(t, registryProto.FeatureViews, 1) + assert.Equal(t, "driver_stats", registryProto.FeatureViews[0].Spec.GetName()) + require.Len(t, registryProto.FeatureServices, 1) + assert.Equal(t, "driver_stats_service", registryProto.FeatureServices[0].Spec.GetName()) + require.NotNil(t, registryProto.LastUpdated) + assert.Equal(t, lastUpdated, registryProto.LastUpdated.GetSeconds()) + assert.NotNil(t, registryProto.Infra) +} + +func TestMySQLRegistryStore_SchemeRouting(t *testing.T) { + registryConfig := &RegistryConfig{ + Path: "mysql://user:pass@localhost:3306/feast", + } + store, err := getRegistryStoreFromScheme(registryConfig.Path, registryConfig, "", "feature_repo") + require.NoError(t, err, "getRegistryStoreFromScheme failed") + assert.IsType(t, &MySQLRegistryStore{}, store) +} diff --git a/go/internal/feast/registry/registry.go b/go/internal/feast/registry/registry.go index 9cd0febe5d3..51aa031bbda 100644 --- a/go/internal/feast/registry/registry.go +++ b/go/internal/feast/registry/registry.go @@ -15,10 +15,11 @@ import ( var REGISTRY_SCHEMA_VERSION string = "1" var REGISTRY_STORE_CLASS_FOR_SCHEME map[string]string = map[string]string{ - "gs": "GCSRegistryStore", - "s3": "S3RegistryStore", - "file": "FileRegistryStore", - "": "FileRegistryStore", + "gs": "GCSRegistryStore", + "s3": "S3RegistryStore", + "file": "FileRegistryStore", + "mysql": "MySQLRegistryStore", + "": "FileRegistryStore", } /* @@ -357,7 +358,7 @@ func getRegistryStoreFromScheme(registryPath string, registryConfig *RegistryCon if registryStoreType, ok := REGISTRY_STORE_CLASS_FOR_SCHEME[uri.Scheme]; ok { return getRegistryStoreFromType(registryStoreType, registryConfig, repoPath, project) } - return nil, fmt.Errorf("registry path %s has unsupported scheme %s. Supported schemes are file, s3 and gcs", registryPath, uri.Scheme) + return nil, fmt.Errorf("registry path %s has unsupported scheme %s. Supported schemes are file, s3, gcs, and mysql", registryPath, uri.Scheme) } func getRegistryStoreFromType(registryStoreType string, registryConfig *RegistryConfig, repoPath string, project string) (RegistryStore, error) { @@ -368,6 +369,8 @@ func getRegistryStoreFromType(registryStoreType string, registryConfig *Registry return NewS3RegistryStore(registryConfig, repoPath), nil case "GCSRegistryStore": return NewGCSRegistryStore(registryConfig, repoPath), nil + case "MySQLRegistryStore": + return NewMySQLRegistryStore(registryConfig, repoPath, project), nil } - return nil, errors.New("only FileRegistryStore, S3RegistryStore, and GCSRegistryStore are supported at this moment") + return nil, errors.New("only FileRegistryStore, S3RegistryStore, GCSRegistryStore, and MySQLRegistryStore are supported at this moment") } diff --git a/go/internal/feast/registry/repoconfig.go b/go/internal/feast/registry/repoconfig.go index f70310f261c..c13537ca021 100644 --- a/go/internal/feast/registry/repoconfig.go +++ b/go/internal/feast/registry/repoconfig.go @@ -12,8 +12,12 @@ import ( ) const ( - defaultCacheTtlSeconds = int64(600) - defaultClientID = "Unknown" + defaultCacheTtlSeconds = int64(600) + defaultClientID = "Unknown" + defaultMySQLMaxOpenConns = 20 + defaultMySQLMaxIdleConns = 10 + defaultMySQLConnMaxLifetimeSeconds = int64(300) + defaultMySQLQueryTimeoutSeconds = int64(30) ) type RepoConfig struct { @@ -39,10 +43,14 @@ type RepoConfig struct { } type RegistryConfig struct { - RegistryStoreType string `json:"registry_store_type"` - Path string `json:"path"` - ClientId string `json:"client_id" default:"Unknown"` - CacheTtlSeconds int64 `json:"cache_ttl_seconds" default:"600"` + RegistryStoreType string `json:"registry_store_type"` + Path string `json:"path"` + ClientId string `json:"client_id" default:"Unknown"` + CacheTtlSeconds int64 `json:"cache_ttl_seconds" default:"600"` + MySQLMaxOpenConns int `json:"mysql_max_open_conns"` + MySQLMaxIdleConns int `json:"mysql_max_idle_conns"` + MySQLConnMaxLifetimeSeconds int64 `json:"mysql_conn_max_lifetime_seconds"` + MySQLQueryTimeoutSeconds int64 `json:"mysql_query_timeout_seconds"` } // NewRepoConfigFromJSON converts a JSON string into a RepoConfig struct and also sets the repo path. @@ -111,7 +119,14 @@ func (r *RepoConfig) GetLoggingOptions() (*logging.LoggingOptions, error) { func (r *RepoConfig) GetRegistryConfig() (*RegistryConfig, error) { if registryConfigMap, ok := r.Registry.(map[string]interface{}); ok { - registryConfig := RegistryConfig{CacheTtlSeconds: defaultCacheTtlSeconds, ClientId: defaultClientID} + registryConfig := RegistryConfig{ + CacheTtlSeconds: defaultCacheTtlSeconds, + ClientId: defaultClientID, + MySQLMaxOpenConns: defaultMySQLMaxOpenConns, + MySQLMaxIdleConns: defaultMySQLMaxIdleConns, + MySQLConnMaxLifetimeSeconds: defaultMySQLConnMaxLifetimeSeconds, + MySQLQueryTimeoutSeconds: defaultMySQLQueryTimeoutSeconds, + } for k, v := range registryConfigMap { switch k { case "path": @@ -127,23 +142,70 @@ func (r *RepoConfig) GetRegistryConfig() (*RegistryConfig, error) { registryConfig.ClientId = value } case "cache_ttl_seconds": - // cache_ttl_seconds defaulted to type float64. Ex: "cache_ttl_seconds": 60 in registryConfigMap - switch value := v.(type) { - case float64: - registryConfig.CacheTtlSeconds = int64(value) - case int: - registryConfig.CacheTtlSeconds = int64(value) - case int32: - registryConfig.CacheTtlSeconds = int64(value) - case int64: - registryConfig.CacheTtlSeconds = value - default: - return nil, fmt.Errorf("unexpected type %T for CacheTtlSeconds", v) + parsed, err := parseInt64Field("cache_ttl_seconds", v) + if err != nil { + return nil, err + } + registryConfig.CacheTtlSeconds = parsed + case "mysql_max_open_conns": + parsed, err := parseIntField("mysql_max_open_conns", v) + if err != nil { + return nil, err + } + registryConfig.MySQLMaxOpenConns = parsed + case "mysql_max_idle_conns": + parsed, err := parseIntField("mysql_max_idle_conns", v) + if err != nil { + return nil, err } + registryConfig.MySQLMaxIdleConns = parsed + case "mysql_conn_max_lifetime_seconds": + parsed, err := parseInt64Field("mysql_conn_max_lifetime_seconds", v) + if err != nil { + return nil, err + } + registryConfig.MySQLConnMaxLifetimeSeconds = parsed + case "mysql_query_timeout_seconds": + parsed, err := parseInt64Field("mysql_query_timeout_seconds", v) + if err != nil { + return nil, err + } + registryConfig.MySQLQueryTimeoutSeconds = parsed } } return ®istryConfig, nil } else { - return &RegistryConfig{Path: r.Registry.(string), ClientId: defaultClientID, CacheTtlSeconds: defaultCacheTtlSeconds}, nil + return &RegistryConfig{ + Path: r.Registry.(string), + ClientId: defaultClientID, + CacheTtlSeconds: defaultCacheTtlSeconds, + MySQLMaxOpenConns: defaultMySQLMaxOpenConns, + MySQLMaxIdleConns: defaultMySQLMaxIdleConns, + MySQLConnMaxLifetimeSeconds: defaultMySQLConnMaxLifetimeSeconds, + MySQLQueryTimeoutSeconds: defaultMySQLQueryTimeoutSeconds, + }, nil + } +} + +func parseInt64Field(field string, value interface{}) (int64, error) { + switch parsed := value.(type) { + case float64: + return int64(parsed), nil + case int: + return int64(parsed), nil + case int32: + return int64(parsed), nil + case int64: + return parsed, nil + default: + return 0, fmt.Errorf("unexpected type %T for %s", value, field) + } +} + +func parseIntField(field string, value interface{}) (int, error) { + parsed, err := parseInt64Field(field, value) + if err != nil { + return 0, err } + return int(parsed), nil } diff --git a/go/internal/feast/registry/repoconfig_test.go b/go/internal/feast/registry/repoconfig_test.go index 4d30bf7bca0..3b139e411e0 100644 --- a/go/internal/feast/registry/repoconfig_test.go +++ b/go/internal/feast/registry/repoconfig_test.go @@ -191,10 +191,14 @@ func TestGetRegistryConfig_Map(t *testing.T) { // Create a RepoConfig with a map Registry config := &RepoConfig{ Registry: map[string]interface{}{ - "path": "data/registry.db", - "registry_store_type": "local", - "client_id": "test_client_id", - "cache_ttl_seconds": 60, + "path": "data/registry.db", + "registry_store_type": "local", + "client_id": "test_client_id", + "cache_ttl_seconds": 60, + "mysql_max_open_conns": 25, + "mysql_max_idle_conns": 12, + "mysql_conn_max_lifetime_seconds": 180, + "mysql_query_timeout_seconds": 60, }, } @@ -206,6 +210,10 @@ func TestGetRegistryConfig_Map(t *testing.T) { assert.Equal(t, "local", registryConfig.RegistryStoreType) assert.Equal(t, int64(60), registryConfig.CacheTtlSeconds) assert.Equal(t, "test_client_id", registryConfig.ClientId) + assert.Equal(t, 25, registryConfig.MySQLMaxOpenConns) + assert.Equal(t, 12, registryConfig.MySQLMaxIdleConns) + assert.Equal(t, int64(180), registryConfig.MySQLConnMaxLifetimeSeconds) + assert.Equal(t, int64(60), registryConfig.MySQLQueryTimeoutSeconds) } func TestGetRegistryConfig_String(t *testing.T) { @@ -223,6 +231,10 @@ func TestGetRegistryConfig_String(t *testing.T) { println(registryConfig.CacheTtlSeconds) assert.Empty(t, registryConfig.RegistryStoreType) assert.Equal(t, defaultCacheTtlSeconds, registryConfig.CacheTtlSeconds) + assert.Equal(t, defaultMySQLMaxOpenConns, registryConfig.MySQLMaxOpenConns) + assert.Equal(t, defaultMySQLMaxIdleConns, registryConfig.MySQLMaxIdleConns) + assert.Equal(t, defaultMySQLConnMaxLifetimeSeconds, registryConfig.MySQLConnMaxLifetimeSeconds) + assert.Equal(t, defaultMySQLQueryTimeoutSeconds, registryConfig.MySQLQueryTimeoutSeconds) } func TestGetRegistryConfig_CacheTtlSecondsTypes(t *testing.T) {