Skip to content

Commit 09e3119

Browse files
authored
Optionally include reachable fieldpaths in prompt (#1285)
* Optionally include reachable fieldpaths in prompt Update the AI prompt template to optionally include all of the reachable fields from the input variables along with any documentation. The fields are described in terms of CEL types. * Add bazel rule for preserving proto source info
1 parent ae49cd0 commit 09e3119

43 files changed

Lines changed: 10410 additions & 27 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

bazel/BUILD.bazel

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
2+
3+
bzl_library(
4+
name = "proto_source_info",
5+
srcs = ["proto_source_info.bzl"],
6+
)

bazel/proto_source_info.bzl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Build rule for preserving source information in proto descriptor sets."""
2+
3+
load("@com_google_protobuf//bazel/common:proto_info.bzl", "ProtoInfo")
4+
5+
def _source_info_proto_descriptor_set(ctx):
6+
"""Returns a proto descriptor set with source information preserved."""
7+
srcs = depset([s for dep in ctx.attr.proto_libs for s in dep[ProtoInfo].direct_sources])
8+
deps = depset(transitive = [dep[ProtoInfo].transitive_descriptor_sets for dep in ctx.attr.proto_libs])
9+
10+
src_files = srcs.to_list()
11+
dep_files = deps.to_list()
12+
13+
args = ctx.actions.args()
14+
args.add("--descriptor_set_out=" + ctx.outputs.out.path)
15+
args.add("--include_imports")
16+
args.add("--include_source_info=true")
17+
args.add("--proto_path=.")
18+
args.add("--proto_path=" + ctx.configuration.genfiles_dir.path)
19+
args.add("--descriptor_set_in=" + ":".join([d.path for d in dep_files]))
20+
args.add_all(src_files)
21+
22+
ctx.actions.run(
23+
executable = ctx.executable._protoc,
24+
inputs = src_files + dep_files,
25+
outputs = [ctx.outputs.out],
26+
arguments = [args],
27+
mnemonic = "SourceInfoProtoDescriptorSet",
28+
progress_message = "Generating proto descriptor set with source information for %{label}",
29+
)
30+
31+
source_info_proto_descriptor_set = rule(
32+
doc = """
33+
Rule for generating a proto descriptor set for the transitive dependencies of proto libraries
34+
with source information preserved.
35+
36+
This can dramatically increase the size of the descriptor set, so only use it
37+
when necessary (e.g. for formatting documentation about a CEL environment).
38+
39+
Source info is only preserved for input files for each proto_library label in
40+
protolibs. Transitive dependencies are included with source info stripped.
41+
""",
42+
attrs = {
43+
"proto_libs": attr.label_list(providers = [[ProtoInfo]]),
44+
"_protoc": attr.label(
45+
default = "@com_google_protobuf//:protoc",
46+
executable = True,
47+
cfg = "exec",
48+
),
49+
},
50+
outputs = {
51+
"out": "%{name}-transitive-descriptor-set-source-info.proto.bin",
52+
},
53+
implementation = _source_info_proto_descriptor_set,
54+
)

cel/BUILD.bazel

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ go_library(
1010
"cel.go",
1111
"decls.go",
1212
"env.go",
13+
"fieldpaths.go",
1314
"folding.go",
1415
"inlining.go",
1516
"io.go",
@@ -43,6 +44,7 @@ go_library(
4344
"//interpreter:go_default_library",
4445
"//parser:go_default_library",
4546
"@dev_cel_expr//:expr",
47+
"@dev_cel_expr//conformance/proto3:go_default_library",
4648
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
4749
"@org_golang_google_protobuf//proto:go_default_library",
4850
"@org_golang_google_protobuf//reflect/protodesc:go_default_library",
@@ -63,6 +65,7 @@ go_test(
6365
"cel_test.go",
6466
"decls_test.go",
6567
"env_test.go",
68+
"fieldpaths_test.go",
6669
"folding_test.go",
6770
"inlining_test.go",
6871
"io_test.go",
@@ -78,6 +81,7 @@ go_test(
7881
],
7982
embedsrcs = [
8083
"//cel/testdata:prompts",
84+
"//cel/testdata:test_fds_with_source_info",
8185
],
8286
deps = [
8387
"//common/operators:go_default_library",
@@ -89,6 +93,7 @@ go_test(
8993
"//test:go_default_library",
9094
"//test/proto2pb:go_default_library",
9195
"//test/proto3pb:go_default_library",
96+
"@com_github_google_go_cmp//cmp:go_default_library",
9297
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
9398
"@org_golang_google_protobuf//encoding/prototext:go_default_library",
9499
"@org_golang_google_protobuf//proto:go_default_library",
@@ -100,4 +105,4 @@ go_test(
100105
exports_files(
101106
["templates/authoring.tmpl"],
102107
visibility = ["//visibility:public"],
103-
)
108+
)

cel/fieldpaths.go

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
package cel
2+
3+
import (
4+
"slices"
5+
"strings"
6+
7+
"github.com/google/cel-go/common"
8+
"github.com/google/cel-go/common/types"
9+
)
10+
11+
// fieldPath represents a selection path to a field from a variable in a CEL environment.
12+
type fieldPath struct {
13+
celType *Type
14+
// path represents the selection path to the field.
15+
path string
16+
description string
17+
isLeaf bool
18+
}
19+
20+
// Documentation implements the Documentor interface.
21+
func (f *fieldPath) Documentation() *common.Doc {
22+
return common.NewFieldDoc(f.path, f.celType.String(), f.description)
23+
}
24+
25+
type documentationProvider interface {
26+
// FindStructFieldDescription returns documentation for a field if available.
27+
// Returns false if the field could not be found.
28+
FindStructFieldDescription(typeName, fieldName string) (string, bool)
29+
}
30+
31+
type backtrack struct {
32+
// provider used to resolve types.
33+
provider types.Provider
34+
// paths of fields that have been visited along the path.
35+
path []string
36+
// types of fields that have been visited along the path. used to avoid cycles.
37+
types []*Type
38+
}
39+
40+
func (b *backtrack) push(pathStep string, celType *Type) {
41+
b.path = append(b.path, pathStep)
42+
b.types = append(b.types, celType)
43+
}
44+
45+
func (b *backtrack) pop() {
46+
b.path = b.path[:len(b.path)-1]
47+
b.types = b.types[:len(b.types)-1]
48+
}
49+
50+
func formatPath(path []string) string {
51+
var buffer strings.Builder
52+
for i, p := range path {
53+
if i == 0 {
54+
buffer.WriteString(p)
55+
continue
56+
}
57+
if strings.HasPrefix(p, "[") {
58+
buffer.WriteString(p)
59+
continue
60+
}
61+
buffer.WriteString(".")
62+
buffer.WriteString(p)
63+
}
64+
return buffer.String()
65+
}
66+
67+
func (b *backtrack) expandFieldPaths(celType *Type, paths []*fieldPath) []*fieldPath {
68+
if slices.ContainsFunc(b.types[:len(b.types)-1], func(t *Type) bool { return t.String() == celType.String() }) {
69+
// Cycle detected, so stop expanding.
70+
paths[len(paths)-1].isLeaf = false
71+
return paths
72+
}
73+
switch celType.Kind() {
74+
case types.StructKind:
75+
fields, ok := b.provider.FindStructFieldNames(celType.String())
76+
if !ok {
77+
// Caller added this type to the path, so it must be a leaf.
78+
paths[len(paths)-1].isLeaf = true
79+
return paths
80+
}
81+
for _, field := range fields {
82+
fieldType, ok := b.provider.FindStructFieldType(celType.String(), field)
83+
if !ok {
84+
// Field not found, either hidden or an error.
85+
continue
86+
}
87+
b.push(field, celType)
88+
description := ""
89+
if docProvider, ok := b.provider.(documentationProvider); ok {
90+
description, _ = docProvider.FindStructFieldDescription(celType.String(), field)
91+
}
92+
path := &fieldPath{
93+
celType: fieldType.Type,
94+
path: formatPath(b.path),
95+
description: description,
96+
isLeaf: false,
97+
}
98+
paths = append(paths, path)
99+
paths = b.expandFieldPaths(fieldType.Type, paths)
100+
b.pop()
101+
}
102+
return paths
103+
case types.MapKind:
104+
if len(celType.Parameters()) != 2 {
105+
// dynamic map, so treat as a leaf.
106+
paths[len(paths)-1].isLeaf = true
107+
return paths
108+
}
109+
mapKeyType := celType.Parameters()[0]
110+
mapValueType := celType.Parameters()[1]
111+
// Add a placeholder for the map key kind (the zero value).
112+
keyIdentifier := ""
113+
switch mapKeyType.Kind() {
114+
case types.StringKind:
115+
keyIdentifier = "[\"\"]"
116+
case types.IntKind:
117+
keyIdentifier = "[0]"
118+
case types.UintKind:
119+
keyIdentifier = "[0u]"
120+
case types.BoolKind:
121+
keyIdentifier = "[false]"
122+
default:
123+
// Caller added this type to the path, so it must be a leaf.
124+
paths[len(paths)-1].isLeaf = true
125+
return paths
126+
}
127+
b.push(keyIdentifier, mapValueType)
128+
defer b.pop()
129+
return b.expandFieldPaths(mapValueType, paths)
130+
case types.ListKind:
131+
if len(celType.Parameters()) != 1 {
132+
// dynamic list, so treat as a leaf.
133+
paths[len(paths)-1].isLeaf = true
134+
return paths
135+
}
136+
listElemType := celType.Parameters()[0]
137+
b.push("[0]", listElemType)
138+
defer b.pop()
139+
return b.expandFieldPaths(listElemType, paths)
140+
default:
141+
paths[len(paths)-1].isLeaf = true
142+
}
143+
144+
return paths
145+
}
146+
147+
// fieldPathsForType expands the reachable fields from the given root identifier.
148+
func fieldPathsForType(provider types.Provider, identifier string, celType *Type) []*fieldPath {
149+
b := &backtrack{
150+
provider: provider,
151+
path: []string{identifier},
152+
types: []*Type{celType},
153+
}
154+
paths := []*fieldPath{
155+
{
156+
celType: celType,
157+
path: identifier,
158+
isLeaf: false,
159+
},
160+
}
161+
162+
return b.expandFieldPaths(celType, paths)
163+
}

0 commit comments

Comments
 (0)