@@ -388,6 +388,20 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum
388388}
389389
390390func buildQueries (conf Config , req * plugin.GenerateRequest , structs []Struct ) ([]Query , error ) {
391+ rlsFieldsByTable := make (map [string ][]string ) // TODO
392+ if len (conf .RLSEnforcedFields ) > 0 {
393+ for i := range structs {
394+ tableName := structs [i ].Table .Name
395+ for _ , f := range structs [i ].Fields {
396+ for _ , enforced := range conf .RLSEnforcedFields {
397+ if f .Name == enforced {
398+ rlsFieldsByTable [tableName ] = append (rlsFieldsByTable [tableName ], f .Name )
399+ }
400+ }
401+ }
402+ }
403+ }
404+
391405 qs := make ([]Query , 0 , len (req .Queries ))
392406 for _ , query := range req .Queries {
393407 if query .Name == "" {
@@ -419,9 +433,20 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([
419433 if qpl < 0 {
420434 return nil , errors .New ("invalid query parameter limit" )
421435 }
436+ enforcedFields := make (map [string ]bool )
437+ for _ , c := range query .Columns {
438+ if fields , ok := rlsFieldsByTable [c .GetTable ().GetName ()]; ok {
439+ for _ , f := range fields {
440+ enforcedFields [f ] = false
441+ }
442+ }
443+ }
422444 if len (query .Params ) > qpl || qpl == 0 {
423445 var cols []pyColumn
424446 for _ , p := range query .Params {
447+ if _ , ok := enforcedFields [p .GetColumn ().GetName ()]; ok {
448+ enforcedFields [p .Column .Name ] = true
449+ }
425450 cols = append (cols , pyColumn {
426451 id : p .Number ,
427452 Column : p .Column ,
@@ -435,14 +460,21 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([
435460 } else {
436461 args := make ([]QueryValue , 0 , len (query .Params ))
437462 for _ , p := range query .Params {
463+ if _ , ok := enforcedFields [p .GetColumn ().GetName ()]; ok {
464+ enforcedFields [p .Column .Name ] = true
465+ }
438466 args = append (args , QueryValue {
439467 Name : paramName (p ),
440468 Typ : makePyType (req , p .Column ),
441469 })
442470 }
443471 gq .Args = args
444472 }
445-
473+ for field , is_enforced := range enforcedFields {
474+ if ! is_enforced {
475+ return nil , fmt .Errorf ("RLS field %s is not filtered in query %s" , field , query .Name )
476+ }
477+ }
446478 if len (query .Columns ) == 1 {
447479 c := query .Columns [0 ]
448480 gq .Ret = QueryValue {
0 commit comments