Skip to content

Commit 5271368

Browse files
authored
fix(skill): fix env variable propagation (#2645)
This PR refactors the tool file parsing logic by introducing a new ToolsFileParser struct. This encapsulation allows for stateful parsing, specifically enabling the tracking of resolved environment variables during configuration loading. This is useful for skill generation, where we can now identify and omit default values that correspond to environment variables, ensuring more precise generated skill.
1 parent 3d6ae4e commit 5271368

8 files changed

Lines changed: 89 additions & 39 deletions

File tree

cmd/internal/invoke/command.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func runInvoke(cmd *cobra.Command, args []string, opts *internal.ToolboxOptions)
5454
_ = shutdown(ctx)
5555
}()
5656

57-
_, err = opts.LoadConfig(ctx)
57+
_, err = opts.LoadConfig(ctx, &internal.ToolsFileParser{})
5858
if err != nil {
5959
return err
6060
}

cmd/internal/options.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func (opts *ToolboxOptions) Setup(ctx context.Context) (context.Context, func(co
131131
}
132132

133133
// LoadConfig checks and merge files that should be loaded into the server
134-
func (opts *ToolboxOptions) LoadConfig(ctx context.Context) (bool, error) {
134+
func (opts *ToolboxOptions) LoadConfig(ctx context.Context, parser *ToolsFileParser) (bool, error) {
135135
// Determine if Custom Files should be loaded
136136
// Check for explicit custom flags
137137
isCustomConfigured := opts.ToolsFile != "" || len(opts.ToolsFiles) > 0 || opts.ToolsFolder != ""
@@ -167,7 +167,7 @@ func (opts *ToolboxOptions) LoadConfig(ctx context.Context) (bool, error) {
167167
}
168168

169169
// Parse into ToolsFile struct
170-
parsed, err := parseToolsFile(ctx, buf)
170+
parsed, err := parser.ParseToolsFile(ctx, buf)
171171
if err != nil {
172172
errMsg := fmt.Errorf("unable to parse prebuilt tool configuration for '%s': %w", configName, err)
173173
logger.ErrorContext(ctx, errMsg.Error())
@@ -194,11 +194,11 @@ func (opts *ToolboxOptions) LoadConfig(ctx context.Context) (bool, error) {
194194
if len(opts.ToolsFiles) > 0 {
195195
// Use tools-files
196196
logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(opts.ToolsFiles)))
197-
customTools, err = LoadAndMergeToolsFiles(ctx, opts.ToolsFiles)
197+
customTools, err = parser.LoadAndMergeToolsFiles(ctx, opts.ToolsFiles)
198198
} else if opts.ToolsFolder != "" {
199199
// Use tools-folder
200200
logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", opts.ToolsFolder))
201-
customTools, err = LoadAndMergeToolsFolder(ctx, opts.ToolsFolder)
201+
customTools, err = parser.LoadAndMergeToolsFolder(ctx, opts.ToolsFolder)
202202
} else {
203203
// Use single file (tools-file or default `tools.yaml`)
204204
buf, readFileErr := os.ReadFile(opts.ToolsFile)
@@ -207,7 +207,7 @@ func (opts *ToolboxOptions) LoadConfig(ctx context.Context) (bool, error) {
207207
logger.ErrorContext(ctx, errMsg.Error())
208208
return isCustomConfigured, errMsg
209209
}
210-
customTools, err = parseToolsFile(ctx, buf)
210+
customTools, err = parser.ParseToolsFile(ctx, buf)
211211
if err != nil {
212212
err = fmt.Errorf("unable to parse tool file at %q: %w", opts.ToolsFile, err)
213213
}

cmd/internal/skills/command.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ func run(cmd *skillsCmd, opts *internal.ToolboxOptions) error {
7272
_ = shutdown(ctx)
7373
}()
7474

75-
_, err = opts.LoadConfig(ctx)
75+
parser := internal.ToolsFileParser{}
76+
_, err = opts.LoadConfig(ctx, &parser)
7677
if err != nil {
7778
return err
7879
}
@@ -165,7 +166,7 @@ func run(cmd *skillsCmd, opts *internal.ToolboxOptions) error {
165166
}
166167

167168
// Generate SKILL.md
168-
skillContent, err := generateSkillMarkdown(cmd.name, cmd.description, allTools)
169+
skillContent, err := generateSkillMarkdown(cmd.name, cmd.description, allTools, parser.EnvVars)
169170
if err != nil {
170171
errMsg := fmt.Errorf("error generating SKILL.md content: %w", err)
171172
opts.Logger.ErrorContext(ctx, errMsg.Error())

cmd/internal/skills/generator.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ type skillTemplateData struct {
7272
// generateSkillMarkdown generates the content of the SKILL.md file.
7373
// It includes usage instructions and a reference section for each tool in the skill,
7474
// detailing its description and parameters.
75-
func generateSkillMarkdown(skillName, skillDescription string, toolsMap map[string]tools.Tool) (string, error) {
75+
func generateSkillMarkdown(skillName, skillDescription string, toolsMap map[string]tools.Tool, envVars map[string]string) (string, error) {
7676
var toolsData []toolTemplateData
7777

7878
// Order tools based on name
@@ -86,7 +86,7 @@ func generateSkillMarkdown(skillName, skillDescription string, toolsMap map[stri
8686
tool := toolsMap[name]
8787
manifest := tool.Manifest()
8888

89-
parametersSchema, err := formatParameters(manifest.Parameters)
89+
parametersSchema, err := formatParameters(manifest.Parameters, envVars)
9090
if err != nil {
9191
return "", err
9292
}
@@ -200,7 +200,7 @@ func generateScriptContent(name string, toolsFileName string) (string, error) {
200200

201201
// formatParameters converts a list of parameter manifests into a formatted JSON schema string.
202202
// This schema is used in the skill documentation to describe the input parameters for a tool.
203-
func formatParameters(params []parameters.ParameterManifest) (string, error) {
203+
func formatParameters(params []parameters.ParameterManifest, envVars map[string]string) (string, error) {
204204
if len(params) == 0 {
205205
return "", nil
206206
}
@@ -214,7 +214,20 @@ func formatParameters(params []parameters.ParameterManifest) (string, error) {
214214
"description": p.Description,
215215
}
216216
if p.Default != nil {
217-
paramMap["default"] = p.Default
217+
defaultValue := p.Default
218+
// Check if default value is pre-configured, if so, remove it as the the value will be
219+
// read by the tool at runtime and the agent does not need to be aware of it.
220+
if strVal, ok := defaultValue.(string); ok {
221+
for _, envVal := range envVars {
222+
if envVal == strVal {
223+
defaultValue = nil
224+
break
225+
}
226+
}
227+
}
228+
if defaultValue != nil {
229+
paramMap["default"] = defaultValue
230+
}
218231
}
219232
properties[p.Name] = paramMap
220233
if p.Required {

cmd/internal/skills/generator_test.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ func TestFormatParameters(t *testing.T) {
6060
tests := []struct {
6161
name string
6262
params []parameters.ParameterManifest
63+
envVars map[string]string
6364
wantContains []string
6465
wantErr bool
6566
}{
@@ -115,11 +116,29 @@ func TestFormatParameters(t *testing.T) {
115116
`"param1"`,
116117
},
117118
},
119+
{
120+
name: "parameter with env var default",
121+
params: []parameters.ParameterManifest{
122+
{
123+
Name: "param1",
124+
Description: "Param 1",
125+
Type: "string",
126+
Default: "default-value",
127+
Required: false,
128+
},
129+
},
130+
envVars: map[string]string{
131+
"MY_ENV_VAR": "default-value",
132+
},
133+
wantContains: []string{
134+
`"param1": {`,
135+
},
136+
},
118137
}
119138

120139
for _, tt := range tests {
121140
t.Run(tt.name, func(t *testing.T) {
122-
got, err := formatParameters(tt.params)
141+
got, err := formatParameters(tt.params, tt.envVars)
123142
if (err != nil) != tt.wantErr {
124143
t.Errorf("formatParameters() error = %v, wantErr %v", err, tt.wantErr)
125144
return
@@ -154,7 +173,7 @@ func TestGenerateSkillMarkdown(t *testing.T) {
154173
},
155174
}
156175

157-
got, err := generateSkillMarkdown("MySkill", "My Description", toolsMap)
176+
got, err := generateSkillMarkdown("MySkill", "My Description", toolsMap, nil)
158177
if err != nil {
159178
t.Fatalf("generateSkillMarkdown() error = %v", err)
160179
}

cmd/internal/tools_file.go

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,34 +38,45 @@ type ToolsFile struct {
3838
Prompts server.PromptConfigs `yaml:"prompts"`
3939
}
4040

41+
type ToolsFileParser struct {
42+
EnvVars map[string]string
43+
}
44+
4145
// parseEnv replaces environment variables ${ENV_NAME} with their values.
4246
// also support ${ENV_NAME:default_value}.
43-
func parseEnv(input string) (string, error) {
47+
func (p *ToolsFileParser) parseEnv(input string) (string, error) {
4448
re := regexp.MustCompile(`\$\{(\w+)(:([^}]*))?\}`)
4549

50+
if p.EnvVars == nil {
51+
p.EnvVars = make(map[string]string)
52+
}
53+
4654
var err error
4755
output := re.ReplaceAllStringFunc(input, func(match string) string {
4856
parts := re.FindStringSubmatch(match)
4957

5058
// extract the variable name
5159
variableName := parts[1]
5260
if value, found := os.LookupEnv(variableName); found {
61+
p.EnvVars[variableName] = value
5362
return value
5463
}
5564
if len(parts) >= 4 && parts[2] != "" {
56-
return parts[3]
65+
value := parts[3]
66+
p.EnvVars[variableName] = value
67+
return value
5768
}
5869
err = fmt.Errorf("environment variable not found: %q", variableName)
5970
return ""
6071
})
6172
return output, err
6273
}
6374

64-
// parseToolsFile parses the provided yaml into appropriate configs.
65-
func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
75+
// ParseToolsFile parses the provided yaml into appropriate configs.
76+
func (p *ToolsFileParser) ParseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
6677
var toolsFile ToolsFile
6778
// Replace environment variables if found
68-
output, err := parseEnv(string(raw))
79+
output, err := p.parseEnv(string(raw))
6980
if err != nil {
7081
return toolsFile, fmt.Errorf("error parsing environment variables: %s", err)
7182
}
@@ -157,7 +168,7 @@ func transformDocs(kind string, input yaml.MapSlice) ([]yaml.MapSlice, error) {
157168
if !ok {
158169
return nil, fmt.Errorf("unexpected non-string key for entry in '%s': %v", kind, entry.Key)
159170
}
160-
entryBody := ProcessValue(entry.Value, kind == "toolsets")
171+
entryBody := processValue(entry.Value, kind == "toolsets")
161172

162173
currentTransformed := yaml.MapSlice{
163174
{Key: "kind", Value: kind},
@@ -175,8 +186,8 @@ func transformDocs(kind string, input yaml.MapSlice) ([]yaml.MapSlice, error) {
175186
return transformed, nil
176187
}
177188

178-
// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type'
179-
func ProcessValue(v any, isToolset bool) any {
189+
// processValue recursively looks for MapSlices to rename 'kind' -> 'type'
190+
func processValue(v any, isToolset bool) any {
180191
switch val := v.(type) {
181192
case yaml.MapSlice:
182193
// creating a new MapSlice is safer for recursive transformation
@@ -187,7 +198,7 @@ func ProcessValue(v any, isToolset bool) any {
187198
item.Key = "type"
188199
}
189200
// Recursive call for nested values (e.g., nested objects or lists)
190-
item.Value = ProcessValue(item.Value, false)
201+
item.Value = processValue(item.Value, false)
191202
newVal[i] = item
192203
}
193204
return newVal
@@ -199,7 +210,7 @@ func ProcessValue(v any, isToolset bool) any {
199210
// Otherwise, recurse into list items (to catch nested objects)
200211
newVal := make([]any, len(val))
201212
for i := range val {
202-
newVal[i] = ProcessValue(val[i], false)
213+
newVal[i] = processValue(val[i], false)
203214
}
204215
return newVal
205216
default:
@@ -287,7 +298,7 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
287298
}
288299

289300
// LoadAndMergeToolsFiles loads multiple YAML files and merges them
290-
func LoadAndMergeToolsFiles(ctx context.Context, filePaths []string) (ToolsFile, error) {
301+
func (p *ToolsFileParser) LoadAndMergeToolsFiles(ctx context.Context, filePaths []string) (ToolsFile, error) {
291302
var toolsFiles []ToolsFile
292303

293304
for _, filePath := range filePaths {
@@ -296,7 +307,7 @@ func LoadAndMergeToolsFiles(ctx context.Context, filePaths []string) (ToolsFile,
296307
return ToolsFile{}, fmt.Errorf("unable to read tool file at %q: %w", filePath, err)
297308
}
298309

299-
toolsFile, err := parseToolsFile(ctx, buf)
310+
toolsFile, err := p.ParseToolsFile(ctx, buf)
300311
if err != nil {
301312
return ToolsFile{}, fmt.Errorf("unable to parse tool file at %q: %w", filePath, err)
302313
}
@@ -313,7 +324,7 @@ func LoadAndMergeToolsFiles(ctx context.Context, filePaths []string) (ToolsFile,
313324
}
314325

315326
// LoadAndMergeToolsFolder loads all YAML files from a directory and merges them
316-
func LoadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile, error) {
327+
func (p *ToolsFileParser) LoadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile, error) {
317328
// Check if directory exists
318329
info, err := os.Stat(folderPath)
319330
if err != nil {
@@ -345,5 +356,5 @@ func LoadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile,
345356
}
346357

347358
// Use existing LoadAndMergeToolsFiles function
348-
return LoadAndMergeToolsFiles(ctx, allFiles)
359+
return p.LoadAndMergeToolsFiles(ctx, allFiles)
349360
}

cmd/internal/tools_file_test.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ func TestParseEnv(t *testing.T) {
8585
t.Setenv(k, v)
8686
}
8787
}
88-
got, err := parseEnv(tc.in)
88+
parser := &ToolsFileParser{}
89+
got, err := parser.parseEnv(tc.in)
8990
if tc.err {
9091
if err == nil {
9192
t.Fatalf("expected error not found")
@@ -754,7 +755,8 @@ func TestParseToolFile(t *testing.T) {
754755
}
755756
for _, tc := range tcs {
756757
t.Run(tc.description, func(t *testing.T) {
757-
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
758+
parser := ToolsFileParser{}
759+
toolsFile, err := parser.ParseToolsFile(ctx, testutils.FormatYaml(tc.in))
758760
if err != nil {
759761
t.Fatalf("failed to parse input: %v", err)
760762
}
@@ -1100,7 +1102,8 @@ func TestParseToolFileWithAuth(t *testing.T) {
11001102
}
11011103
for _, tc := range tcs {
11021104
t.Run(tc.description, func(t *testing.T) {
1103-
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
1105+
parser := ToolsFileParser{}
1106+
toolsFile, err := parser.ParseToolsFile(ctx, testutils.FormatYaml(tc.in))
11041107
if err != nil {
11051108
t.Fatalf("failed to parse input: %v", err)
11061109
}
@@ -1437,7 +1440,8 @@ func TestEnvVarReplacement(t *testing.T) {
14371440
}
14381441
for _, tc := range tcs {
14391442
t.Run(tc.description, func(t *testing.T) {
1440-
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
1443+
parser := ToolsFileParser{}
1444+
toolsFile, err := parser.ParseToolsFile(ctx, testutils.FormatYaml(tc.in))
14411445
if err != nil {
14421446
t.Fatalf("failed to parse input: %v", err)
14431447
}
@@ -1458,6 +1462,7 @@ func TestEnvVarReplacement(t *testing.T) {
14581462
}
14591463
})
14601464
}
1465+
14611466
}
14621467

14631468
func TestPrebuiltTools(t *testing.T) {
@@ -1949,7 +1954,8 @@ func TestPrebuiltTools(t *testing.T) {
19491954

19501955
for _, tc := range tcs {
19511956
t.Run(tc.name, func(t *testing.T) {
1952-
toolsFile, err := parseToolsFile(ctx, tc.in)
1957+
parser := ToolsFileParser{}
1958+
toolsFile, err := parser.ParseToolsFile(ctx, tc.in)
19531959
if err != nil {
19541960
t.Fatalf("failed to parse input: %v", err)
19551961
}
@@ -2146,8 +2152,8 @@ tools:
21462152
t.Run(tc.desc, func(t *testing.T) {
21472153
// Indent parameters to match YAML structure
21482154
yamlContent := fmt.Sprintf(baseYaml, tc.params)
2149-
2150-
_, err := parseToolsFile(ctx, []byte(yamlContent))
2155+
parser := ToolsFileParser{}
2156+
_, err := parser.ParseToolsFile(ctx, []byte(yamlContent))
21512157

21522158
if tc.wantErr {
21532159
if err == nil {

cmd/root.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,17 +367,17 @@ func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles m
367367
case <-debounce.C:
368368
debounce.Stop()
369369
var reloadedToolsFile internal.ToolsFile
370-
370+
parser := internal.ToolsFileParser{}
371371
if watchingFolder {
372372
logger.DebugContext(ctx, "Reloading tools folder.")
373-
reloadedToolsFile, err = internal.LoadAndMergeToolsFolder(ctx, folderToWatch)
373+
reloadedToolsFile, err = parser.LoadAndMergeToolsFolder(ctx, folderToWatch)
374374
if err != nil {
375375
logger.WarnContext(ctx, fmt.Sprintf("error loading tools folder %s", err))
376376
continue
377377
}
378378
} else {
379379
logger.DebugContext(ctx, "Reloading tools file(s).")
380-
reloadedToolsFile, err = internal.LoadAndMergeToolsFiles(ctx, slices.Collect(maps.Keys(watchedFiles)))
380+
reloadedToolsFile, err = parser.LoadAndMergeToolsFiles(ctx, slices.Collect(maps.Keys(watchedFiles)))
381381
if err != nil {
382382
logger.WarnContext(ctx, fmt.Sprintf("error loading tools files %s", err))
383383
continue
@@ -453,7 +453,7 @@ func run(cmd *cobra.Command, opts *internal.ToolboxOptions) error {
453453
_ = shutdown(ctx)
454454
}()
455455

456-
isCustomConfigured, err := opts.LoadConfig(ctx)
456+
isCustomConfigured, err := opts.LoadConfig(ctx, &internal.ToolsFileParser{})
457457
if err != nil {
458458
return err
459459
}

0 commit comments

Comments
 (0)