410 lines
7.8 KiB
Go
410 lines
7.8 KiB
Go
|
|
package tool
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"slices"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestToolValidate(t *testing.T) {
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
tool Tool
|
||
|
|
wantErr bool
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
name: "valid tool",
|
||
|
|
tool: Tool{
|
||
|
|
Name: "echo",
|
||
|
|
Description: "echoes input",
|
||
|
|
Command: "echo",
|
||
|
|
Arguments: []string{"{{.message}}"},
|
||
|
|
},
|
||
|
|
wantErr: false,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "valid tool with timeout",
|
||
|
|
tool: Tool{
|
||
|
|
Name: "slow",
|
||
|
|
Description: "slow command",
|
||
|
|
Command: "sleep",
|
||
|
|
Arguments: []string{"1"},
|
||
|
|
Timeout: "5s",
|
||
|
|
},
|
||
|
|
wantErr: false,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "missing name",
|
||
|
|
tool: Tool{
|
||
|
|
Description: "test",
|
||
|
|
Command: "echo",
|
||
|
|
},
|
||
|
|
wantErr: true,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "missing description",
|
||
|
|
tool: Tool{
|
||
|
|
Name: "test",
|
||
|
|
Command: "echo",
|
||
|
|
},
|
||
|
|
wantErr: true,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "missing command",
|
||
|
|
tool: Tool{
|
||
|
|
Name: "test",
|
||
|
|
Description: "test",
|
||
|
|
},
|
||
|
|
wantErr: true,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "invalid timeout",
|
||
|
|
tool: Tool{
|
||
|
|
Name: "test",
|
||
|
|
Description: "test",
|
||
|
|
Command: "echo",
|
||
|
|
Timeout: "invalid",
|
||
|
|
},
|
||
|
|
wantErr: true,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "invalid template",
|
||
|
|
tool: Tool{
|
||
|
|
Name: "test",
|
||
|
|
Description: "test",
|
||
|
|
Command: "echo",
|
||
|
|
Arguments: []string{"{{.unclosed"},
|
||
|
|
},
|
||
|
|
wantErr: true,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
err := tt.tool.Validate()
|
||
|
|
if (err != nil) != tt.wantErr {
|
||
|
|
t.Errorf(
|
||
|
|
"Validate() error = %v, wantErr %v",
|
||
|
|
err, tt.wantErr,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestToolParseArguments(t *testing.T) {
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
tool Tool
|
||
|
|
args string
|
||
|
|
want []string
|
||
|
|
wantErr bool
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
name: "simple substitution",
|
||
|
|
tool: Tool{
|
||
|
|
Arguments: []string{"{{.message}}"},
|
||
|
|
},
|
||
|
|
args: `{"message": "hello"}`,
|
||
|
|
want: []string{"hello"},
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "multiple arguments",
|
||
|
|
tool: Tool{
|
||
|
|
Arguments: []string{"-n", "{{.count}}", "{{.file}}"},
|
||
|
|
},
|
||
|
|
args: `{"count": "10", "file": "test.txt"}`,
|
||
|
|
want: []string{"-n", "10", "test.txt"},
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "conditional with if",
|
||
|
|
tool: Tool{
|
||
|
|
Arguments: []string{
|
||
|
|
"{{.required}}", "{{if .optional}}{{.optional}}{{end}}",
|
||
|
|
},
|
||
|
|
},
|
||
|
|
args: `{"required": "value"}`,
|
||
|
|
want: []string{"value"},
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "json function",
|
||
|
|
tool: Tool{
|
||
|
|
Arguments: []string{`{{json .data}}`},
|
||
|
|
},
|
||
|
|
args: `{"data": {"key": "value"}}`,
|
||
|
|
want: []string{`{"key":"value"}`},
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "empty JSON object",
|
||
|
|
tool: Tool{
|
||
|
|
Arguments: []string{"fixed"},
|
||
|
|
},
|
||
|
|
args: "{}",
|
||
|
|
want: []string{"fixed"},
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "empty string args",
|
||
|
|
tool: Tool{
|
||
|
|
Arguments: []string{"fixed"},
|
||
|
|
},
|
||
|
|
args: "",
|
||
|
|
want: []string{"fixed"},
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "invalid JSON",
|
||
|
|
tool: Tool{
|
||
|
|
Arguments: []string{"{{.x}}"},
|
||
|
|
},
|
||
|
|
args: "not json",
|
||
|
|
wantErr: true,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
got, err := tt.tool.ParseArguments(tt.args)
|
||
|
|
if (err != nil) != tt.wantErr {
|
||
|
|
t.Errorf(
|
||
|
|
"ParseArguments() error = %v, wantErr %v",
|
||
|
|
err, tt.wantErr,
|
||
|
|
)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if !tt.wantErr && !slices.Equal(got, tt.want) {
|
||
|
|
t.Errorf(
|
||
|
|
"ParseArguments() = %v, want %v",
|
||
|
|
got, tt.want,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestNewRegistry(t *testing.T) {
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
tools []Tool
|
||
|
|
wantErr bool
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
name: "valid registry",
|
||
|
|
tools: []Tool{
|
||
|
|
{Name: "a", Description: "a", Command: "echo"},
|
||
|
|
{Name: "b", Description: "b", Command: "cat"},
|
||
|
|
},
|
||
|
|
wantErr: false,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "empty registry",
|
||
|
|
tools: []Tool{},
|
||
|
|
wantErr: false,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "duplicate names",
|
||
|
|
tools: []Tool{
|
||
|
|
{
|
||
|
|
Name: "dup",
|
||
|
|
Description: "first",
|
||
|
|
Command: "echo",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
Name: "dup",
|
||
|
|
Description: "second",
|
||
|
|
Command: "cat",
|
||
|
|
},
|
||
|
|
},
|
||
|
|
wantErr: true,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "invalid tool",
|
||
|
|
tools: []Tool{
|
||
|
|
{
|
||
|
|
Name: "",
|
||
|
|
Description: "missing name",
|
||
|
|
Command: "echo",
|
||
|
|
},
|
||
|
|
},
|
||
|
|
wantErr: true,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
_, err := NewRegistry(tt.tools)
|
||
|
|
if (err != nil) != tt.wantErr {
|
||
|
|
t.Errorf(
|
||
|
|
"NewRegistry() error = %v, wantErr %v",
|
||
|
|
err, tt.wantErr,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRegistryGet(t *testing.T) {
|
||
|
|
r, err := NewRegistry([]Tool{
|
||
|
|
{Name: "echo", Description: "echo", Command: "echo"},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("NewRegistry() error = %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
tool, ok := r.Get("echo")
|
||
|
|
if !ok {
|
||
|
|
t.Error("Get(echo) returned false, want true")
|
||
|
|
}
|
||
|
|
if tool.Name != "echo" {
|
||
|
|
t.Errorf("Get(echo).Name = %q, want %q", tool.Name, "echo")
|
||
|
|
}
|
||
|
|
|
||
|
|
_, ok = r.Get("nonexistent")
|
||
|
|
if ok {
|
||
|
|
t.Error("Get(nonexistent) returned true, want false")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRegistryList(t *testing.T) {
|
||
|
|
r, err := NewRegistry([]Tool{
|
||
|
|
{Name: "a", Description: "a", Command: "echo"},
|
||
|
|
{Name: "b", Description: "b", Command: "cat"},
|
||
|
|
{Name: "c", Description: "c", Command: "ls"},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("NewRegistry() error = %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
names := r.List()
|
||
|
|
if len(names) != 3 {
|
||
|
|
t.Errorf("List() returned %d names, want 3", len(names))
|
||
|
|
}
|
||
|
|
|
||
|
|
slices.Sort(names)
|
||
|
|
want := []string{"a", "b", "c"}
|
||
|
|
if !slices.Equal(names, want) {
|
||
|
|
t.Errorf("List() = %v, want %v", names, want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRegistryExecute(t *testing.T) {
|
||
|
|
r, err := NewRegistry([]Tool{
|
||
|
|
{
|
||
|
|
Name: "echo",
|
||
|
|
Description: "echo message",
|
||
|
|
Command: "echo",
|
||
|
|
Arguments: []string{"{{.message}}"},
|
||
|
|
},
|
||
|
|
{
|
||
|
|
Name: "cat",
|
||
|
|
Description: "cat stdin",
|
||
|
|
Command: "cat",
|
||
|
|
Arguments: []string{},
|
||
|
|
},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("NewRegistry() error = %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
ctx := context.Background()
|
||
|
|
|
||
|
|
// Successful execution.
|
||
|
|
out, err := r.Execute(ctx, "echo", `{"message": "hello world"}`)
|
||
|
|
if err != nil {
|
||
|
|
t.Errorf("Execute(echo) error = %v", err)
|
||
|
|
}
|
||
|
|
if out != "hello world\n" {
|
||
|
|
t.Errorf("Execute(echo) = %q, want %q", out, "hello world\n")
|
||
|
|
}
|
||
|
|
|
||
|
|
// Unknown tool.
|
||
|
|
_, err = r.Execute(ctx, "unknown", "{}")
|
||
|
|
if err == nil {
|
||
|
|
t.Error("Execute(unknown) expected error, got nil")
|
||
|
|
}
|
||
|
|
|
||
|
|
// Invalid JSON arguments.
|
||
|
|
_, err = r.Execute(ctx, "echo", "not json")
|
||
|
|
if err == nil {
|
||
|
|
t.Error("Execute with invalid JSON expected error, got nil")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRegistryExecuteTimeout(t *testing.T) {
|
||
|
|
r, err := NewRegistry([]Tool{
|
||
|
|
{
|
||
|
|
Name: "slow",
|
||
|
|
Description: "slow command",
|
||
|
|
Command: "sleep",
|
||
|
|
Arguments: []string{"10"},
|
||
|
|
Timeout: "50ms",
|
||
|
|
},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("NewRegistry() error = %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
ctx := context.Background()
|
||
|
|
start := time.Now()
|
||
|
|
_, err = r.Execute(ctx, "slow", "{}")
|
||
|
|
elapsed := time.Since(start)
|
||
|
|
|
||
|
|
if err == nil {
|
||
|
|
t.Error("Execute(slow) expected timeout error, got nil")
|
||
|
|
}
|
||
|
|
if elapsed > time.Second {
|
||
|
|
t.Errorf(
|
||
|
|
"Execute(slow) took %v, expected ~50ms timeout",
|
||
|
|
elapsed,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRegistryExecuteFailure(t *testing.T) {
|
||
|
|
r, err := NewRegistry([]Tool{
|
||
|
|
{
|
||
|
|
Name: "fail",
|
||
|
|
Description: "always fails",
|
||
|
|
Command: "false",
|
||
|
|
},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("NewRegistry() error = %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
_, err = r.Execute(context.Background(), "fail", "{}")
|
||
|
|
if err == nil {
|
||
|
|
t.Error("Execute(fail) expected error, got nil")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestToolOpenAI(t *testing.T) {
|
||
|
|
tool := Tool{
|
||
|
|
Name: "test",
|
||
|
|
Description: "test tool",
|
||
|
|
Command: "echo",
|
||
|
|
Parameters: map[string]any{
|
||
|
|
"type": "object",
|
||
|
|
"properties": map[string]any{
|
||
|
|
"message": map[string]any{
|
||
|
|
"type": "string",
|
||
|
|
"description": "the message",
|
||
|
|
},
|
||
|
|
},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
openaiTool := tool.OpenAI()
|
||
|
|
if openaiTool.Function.Name != "test" {
|
||
|
|
t.Errorf(
|
||
|
|
"OpenAI().Function.Name = %q, want %q",
|
||
|
|
openaiTool.Function.Name, "test",
|
||
|
|
)
|
||
|
|
}
|
||
|
|
if openaiTool.Function.Description != "test tool" {
|
||
|
|
t.Errorf(
|
||
|
|
"OpenAI().Function.Description = %q, want %q",
|
||
|
|
openaiTool.Function.Description, "test tool",
|
||
|
|
)
|
||
|
|
}
|
||
|
|
}
|