From 5e703a3dd41e47f615f13474f0634b3309411a23 Mon Sep 17 00:00:00 2001 From: dwrz Date: Fri, 13 Feb 2026 15:02:07 +0000 Subject: [PATCH] Add tool registry --- internal/tool/tool.go | 235 +++++++++++++++++++++ internal/tool/tool_test.go | 409 +++++++++++++++++++++++++++++++++++++ 2 files changed, 644 insertions(+) create mode 100644 internal/tool/tool.go create mode 100644 internal/tool/tool_test.go diff --git a/internal/tool/tool.go b/internal/tool/tool.go new file mode 100644 index 0000000..a6244c9 --- /dev/null +++ b/internal/tool/tool.go @@ -0,0 +1,235 @@ +// Package tool provides a registry for external tools that can be invoked by +// LLMs. +// +// The package bridges YAML configuration to exec.CommandContext, allowing +// tools to be defined declaratively without writing Go code. Each tool +// specifies a command, argument templates using Go's text/template syntax, +// JSON Schema parameters for LLM input, and an optional execution timeout. +// +// The registry validates all tool definitions at construction time, +// failing fast on configuration errors. At execution time, it expands +// argument templates with LLM-provided JSON, runs the subprocess, and +// returns stdout. +package tool + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os/exec" + "text/template" + "time" + + "github.com/sashabaranov/go-openai" +) + +// fn provides template functions available to argument templates. +var fn = template.FuncMap{ + "json": func(v any) string { + b, _ := json.Marshal(v) + return string(b) + }, +} + +// Tool represents an external tool that can be invoked by LLMs. +// Tools are executed as subprocesses with templated arguments. +type Tool struct { + // Name uniquely identifies the tool within a registry. + Name string `yaml:"name"` + // Description explains the tool's purpose for the LLM. + Description string `yaml:"description"` + // Command is the executable path or name. + Command string `yaml:"command"` + // Arguments are Go templates expanded with LLM-provided parameters. + // Empty results after expansion are filtered out. + Arguments []string `yaml:"arguments"` + // Parameters is a JSON Schema describing expected input from the LLM. + Parameters map[string]any `yaml:"parameters"` + // Timeout limits execution time (e.g., "30s", "5m"). + // Empty means no timeout. + Timeout string `yaml:"timeout"` + // timeout is the parsed duration, set during registry construction. + timeout time.Duration `yaml:"-"` +} + +// OpenAI converts the tool to an OpenAI function definition for API calls. +func (t Tool) OpenAI() openai.Tool { + return openai.Tool{ + Type: openai.ToolTypeFunction, + Function: &openai.FunctionDefinition{ + Name: t.Name, + Description: t.Description, + Parameters: t.Parameters, + }, + } +} + +// ParseArguments expands argument templates with the provided JSON data. +// The args parameter should be a JSON object string; empty string or "{}" +// results in an empty data map. Templates producing empty strings are +// filtered from the result, allowing conditional arguments. +func (t Tool) ParseArguments(args string) ([]string, error) { + var data = map[string]any{} + if args != "" && args != "{}" { + if err := json.Unmarshal([]byte(args), &data); err != nil { + return nil, fmt.Errorf( + "invalid arguments JSON: %w", err, + ) + } + } + + var result []string + for _, v := range t.Arguments { + tmpl, err := template.New("").Funcs(fn).Parse(v) + if err != nil { + return nil, fmt.Errorf( + "invalid template %q: %w", v, err, + ) + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return nil, fmt.Errorf( + "execute template %q: %w", v, err, + ) + } + + // Filter out empty strings (unused conditional arguments). + if s := buf.String(); s != "" { + result = append(result, s) + } + } + + return result, nil +} + +// Validate checks that the tool definition is complete and valid. +// It verifies required fields are present, the timeout (if specified) +// is parseable, and all argument templates are syntactically valid. +func (t Tool) Validate() error { + if t.Name == "" { + return fmt.Errorf("missing name") + } + if t.Description == "" { + return fmt.Errorf("missing description") + } + if t.Command == "" { + return fmt.Errorf("missing command") + } + if t.Timeout != "" { + if _, err := time.ParseDuration(t.Timeout); err != nil { + return fmt.Errorf("invalid timeout: %v", err) + } + } + for _, arg := range t.Arguments { + if _, err := template.New("").Funcs(fn).Parse(arg); err != nil { + return fmt.Errorf("invalid argument template") + } + } + + return nil +} + +// Registry holds tools indexed by name and handles their execution. +// It validates all tools at construction time to fail fast on +// configuration errors. +type Registry struct { + tools map[string]*Tool +} + +// NewRegistry creates a registry from the provided tool definitions. +// Returns an error if any tool fails validation or if duplicate names exist. +func NewRegistry(tools []Tool) (*Registry, error) { + var r = &Registry{ + tools: make(map[string]*Tool), + } + for _, t := range tools { + if err := t.Validate(); err != nil { + return nil, fmt.Errorf("invalid tool: %v", err) + } + if t.Timeout != "" { + d, err := time.ParseDuration(t.Timeout) + if err != nil { + return nil, fmt.Errorf( + "parse timeout: %v", err, + ) + } + t.timeout = d + } + + if _, exists := r.tools[t.Name]; exists { + return nil, fmt.Errorf( + "duplicate tool name: %s", t.Name, + ) + } + r.tools[t.Name] = &t + } + + return r, nil +} + +// Get returns a tool by name and a boolean indicating if it was found. +func (r *Registry) Get(name string) (*Tool, bool) { + tool, ok := r.tools[name] + return tool, ok +} + +// List returns all registered tool names in arbitrary order. +func (r *Registry) List() []string { + names := make([]string, 0, len(r.tools)) + for name := range r.tools { + names = append(names, name) + } + return names +} + +// Execute runs a tool by name with the provided JSON arguments. +// It expands argument templates, executes the command as a subprocess, +// and returns stdout on success. The context can be used for cancellation; +// tool-specific timeouts are applied on top of any context deadline. +func (r *Registry) Execute( + ctx context.Context, name string, args string, +) (string, error) { + tool, ok := r.tools[name] + if !ok { + return "", fmt.Errorf("unknown tool: %s", name) + } + + // Evaluate argument templates. + cmdArgs, err := tool.ParseArguments(args) + if err != nil { + return "", fmt.Errorf("parse arguments: %w", err) + } + + // If defined, use the timeout. + if tool.timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, tool.timeout) + defer cancel() + } + + // Setup and run the command. + var ( + stdout, stderr bytes.Buffer + cmd = exec.CommandContext( + ctx, tool.Command, cmdArgs..., + ) + ) + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + if ctx.Err() == context.DeadlineExceeded && tool.timeout > 0 { + return "", fmt.Errorf( + "tool %s timed out after %v", + name, tool.timeout, + ) + } + return "", fmt.Errorf( + "tool %s: %w\nstderr: %s", + name, err, stderr.String(), + ) + } + + return stdout.String(), nil +} diff --git a/internal/tool/tool_test.go b/internal/tool/tool_test.go new file mode 100644 index 0000000..c274bac --- /dev/null +++ b/internal/tool/tool_test.go @@ -0,0 +1,409 @@ +package tool + +import ( + "context" + "slices" + "testing" + "time" +) + +func TestToolValidate(t *testing.T) { + tests := []struct { + name string + tool Tool + wantErr bool + }{ + { + name: "valid tool", + tool: Tool{ + Name: "echo", + Description: "echoes input", + Command: "echo", + Arguments: []string{"{{.message}}"}, + }, + wantErr: false, + }, + { + name: "valid tool with timeout", + tool: Tool{ + Name: "slow", + Description: "slow command", + Command: "sleep", + Arguments: []string{"1"}, + Timeout: "5s", + }, + wantErr: false, + }, + { + name: "missing name", + tool: Tool{ + Description: "test", + Command: "echo", + }, + wantErr: true, + }, + { + name: "missing description", + tool: Tool{ + Name: "test", + Command: "echo", + }, + wantErr: true, + }, + { + name: "missing command", + tool: Tool{ + Name: "test", + Description: "test", + }, + wantErr: true, + }, + { + name: "invalid timeout", + tool: Tool{ + Name: "test", + Description: "test", + Command: "echo", + Timeout: "invalid", + }, + wantErr: true, + }, + { + name: "invalid template", + tool: Tool{ + Name: "test", + Description: "test", + Command: "echo", + Arguments: []string{"{{.unclosed"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.tool.Validate() + if (err != nil) != tt.wantErr { + t.Errorf( + "Validate() error = %v, wantErr %v", + err, tt.wantErr, + ) + } + }) + } +} + +func TestToolParseArguments(t *testing.T) { + tests := []struct { + name string + tool Tool + args string + want []string + wantErr bool + }{ + { + name: "simple substitution", + tool: Tool{ + Arguments: []string{"{{.message}}"}, + }, + args: `{"message": "hello"}`, + want: []string{"hello"}, + }, + { + name: "multiple arguments", + tool: Tool{ + Arguments: []string{"-n", "{{.count}}", "{{.file}}"}, + }, + args: `{"count": "10", "file": "test.txt"}`, + want: []string{"-n", "10", "test.txt"}, + }, + { + name: "conditional with if", + tool: Tool{ + Arguments: []string{ + "{{.required}}", "{{if .optional}}{{.optional}}{{end}}", + }, + }, + args: `{"required": "value"}`, + want: []string{"value"}, + }, + { + name: "json function", + tool: Tool{ + Arguments: []string{`{{json .data}}`}, + }, + args: `{"data": {"key": "value"}}`, + want: []string{`{"key":"value"}`}, + }, + { + name: "empty JSON object", + tool: Tool{ + Arguments: []string{"fixed"}, + }, + args: "{}", + want: []string{"fixed"}, + }, + { + name: "empty string args", + tool: Tool{ + Arguments: []string{"fixed"}, + }, + args: "", + want: []string{"fixed"}, + }, + { + name: "invalid JSON", + tool: Tool{ + Arguments: []string{"{{.x}}"}, + }, + args: "not json", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.tool.ParseArguments(tt.args) + if (err != nil) != tt.wantErr { + t.Errorf( + "ParseArguments() error = %v, wantErr %v", + err, tt.wantErr, + ) + return + } + if !tt.wantErr && !slices.Equal(got, tt.want) { + t.Errorf( + "ParseArguments() = %v, want %v", + got, tt.want, + ) + } + }) + } +} + +func TestNewRegistry(t *testing.T) { + tests := []struct { + name string + tools []Tool + wantErr bool + }{ + { + name: "valid registry", + tools: []Tool{ + {Name: "a", Description: "a", Command: "echo"}, + {Name: "b", Description: "b", Command: "cat"}, + }, + wantErr: false, + }, + { + name: "empty registry", + tools: []Tool{}, + wantErr: false, + }, + { + name: "duplicate names", + tools: []Tool{ + { + Name: "dup", + Description: "first", + Command: "echo", + }, + { + Name: "dup", + Description: "second", + Command: "cat", + }, + }, + wantErr: true, + }, + { + name: "invalid tool", + tools: []Tool{ + { + Name: "", + Description: "missing name", + Command: "echo", + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewRegistry(tt.tools) + if (err != nil) != tt.wantErr { + t.Errorf( + "NewRegistry() error = %v, wantErr %v", + err, tt.wantErr, + ) + } + }) + } +} + +func TestRegistryGet(t *testing.T) { + r, err := NewRegistry([]Tool{ + {Name: "echo", Description: "echo", Command: "echo"}, + }) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + tool, ok := r.Get("echo") + if !ok { + t.Error("Get(echo) returned false, want true") + } + if tool.Name != "echo" { + t.Errorf("Get(echo).Name = %q, want %q", tool.Name, "echo") + } + + _, ok = r.Get("nonexistent") + if ok { + t.Error("Get(nonexistent) returned true, want false") + } +} + +func TestRegistryList(t *testing.T) { + r, err := NewRegistry([]Tool{ + {Name: "a", Description: "a", Command: "echo"}, + {Name: "b", Description: "b", Command: "cat"}, + {Name: "c", Description: "c", Command: "ls"}, + }) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + names := r.List() + if len(names) != 3 { + t.Errorf("List() returned %d names, want 3", len(names)) + } + + slices.Sort(names) + want := []string{"a", "b", "c"} + if !slices.Equal(names, want) { + t.Errorf("List() = %v, want %v", names, want) + } +} + +func TestRegistryExecute(t *testing.T) { + r, err := NewRegistry([]Tool{ + { + Name: "echo", + Description: "echo message", + Command: "echo", + Arguments: []string{"{{.message}}"}, + }, + { + Name: "cat", + Description: "cat stdin", + Command: "cat", + Arguments: []string{}, + }, + }) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + ctx := context.Background() + + // Successful execution. + out, err := r.Execute(ctx, "echo", `{"message": "hello world"}`) + if err != nil { + t.Errorf("Execute(echo) error = %v", err) + } + if out != "hello world\n" { + t.Errorf("Execute(echo) = %q, want %q", out, "hello world\n") + } + + // Unknown tool. + _, err = r.Execute(ctx, "unknown", "{}") + if err == nil { + t.Error("Execute(unknown) expected error, got nil") + } + + // Invalid JSON arguments. + _, err = r.Execute(ctx, "echo", "not json") + if err == nil { + t.Error("Execute with invalid JSON expected error, got nil") + } +} + +func TestRegistryExecuteTimeout(t *testing.T) { + r, err := NewRegistry([]Tool{ + { + Name: "slow", + Description: "slow command", + Command: "sleep", + Arguments: []string{"10"}, + Timeout: "50ms", + }, + }) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + ctx := context.Background() + start := time.Now() + _, err = r.Execute(ctx, "slow", "{}") + elapsed := time.Since(start) + + if err == nil { + t.Error("Execute(slow) expected timeout error, got nil") + } + if elapsed > time.Second { + t.Errorf( + "Execute(slow) took %v, expected ~50ms timeout", + elapsed, + ) + } +} + +func TestRegistryExecuteFailure(t *testing.T) { + r, err := NewRegistry([]Tool{ + { + Name: "fail", + Description: "always fails", + Command: "false", + }, + }) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + _, err = r.Execute(context.Background(), "fail", "{}") + if err == nil { + t.Error("Execute(fail) expected error, got nil") + } +} + +func TestToolOpenAI(t *testing.T) { + tool := Tool{ + Name: "test", + Description: "test tool", + Command: "echo", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "the message", + }, + }, + }, + } + + openaiTool := tool.OpenAI() + if openaiTool.Function.Name != "test" { + t.Errorf( + "OpenAI().Function.Name = %q, want %q", + openaiTool.Function.Name, "test", + ) + } + if openaiTool.Function.Description != "test tool" { + t.Errorf( + "OpenAI().Function.Description = %q, want %q", + openaiTool.Function.Description, "test tool", + ) + } +}