Add tool registry
This commit is contained in:
409
internal/tool/tool_test.go
Normal file
409
internal/tool/tool_test.go
Normal file
@@ -0,0 +1,409 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestToolValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tool Tool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid tool",
|
||||
tool: Tool{
|
||||
Name: "echo",
|
||||
Description: "echoes input",
|
||||
Command: "echo",
|
||||
Arguments: []string{"{{.message}}"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid tool with timeout",
|
||||
tool: Tool{
|
||||
Name: "slow",
|
||||
Description: "slow command",
|
||||
Command: "sleep",
|
||||
Arguments: []string{"1"},
|
||||
Timeout: "5s",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
tool: Tool{
|
||||
Description: "test",
|
||||
Command: "echo",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing description",
|
||||
tool: Tool{
|
||||
Name: "test",
|
||||
Command: "echo",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing command",
|
||||
tool: Tool{
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid timeout",
|
||||
tool: Tool{
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
Command: "echo",
|
||||
Timeout: "invalid",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid template",
|
||||
tool: Tool{
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
Command: "echo",
|
||||
Arguments: []string{"{{.unclosed"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.tool.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf(
|
||||
"Validate() error = %v, wantErr %v",
|
||||
err, tt.wantErr,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolParseArguments(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tool Tool
|
||||
args string
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple substitution",
|
||||
tool: Tool{
|
||||
Arguments: []string{"{{.message}}"},
|
||||
},
|
||||
args: `{"message": "hello"}`,
|
||||
want: []string{"hello"},
|
||||
},
|
||||
{
|
||||
name: "multiple arguments",
|
||||
tool: Tool{
|
||||
Arguments: []string{"-n", "{{.count}}", "{{.file}}"},
|
||||
},
|
||||
args: `{"count": "10", "file": "test.txt"}`,
|
||||
want: []string{"-n", "10", "test.txt"},
|
||||
},
|
||||
{
|
||||
name: "conditional with if",
|
||||
tool: Tool{
|
||||
Arguments: []string{
|
||||
"{{.required}}", "{{if .optional}}{{.optional}}{{end}}",
|
||||
},
|
||||
},
|
||||
args: `{"required": "value"}`,
|
||||
want: []string{"value"},
|
||||
},
|
||||
{
|
||||
name: "json function",
|
||||
tool: Tool{
|
||||
Arguments: []string{`{{json .data}}`},
|
||||
},
|
||||
args: `{"data": {"key": "value"}}`,
|
||||
want: []string{`{"key":"value"}`},
|
||||
},
|
||||
{
|
||||
name: "empty JSON object",
|
||||
tool: Tool{
|
||||
Arguments: []string{"fixed"},
|
||||
},
|
||||
args: "{}",
|
||||
want: []string{"fixed"},
|
||||
},
|
||||
{
|
||||
name: "empty string args",
|
||||
tool: Tool{
|
||||
Arguments: []string{"fixed"},
|
||||
},
|
||||
args: "",
|
||||
want: []string{"fixed"},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
tool: Tool{
|
||||
Arguments: []string{"{{.x}}"},
|
||||
},
|
||||
args: "not json",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.tool.ParseArguments(tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf(
|
||||
"ParseArguments() error = %v, wantErr %v",
|
||||
err, tt.wantErr,
|
||||
)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && !slices.Equal(got, tt.want) {
|
||||
t.Errorf(
|
||||
"ParseArguments() = %v, want %v",
|
||||
got, tt.want,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRegistry(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tools []Tool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid registry",
|
||||
tools: []Tool{
|
||||
{Name: "a", Description: "a", Command: "echo"},
|
||||
{Name: "b", Description: "b", Command: "cat"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty registry",
|
||||
tools: []Tool{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "duplicate names",
|
||||
tools: []Tool{
|
||||
{
|
||||
Name: "dup",
|
||||
Description: "first",
|
||||
Command: "echo",
|
||||
},
|
||||
{
|
||||
Name: "dup",
|
||||
Description: "second",
|
||||
Command: "cat",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tool",
|
||||
tools: []Tool{
|
||||
{
|
||||
Name: "",
|
||||
Description: "missing name",
|
||||
Command: "echo",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewRegistry(tt.tools)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf(
|
||||
"NewRegistry() error = %v, wantErr %v",
|
||||
err, tt.wantErr,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryGet(t *testing.T) {
|
||||
r, err := NewRegistry([]Tool{
|
||||
{Name: "echo", Description: "echo", Command: "echo"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
tool, ok := r.Get("echo")
|
||||
if !ok {
|
||||
t.Error("Get(echo) returned false, want true")
|
||||
}
|
||||
if tool.Name != "echo" {
|
||||
t.Errorf("Get(echo).Name = %q, want %q", tool.Name, "echo")
|
||||
}
|
||||
|
||||
_, ok = r.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("Get(nonexistent) returned true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryList(t *testing.T) {
|
||||
r, err := NewRegistry([]Tool{
|
||||
{Name: "a", Description: "a", Command: "echo"},
|
||||
{Name: "b", Description: "b", Command: "cat"},
|
||||
{Name: "c", Description: "c", Command: "ls"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
names := r.List()
|
||||
if len(names) != 3 {
|
||||
t.Errorf("List() returned %d names, want 3", len(names))
|
||||
}
|
||||
|
||||
slices.Sort(names)
|
||||
want := []string{"a", "b", "c"}
|
||||
if !slices.Equal(names, want) {
|
||||
t.Errorf("List() = %v, want %v", names, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryExecute(t *testing.T) {
|
||||
r, err := NewRegistry([]Tool{
|
||||
{
|
||||
Name: "echo",
|
||||
Description: "echo message",
|
||||
Command: "echo",
|
||||
Arguments: []string{"{{.message}}"},
|
||||
},
|
||||
{
|
||||
Name: "cat",
|
||||
Description: "cat stdin",
|
||||
Command: "cat",
|
||||
Arguments: []string{},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Successful execution.
|
||||
out, err := r.Execute(ctx, "echo", `{"message": "hello world"}`)
|
||||
if err != nil {
|
||||
t.Errorf("Execute(echo) error = %v", err)
|
||||
}
|
||||
if out != "hello world\n" {
|
||||
t.Errorf("Execute(echo) = %q, want %q", out, "hello world\n")
|
||||
}
|
||||
|
||||
// Unknown tool.
|
||||
_, err = r.Execute(ctx, "unknown", "{}")
|
||||
if err == nil {
|
||||
t.Error("Execute(unknown) expected error, got nil")
|
||||
}
|
||||
|
||||
// Invalid JSON arguments.
|
||||
_, err = r.Execute(ctx, "echo", "not json")
|
||||
if err == nil {
|
||||
t.Error("Execute with invalid JSON expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryExecuteTimeout(t *testing.T) {
|
||||
r, err := NewRegistry([]Tool{
|
||||
{
|
||||
Name: "slow",
|
||||
Description: "slow command",
|
||||
Command: "sleep",
|
||||
Arguments: []string{"10"},
|
||||
Timeout: "50ms",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
start := time.Now()
|
||||
_, err = r.Execute(ctx, "slow", "{}")
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Execute(slow) expected timeout error, got nil")
|
||||
}
|
||||
if elapsed > time.Second {
|
||||
t.Errorf(
|
||||
"Execute(slow) took %v, expected ~50ms timeout",
|
||||
elapsed,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryExecuteFailure(t *testing.T) {
|
||||
r, err := NewRegistry([]Tool{
|
||||
{
|
||||
Name: "fail",
|
||||
Description: "always fails",
|
||||
Command: "false",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = r.Execute(context.Background(), "fail", "{}")
|
||||
if err == nil {
|
||||
t.Error("Execute(fail) expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolOpenAI(t *testing.T) {
|
||||
tool := Tool{
|
||||
Name: "test",
|
||||
Description: "test tool",
|
||||
Command: "echo",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "the message",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
openaiTool := tool.OpenAI()
|
||||
if openaiTool.Function.Name != "test" {
|
||||
t.Errorf(
|
||||
"OpenAI().Function.Name = %q, want %q",
|
||||
openaiTool.Function.Name, "test",
|
||||
)
|
||||
}
|
||||
if openaiTool.Function.Description != "test tool" {
|
||||
t.Errorf(
|
||||
"OpenAI().Function.Description = %q, want %q",
|
||||
openaiTool.Function.Description, "test tool",
|
||||
)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user