forked from anomalyco/opencode
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnon_interactive_mode.go
More file actions
292 lines (241 loc) · 7.81 KB
/
non_interactive_mode.go
File metadata and controls
292 lines (241 loc) · 7.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
package cmd
import (
"context"
"fmt"
"io"
"os"
"sync"
"time"
"log/slog"
charmlog "github.com/charmbracelet/log"
"github.com/sst/opencode/internal/app"
"github.com/sst/opencode/internal/config"
"github.com/sst/opencode/internal/db"
"github.com/sst/opencode/internal/format"
"github.com/sst/opencode/internal/llm/agent"
"github.com/sst/opencode/internal/llm/tools"
"github.com/sst/opencode/internal/message"
"github.com/sst/opencode/internal/permission"
"github.com/sst/opencode/internal/tui/components/spinner"
"github.com/sst/opencode/internal/tui/theme"
)
// syncWriter is a thread-safe writer that prevents interleaved output
type syncWriter struct {
w io.Writer
mu sync.Mutex
}
// Write implements io.Writer
func (sw *syncWriter) Write(p []byte) (n int, err error) {
sw.mu.Lock()
defer sw.mu.Unlock()
return sw.w.Write(p)
}
// newSyncWriter creates a new synchronized writer
func newSyncWriter(w io.Writer) io.Writer {
return &syncWriter{w: w}
}
// filterTools filters the provided tools based on allowed or excluded tool names
func filterTools(allTools []tools.BaseTool, allowedTools, excludedTools []string) []tools.BaseTool {
// If neither allowed nor excluded tools are specified, return all tools
if len(allowedTools) == 0 && len(excludedTools) == 0 {
return allTools
}
// Create a map for faster lookups
allowedMap := make(map[string]bool)
for _, name := range allowedTools {
allowedMap[name] = true
}
excludedMap := make(map[string]bool)
for _, name := range excludedTools {
excludedMap[name] = true
}
var filteredTools []tools.BaseTool
for _, tool := range allTools {
toolName := tool.Info().Name
// If we have an allowed list, only include tools in that list
if len(allowedTools) > 0 {
if allowedMap[toolName] {
filteredTools = append(filteredTools, tool)
}
} else if len(excludedTools) > 0 {
// If we have an excluded list, include all tools except those in the list
if !excludedMap[toolName] {
filteredTools = append(filteredTools, tool)
}
}
}
return filteredTools
}
// handleNonInteractiveMode processes a single prompt in non-interactive mode
func handleNonInteractiveMode(ctx context.Context, prompt string, outputFormat format.OutputFormat, quiet bool, verbose bool, allowedTools, excludedTools []string) error {
// Initial log message using standard slog
slog.Info("Running in non-interactive mode", "prompt", prompt, "format", outputFormat, "quiet", quiet, "verbose", verbose,
"allowedTools", allowedTools, "excludedTools", excludedTools)
// Sanity check for mutually exclusive flags
if quiet && verbose {
return fmt.Errorf("--quiet and --verbose flags cannot be used together")
}
// Set up logging to stderr if verbose mode is enabled
if verbose {
// Create a synchronized writer to prevent interleaved output
syncWriter := newSyncWriter(os.Stderr)
// Create a charmbracelet/log logger that writes to the synchronized writer
charmLogger := charmlog.NewWithOptions(syncWriter, charmlog.Options{
Level: charmlog.DebugLevel,
ReportCaller: true,
ReportTimestamp: true,
TimeFormat: time.RFC3339,
Prefix: "OpenCode",
})
// Set the global logger for charmbracelet/log
charmlog.SetDefault(charmLogger)
// Create a slog handler that uses charmbracelet/log
// This will forward all slog logs to charmbracelet/log
slog.SetDefault(slog.New(charmLogger))
// Log a message to confirm verbose logging is enabled
charmLogger.Info("Verbose logging enabled")
}
// Start spinner if not in quiet mode
var s *spinner.Spinner
if !quiet {
// Get the current theme to style the spinner
currentTheme := theme.CurrentTheme()
// Create a themed spinner
if currentTheme != nil {
// Use the primary color from the theme
s = spinner.NewThemedSpinner("Thinking...", currentTheme.Primary())
} else {
// Fallback to default spinner if no theme is available
s = spinner.NewSpinner("Thinking...")
}
s.Start()
defer s.Stop()
}
// Connect DB, this will also run migrations
conn, err := db.Connect()
if err != nil {
return err
}
// Create a context with cancellation
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Create the app
app, err := app.New(ctx, conn)
if err != nil {
slog.Error("Failed to create app", "error", err)
return err
}
// Create a new session for this prompt
session, err := app.Sessions.Create(ctx, "Non-interactive prompt")
if err != nil {
return fmt.Errorf("failed to create session: %w", err)
}
// Set the session as current
app.CurrentSession = &session
// Auto-approve all permissions for this session
permission.AutoApproveSession(ctx, session.ID)
// Create the user message
_, err = app.Messages.Create(ctx, session.ID, message.CreateMessageParams{
Role: message.User,
Parts: []message.ContentPart{message.TextContent{Text: prompt}},
})
if err != nil {
return fmt.Errorf("failed to create message: %w", err)
}
// If tool restrictions are specified, create a new agent with filtered tools
if len(allowedTools) > 0 || len(excludedTools) > 0 {
// Initialize MCP tools synchronously to ensure they're included in filtering
mcpCtx, mcpCancel := context.WithTimeout(ctx, 10*time.Second)
agent.GetMcpTools(mcpCtx, app.Permissions)
mcpCancel()
// Get all available tools including MCP tools
allTools := agent.PrimaryAgentTools(
app.Permissions,
app.Sessions,
app.Messages,
app.History,
app.LSPClients,
)
// Filter tools based on allowed/excluded lists
filteredTools := filterTools(allTools, allowedTools, excludedTools)
// Log the filtered tools for debugging
var toolNames []string
for _, tool := range filteredTools {
toolNames = append(toolNames, tool.Info().Name)
}
slog.Debug("Using filtered tools", "count", len(filteredTools), "tools", toolNames)
// Create a new agent with the filtered tools
restrictedAgent, err := agent.NewAgent(
config.AgentPrimary,
app.Sessions,
app.Messages,
filteredTools,
)
if err != nil {
return fmt.Errorf("failed to create restricted agent: %w", err)
}
// Use the restricted agent for this request
eventCh, err := restrictedAgent.Run(ctx, session.ID, prompt)
if err != nil {
return fmt.Errorf("failed to run restricted agent: %w", err)
}
// Wait for the response
var response message.Message
for event := range eventCh {
if event.Err() != nil {
return fmt.Errorf("agent error: %w", event.Err())
}
response = event.Response()
}
// Format and print the output
content := ""
if textContent := response.Content(); textContent != nil {
content = textContent.Text
}
formattedOutput, err := format.FormatOutput(content, outputFormat)
if err != nil {
return fmt.Errorf("failed to format output: %w", err)
}
// Stop spinner before printing output
if !quiet && s != nil {
s.Stop()
}
// Print the formatted output to stdout
fmt.Println(formattedOutput)
// Shutdown the app
app.Shutdown()
return nil
}
// Run the default agent if no tool restrictions
eventCh, err := app.PrimaryAgent.Run(ctx, session.ID, prompt)
if err != nil {
return fmt.Errorf("failed to run agent: %w", err)
}
// Wait for the response
var response message.Message
for event := range eventCh {
if event.Err() != nil {
return fmt.Errorf("agent error: %w", event.Err())
}
response = event.Response()
}
// Get the text content from the response
content := ""
if textContent := response.Content(); textContent != nil {
content = textContent.Text
}
// Format the output according to the specified format
formattedOutput, err := format.FormatOutput(content, outputFormat)
if err != nil {
return fmt.Errorf("failed to format output: %w", err)
}
// Stop spinner before printing output
if !quiet && s != nil {
s.Stop()
}
// Print the formatted output to stdout
fmt.Println(formattedOutput)
// Shutdown the app
app.Shutdown()
return nil
}