Add answer worker
Subsystem for message processing: parses messages, generates LLM responses, and replies with SMTP. Introduces: - answer: message processing worker. - llm: OpenAI API compatible client with support for tool execution. - message: message parsing and response logic. - tool: converts YAML configuration into executable subprocesses. - smtp: simple config and client wrapper for sending email.
This commit is contained in:
263
internal/answer/answer.go
Normal file
263
internal/answer/answer.go
Normal file
@@ -0,0 +1,263 @@
|
||||
// Package answer implements the email response pipeline.
|
||||
// Workers receive message UIDs, fetch and parse messages, query an LLM
|
||||
// for responses, and send replies via SMTP.
|
||||
package answer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"raven/internal/filter"
|
||||
"raven/internal/imap"
|
||||
"raven/internal/llm"
|
||||
"raven/internal/message"
|
||||
"raven/internal/smtp"
|
||||
"raven/internal/tracker"
|
||||
|
||||
goimap "github.com/emersion/go-imap/v2"
|
||||
"github.com/emersion/go-message/mail"
|
||||
)
|
||||
|
||||
const fallbackResponse = "Your message was received, but I was unable to generate a response. Please try again later or contact the administrator."
|
||||
|
||||
// Worker processes incoming messages through the response pipeline.
|
||||
// Each worker maintains its own IMAP connection and processes UIDs
|
||||
// received from a shared work channel. The pipeline:
|
||||
// 1. Fetch message from IMAP
|
||||
// 2. Check sender against allowlist
|
||||
// 3. Query LLM for response
|
||||
// 4. Send reply via SMTP
|
||||
// 5. Mark original message as seen
|
||||
type Worker struct {
|
||||
// from is the sender address for outgoing replies.
|
||||
from *mail.Address
|
||||
// ic is the IMAP client for fetching messages and marking seen.
|
||||
ic *imap.Client
|
||||
// id identifies this worker in logs.
|
||||
id int
|
||||
// filters determines which senders are allowed.
|
||||
filters filter.Filters
|
||||
// llm generates response content.
|
||||
llm *llm.Client
|
||||
// log is the worker's logger with worker ID context.
|
||||
log *slog.Logger
|
||||
// smtp sends composed replies.
|
||||
smtp smtp.SMTP
|
||||
// tracker prevents duplicate processing across workers.
|
||||
tracker *tracker.Tracker
|
||||
// work receives UIDs to process.
|
||||
work <-chan goimap.UID
|
||||
}
|
||||
|
||||
// NewWorker creates a Worker with its own IMAP connection.
|
||||
// The from address is parsed from smtp.From for reply composition.
|
||||
func NewWorker(
|
||||
id int,
|
||||
filters filter.Filters,
|
||||
imapConfig imap.Config,
|
||||
smtp smtp.SMTP,
|
||||
llm *llm.Client,
|
||||
tracker *tracker.Tracker,
|
||||
log *slog.Logger,
|
||||
work <-chan goimap.UID,
|
||||
) (*Worker, error) {
|
||||
ic, err := imap.NewClient(imapConfig, filters, log)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create imap client: %v", err)
|
||||
}
|
||||
|
||||
from, err := mail.ParseAddress(smtp.From)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse from: %v", err)
|
||||
}
|
||||
|
||||
return &Worker{
|
||||
from: from,
|
||||
ic: ic,
|
||||
id: id,
|
||||
filters: filters,
|
||||
smtp: smtp,
|
||||
llm: llm,
|
||||
log: log.With(slog.String(
|
||||
"worker", fmt.Sprintf("answer[%d]", id),
|
||||
)),
|
||||
tracker: tracker,
|
||||
work: work,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run processes UIDs from the work channel until ctx is canceled
|
||||
// or the channel is closed. Each UID is processed independently;
|
||||
// errors are logged but do not stop the worker.
|
||||
func (w *Worker) Run(ctx context.Context) {
|
||||
defer w.log.Info("worker stopped", slog.Int("id", w.id))
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := ctx.Err(); err != nil {
|
||||
w.log.InfoContext(
|
||||
ctx,
|
||||
"context closed",
|
||||
slog.Any("error", err),
|
||||
)
|
||||
}
|
||||
return
|
||||
|
||||
case uid, ok := <-w.work:
|
||||
if !ok {
|
||||
w.log.InfoContext(
|
||||
ctx,
|
||||
"channel closed",
|
||||
)
|
||||
return
|
||||
}
|
||||
w.respond(ctx, uid)
|
||||
w.tracker.Release(uid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// respond handles a single message through the full pipeline.
|
||||
func (w *Worker) respond(ctx context.Context, uid goimap.UID) {
|
||||
defer w.ic.Disconnect(ctx)
|
||||
w.log.InfoContext(ctx, "processing", slog.Any("uid", uid))
|
||||
|
||||
// Connect to the IMAP server and retrieve the message.
|
||||
if err := w.ic.Connect(ctx, w.log, nil); err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"failed to connect IMAP client",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
return
|
||||
}
|
||||
fetchedMail, err := w.ic.Fetch(uid, true)
|
||||
if err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"failed to fetch message",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := message.New(fetchedMail, w.log)
|
||||
if err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"skipping: failed to parse message",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
if err := w.ic.MarkSeen(uid); err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"failed to mark message as seen",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Enforce allowlist.
|
||||
if !w.filters.MatchSender(msg.Envelope.From[0].Addr()) {
|
||||
w.log.InfoContext(
|
||||
ctx,
|
||||
"skipping: sender not in allowlist",
|
||||
slog.String("from", msg.Envelope.From[0].Addr()),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
if err := w.ic.MarkSeen(uid); err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"failed to mark message as seen",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Disconnect during LLM query to avoid IMAP timeout.
|
||||
// LLM queries can take significant time depending on model and load.
|
||||
w.ic.Disconnect(ctx)
|
||||
|
||||
res, err := w.llm.Query(ctx, msg)
|
||||
if err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"failed to query LLM",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
res = fallbackResponse
|
||||
}
|
||||
|
||||
reply, err := msg.ComposeReply(time.Now(), w.from, res)
|
||||
if err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"failed to compose reply",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
return
|
||||
}
|
||||
to, err := reply.Recipients()
|
||||
if err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"failed to retrieve reply recipients",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
return
|
||||
}
|
||||
body, err := reply.Bytes()
|
||||
if err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"failed to retrieve reply body",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.smtp.Send(to, body); err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"failed to reply",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Reconnect to flag the message as processed.
|
||||
if err := w.ic.Connect(ctx, w.log, nil); err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"failed to connect IMAP client",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := w.ic.MarkSeen(uid); err != nil {
|
||||
w.log.ErrorContext(
|
||||
ctx,
|
||||
"failed to mark message as seen",
|
||||
slog.Any("error", err),
|
||||
slog.Any("uid", uid),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
w.log.InfoContext(ctx, "completed", slog.Any("uid", uid))
|
||||
}
|
||||
207
internal/llm/llm.go
Normal file
207
internal/llm/llm.go
Normal file
@@ -0,0 +1,207 @@
|
||||
// Package llm provides an OpenAI-compatible client for LLM interactions.
|
||||
// It handles chat completions with automatic tool call execution.
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"raven/internal/message"
|
||||
"raven/internal/tool"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// Config holds the configuration for an LLM client.
|
||||
type Config struct {
|
||||
// Key is the API key for authentication.
|
||||
Key string `yaml:"key"`
|
||||
// Model is the model identifier (e.g., "gpt-oss-120b", "qwen3-32b").
|
||||
Model string `yaml:"model"`
|
||||
// SystemPrompt is prepended to all conversations.
|
||||
SystemPrompt string `yaml:"system_prompt"`
|
||||
// Timeout is the maximum duration for a query (e.g., "15m").
|
||||
// Defaults to 15 minutes if empty.
|
||||
Timeout string `yaml:"timeout"`
|
||||
// URL is the base URL of the OpenAI-compatible API endpoint.
|
||||
URL string `yaml:"url"`
|
||||
}
|
||||
|
||||
// Validate checks that required configuration values are present and valid.
|
||||
func (cfg Config) Validate() error {
|
||||
if cfg.Model == "" {
|
||||
return fmt.Errorf("missing model")
|
||||
}
|
||||
if cfg.Timeout != "" {
|
||||
if _, err := time.ParseDuration(cfg.Timeout); err != nil {
|
||||
return fmt.Errorf("invalid duration")
|
||||
}
|
||||
}
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("missing URL")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Client wraps an OpenAI-compatible client with tool execution support.
|
||||
type Client struct {
|
||||
client *openai.Client
|
||||
log *slog.Logger
|
||||
model string
|
||||
registry *tool.Registry
|
||||
systemPrompt string
|
||||
timeout time.Duration
|
||||
tools []openai.Tool
|
||||
}
|
||||
|
||||
// NewClient creates a new LLM client with the provided configuration.
|
||||
// The registry is optional; if nil, tool calling is disabled.
|
||||
func NewClient(
|
||||
cfg Config,
|
||||
registry *tool.Registry,
|
||||
log *slog.Logger,
|
||||
) (*Client, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %v", err)
|
||||
}
|
||||
var llm = &Client{
|
||||
log: log,
|
||||
model: cfg.Model,
|
||||
systemPrompt: cfg.SystemPrompt,
|
||||
registry: registry,
|
||||
}
|
||||
if cfg.Timeout == "" {
|
||||
llm.timeout = time.Duration(15 * time.Minute)
|
||||
} else {
|
||||
d, err := time.ParseDuration(cfg.Timeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse timeout: %v", err)
|
||||
}
|
||||
llm.timeout = d
|
||||
}
|
||||
|
||||
// Setup client.
|
||||
clientConfig := openai.DefaultConfig(cfg.Key)
|
||||
clientConfig.BaseURL = cfg.URL
|
||||
llm.client = openai.NewClientWithConfig(clientConfig)
|
||||
|
||||
// Parse tools.
|
||||
if llm.registry != nil {
|
||||
for _, name := range llm.registry.List() {
|
||||
tool, _ := llm.registry.Get(name)
|
||||
llm.tools = append(llm.tools, tool.OpenAI())
|
||||
}
|
||||
}
|
||||
|
||||
return llm, nil
|
||||
}
|
||||
|
||||
// Query sends a message to the LLM and returns its response.
|
||||
// It automatically executes any tool calls requested by the model,
|
||||
// looping until the model returns a final text response.
|
||||
// Supports multimodal content (text and images) via the Message type.
|
||||
func (c *Client) Query(
|
||||
ctx context.Context, msg *message.Message,
|
||||
) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Build user message from email content.
|
||||
// Uses MultiContent to support both text and image parts.
|
||||
userMessage := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
MultiContent: msg.ToOpenAIMessages(),
|
||||
}
|
||||
|
||||
// Set the system message and the user prompt.
|
||||
messages := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: c.systemPrompt,
|
||||
},
|
||||
userMessage,
|
||||
}
|
||||
|
||||
// Loop for tool calls.
|
||||
for {
|
||||
req := openai.ChatCompletionRequest{
|
||||
Model: c.model,
|
||||
Messages: messages,
|
||||
}
|
||||
if len(c.tools) > 0 {
|
||||
req.Tools = c.tools
|
||||
}
|
||||
|
||||
res, err := c.client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("chat completion: %w", err)
|
||||
}
|
||||
if len(res.Choices) == 0 {
|
||||
return "", fmt.Errorf("no response choices returned")
|
||||
}
|
||||
|
||||
choice := res.Choices[0]
|
||||
message := choice.Message
|
||||
|
||||
// If no tool calls, we're done.
|
||||
if len(message.ToolCalls) == 0 {
|
||||
return message.Content, nil
|
||||
}
|
||||
|
||||
// Add assistant message with tool calls to history.
|
||||
messages = append(messages, message)
|
||||
|
||||
// Process each tool call.
|
||||
for _, tc := range message.ToolCalls {
|
||||
c.log.InfoContext(
|
||||
ctx,
|
||||
"calling tool",
|
||||
slog.String("name", tc.Function.Name),
|
||||
slog.String("args", tc.Function.Arguments),
|
||||
)
|
||||
|
||||
res, err := c.registry.Execute(
|
||||
ctx, tc.Function.Name, tc.Function.Arguments,
|
||||
)
|
||||
if err != nil {
|
||||
// Return error to LLM so it can recover.
|
||||
c.log.Error(
|
||||
"failed to call tool",
|
||||
slog.Any("error", err),
|
||||
slog.String("name", tc.Function.Name),
|
||||
)
|
||||
// Assume JSON is a more helpful response
|
||||
// for an LLM.
|
||||
res = fmt.Sprintf(
|
||||
`{"ok": false,"error": %q}`, err,
|
||||
)
|
||||
} else {
|
||||
c.log.Info(
|
||||
"called tool",
|
||||
slog.String("name", tc.Function.Name),
|
||||
)
|
||||
}
|
||||
|
||||
// Content cannot be empty.
|
||||
// Assume JSON is better than OK or (no output).
|
||||
if strings.TrimSpace(res) == "" {
|
||||
res = `{"ok":true,"result":null}`
|
||||
}
|
||||
|
||||
// Add tool result to messages.
|
||||
messages = append(
|
||||
messages,
|
||||
openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleTool,
|
||||
Content: res,
|
||||
ToolCallID: tc.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
// Loop to get LLM's response after tool execution.
|
||||
}
|
||||
}
|
||||
91
internal/llm/llm_test.go
Normal file
91
internal/llm/llm_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty config",
|
||||
cfg: Config{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing model",
|
||||
cfg: Config{
|
||||
URL: "https://api.example.com",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing URL",
|
||||
cfg: Config{
|
||||
Model: "gpt-4o",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid timeout",
|
||||
cfg: Config{
|
||||
Model: "gpt-4o",
|
||||
URL: "https://api.example.com",
|
||||
Timeout: "not-a-duration",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid minimal config",
|
||||
cfg: Config{
|
||||
Model: "gpt-4o",
|
||||
URL: "https://api.example.com",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with timeout",
|
||||
cfg: Config{
|
||||
Model: "gpt-4o",
|
||||
URL: "https://api.example.com",
|
||||
Timeout: "15m",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with complex timeout",
|
||||
cfg: Config{
|
||||
Model: "gpt-4o",
|
||||
URL: "https://api.example.com",
|
||||
Timeout: "1h30m45s",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid full config",
|
||||
cfg: Config{
|
||||
Key: "sk-test-key",
|
||||
Model: "gpt-4o",
|
||||
SystemPrompt: "You are a helpful assistant.",
|
||||
Timeout: "30m",
|
||||
URL: "https://api.example.com",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.cfg.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf(
|
||||
"Validate() error = %v, wantErr %v",
|
||||
err, tt.wantErr,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
529
internal/message/message.go
Normal file
529
internal/message/message.go
Normal file
@@ -0,0 +1,529 @@
|
||||
// Package message handles email message parsing and reply composition.
|
||||
package message
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/emersion/go-imap/v2"
|
||||
"github.com/emersion/go-imap/v2/imapclient"
|
||||
_ "github.com/emersion/go-message/charset"
|
||||
"github.com/emersion/go-message/mail"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// RFC5322 defines the date format specified in RFC 5322 §3.3.
|
||||
const RFC5322 = "Mon, 2 Jan 2006 15:04:05 -0700"
|
||||
|
||||
// Supported image MIME types for vision models.
|
||||
var supportedImageTypes = map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/gif": true,
|
||||
"image/webp": true,
|
||||
}
|
||||
|
||||
// maxPartSize is the maximum size in bytes for a single MIME part (32MB).
|
||||
const maxPartSize = 32 << 20
|
||||
|
||||
// Part represents a single piece of message content.
|
||||
type Part struct {
|
||||
// Content holds text content for text parts, empty for images.
|
||||
Content string
|
||||
// ContentType is the MIME type (e.g., "text/plain", "image/png").
|
||||
ContentType string
|
||||
// Data holds raw bytes for binary content like images.
|
||||
Data []byte
|
||||
// Filename is set for attachment parts.
|
||||
Filename string
|
||||
// IsAttachment distinguishes attachments from inline content.
|
||||
IsAttachment bool
|
||||
}
|
||||
|
||||
// IsImage returns true if the part is a supported image type.
|
||||
func (p *Part) IsImage() bool {
|
||||
return supportedImageTypes[p.ContentType]
|
||||
}
|
||||
|
||||
// IsText returns true if the part contains text content.
|
||||
func (p *Part) IsText() bool {
|
||||
return strings.HasPrefix(p.ContentType, "text/")
|
||||
}
|
||||
|
||||
// Message represents a parsed email with its metadata and content.
|
||||
type Message struct {
|
||||
// Attachments are parts with Content-Disposition: attachment.
|
||||
// Stored separately to allow appending after inline content.
|
||||
Attachments []Part
|
||||
Envelope *imap.Envelope
|
||||
// Parts contains inline content in order of appearance.
|
||||
Parts []Part
|
||||
// References are Message-IDs from the References header,
|
||||
// used for threading.
|
||||
References []string
|
||||
UID imap.UID
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func (msg *Message) TextFrom() string {
|
||||
var str strings.Builder
|
||||
if msg.Envelope == nil || len(msg.Envelope.From) == 0 {
|
||||
return ""
|
||||
}
|
||||
str.WriteString("From: ")
|
||||
for i, a := range msg.Envelope.From {
|
||||
if a.Name == "" {
|
||||
fmt.Fprintf(&str, "%s", a.Addr())
|
||||
} else {
|
||||
fmt.Fprintf(
|
||||
&str, "%s <%s>",
|
||||
a.Name, a.Addr(),
|
||||
)
|
||||
}
|
||||
if i+1 < len(msg.Envelope.From) {
|
||||
str.WriteString(", ")
|
||||
}
|
||||
}
|
||||
|
||||
return str.String()
|
||||
}
|
||||
|
||||
// TextBody returns the concatenated text content from all text parts.
|
||||
// Used for reply composition and quoting.
|
||||
func (msg *Message) TextBody() string {
|
||||
var sb strings.Builder
|
||||
first := true
|
||||
for _, p := range msg.Parts {
|
||||
if p.IsText() {
|
||||
if !first {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString(p.Content)
|
||||
first = false
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// ToOpenAIMessages converts the message content to OpenAI chat message parts.
|
||||
// Inline parts appear first in order, followed by attachments.
|
||||
// Text parts become text content, supported images become image_url content.
|
||||
func (msg *Message) ToOpenAIMessages() []openai.ChatMessagePart {
|
||||
var parts = []openai.ChatMessagePart{}
|
||||
|
||||
if msg != nil && msg.Envelope != nil {
|
||||
var str strings.Builder
|
||||
if v := msg.TextFrom(); v != "" {
|
||||
fmt.Fprintf(&str, "%s\n", v)
|
||||
}
|
||||
if !msg.Envelope.Date.IsZero() {
|
||||
fmt.Fprintf(
|
||||
&str,
|
||||
"Date: %s\n",
|
||||
msg.Envelope.Date.Format(time.RFC3339),
|
||||
)
|
||||
}
|
||||
if msg.Envelope.Subject != "" {
|
||||
fmt.Fprintf(&str, "Subject: %v\n", msg.Envelope.Subject)
|
||||
}
|
||||
if v := str.String(); v != "" {
|
||||
parts = append(
|
||||
parts,
|
||||
openai.ChatMessagePart{
|
||||
Type: openai.ChatMessagePartTypeText,
|
||||
Text: v,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Process inline parts first, preserving order.
|
||||
for _, p := range msg.Parts {
|
||||
if part, ok := convertPart(p); ok {
|
||||
parts = append(parts, part)
|
||||
}
|
||||
}
|
||||
|
||||
// Append attachments at the end.
|
||||
for _, p := range msg.Attachments {
|
||||
if part, ok := convertPart(p); ok {
|
||||
parts = append(parts, part)
|
||||
}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// convertPart converts a Part to an OpenAI ChatMessagePart.
|
||||
// Returns false if the part type is not supported for LLM input.
|
||||
func convertPart(p Part) (openai.ChatMessagePart, bool) {
|
||||
switch {
|
||||
case p.IsText():
|
||||
return openai.ChatMessagePart{
|
||||
Type: openai.ChatMessagePartTypeText,
|
||||
Text: p.Content,
|
||||
}, true
|
||||
case p.IsImage():
|
||||
dataURI := fmt.Sprintf(
|
||||
"data:%s;base64,%s",
|
||||
p.ContentType,
|
||||
base64.StdEncoding.EncodeToString(p.Data),
|
||||
)
|
||||
return openai.ChatMessagePart{
|
||||
Type: openai.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &openai.ChatMessageImageURL{
|
||||
URL: dataURI,
|
||||
Detail: openai.ImageURLDetailAuto,
|
||||
},
|
||||
}, true
|
||||
default:
|
||||
return openai.ChatMessagePart{}, false
|
||||
}
|
||||
}
|
||||
|
||||
// composeAttribution builds the attribution line for quoted replies.
|
||||
// Returns sender and timestamp in a standard format like:
|
||||
// "On Mon, 2 Jan 2006 15:04:05 -0700, raven <raven@example.com> wrote:"
|
||||
func (msg *Message) composeAttribution() string {
|
||||
if len(msg.Envelope.From) == 0 {
|
||||
return "> \n"
|
||||
}
|
||||
|
||||
from := msg.Envelope.From[0]
|
||||
sender := from.Addr()
|
||||
if from.Name != "" {
|
||||
sender = fmt.Sprintf("%s <%s>", from.Name, from.Addr())
|
||||
}
|
||||
|
||||
if msg.Envelope.Date.IsZero() {
|
||||
return fmt.Sprintf("%s wrote:\n", sender)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"On %s, %s wrote:\n",
|
||||
msg.Envelope.Date.Format(RFC5322),
|
||||
sender,
|
||||
)
|
||||
}
|
||||
|
||||
// composeBody assembles the reply body: user's response, attribution line,
|
||||
// and quoted original message.
|
||||
func (msg *Message) composeBody(res string) string {
|
||||
var body strings.Builder
|
||||
body.WriteString(strings.TrimRight(res, "\n\r \t"))
|
||||
body.WriteString("\n\n")
|
||||
body.WriteString(msg.composeAttribution())
|
||||
body.WriteString(msg.QuotedBody())
|
||||
|
||||
return body.String()
|
||||
}
|
||||
|
||||
// composeHeader builds RFC 5322-compliant headers for a reply.
|
||||
// Sets From, To, Subject, Date, Message-ID, In-Reply-To, and References.
|
||||
func (msg *Message) composeHeader(
|
||||
date time.Time, from *mail.Address,
|
||||
) (*mail.Header, error) {
|
||||
h := &mail.Header{}
|
||||
h.SetDate(date)
|
||||
h.SetContentType("text/plain", map[string]string{"charset": "utf-8"})
|
||||
h.SetAddressList("From", []*mail.Address{from})
|
||||
h.SetAddressList("Reply-To", []*mail.Address{from})
|
||||
|
||||
to := msg.composeRecipients()
|
||||
if len(to) == 0 {
|
||||
return nil, errors.New("missing recipients")
|
||||
}
|
||||
h.SetAddressList("To", to)
|
||||
h.SetSubject(msg.composeSubject())
|
||||
|
||||
// Use sender's domain for Message-ID per RFC 5322 recommendation.
|
||||
parts := strings.SplitN(from.Address, "@", 2)
|
||||
if len(parts) == 2 {
|
||||
if err := h.GenerateMessageIDWithHostname(parts[1]); err != nil {
|
||||
return nil, fmt.Errorf("generate message id: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := h.GenerateMessageID(); err != nil {
|
||||
return nil, fmt.Errorf("generate message id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
inReplyTo, refs := msg.composeReferences()
|
||||
if inReplyTo != "" {
|
||||
h.SetMsgIDList("In-Reply-To", []string{inReplyTo})
|
||||
}
|
||||
if len(refs) > 0 {
|
||||
h.SetMsgIDList("References", refs)
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// composeReferences builds threading headers per RFC 5322 §3.6.4.
|
||||
// In-Reply-To contains the parent's Message-ID.
|
||||
// References contains the full thread ancestry.
|
||||
func (msg *Message) composeReferences() (inReplyTo string, refs []string) {
|
||||
if msg.Envelope.MessageID != "" {
|
||||
inReplyTo = msg.Envelope.MessageID
|
||||
refs = append([]string(nil), msg.References...)
|
||||
refs = append(refs, msg.Envelope.MessageID)
|
||||
} else {
|
||||
refs = append([]string(nil), msg.References...)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// composeRecipients determines the To address for a reply.
|
||||
// Uses Reply-To if present, otherwise From (per RFC 5322 §3.6.2).
|
||||
func (msg *Message) composeRecipients() []*mail.Address {
|
||||
src := msg.Envelope.ReplyTo
|
||||
if len(src) == 0 {
|
||||
src = msg.Envelope.From
|
||||
}
|
||||
to := make([]*mail.Address, 0, len(src))
|
||||
for _, v := range src {
|
||||
to = append(to, &mail.Address{
|
||||
Name: v.Name,
|
||||
Address: v.Addr(),
|
||||
})
|
||||
}
|
||||
return to
|
||||
}
|
||||
|
||||
// composeSubject prepends "Re: " if not already present (case-insensitive).
|
||||
func (msg *Message) composeSubject() string {
|
||||
s := msg.Envelope.Subject
|
||||
if !strings.HasPrefix(strings.ToLower(s), "re:") {
|
||||
s = "Re: " + s
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// ComposeReply creates a reply to this message.
|
||||
// The reply includes proper threading headers and the original message quoted.
|
||||
func (msg *Message) ComposeReply(
|
||||
date time.Time, from *mail.Address, res string,
|
||||
) (*Reply, error) {
|
||||
if msg == nil || msg.Envelope == nil {
|
||||
return nil, errors.New("missing envelope")
|
||||
}
|
||||
if from == nil || from.Address == "" {
|
||||
return nil, errors.New("missing from address")
|
||||
}
|
||||
|
||||
header, err := msg.composeHeader(date, from)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compose header: %v", err)
|
||||
}
|
||||
|
||||
return &Reply{body: msg.composeBody(res), header: header}, nil
|
||||
}
|
||||
|
||||
// QuotedBody returns the message text body with each line prefixed by "> ".
|
||||
func (msg *Message) QuotedBody() string {
|
||||
var quoted strings.Builder
|
||||
for line := range strings.SplitSeq(msg.TextBody(), "\n") {
|
||||
line = strings.TrimSuffix(line, "\r")
|
||||
quoted.WriteString("> ")
|
||||
quoted.WriteString(line)
|
||||
quoted.WriteString("\n")
|
||||
}
|
||||
return quoted.String()
|
||||
}
|
||||
|
||||
// New creates a Message from an IMAP fetch buffer.
|
||||
// Handles multipart messages by collecting text parts and images inline,
|
||||
// with attachments stored separately for appending later.
|
||||
// Logs skipped parts and non-fatal errors.
|
||||
func New(mb *imapclient.FetchMessageBuffer, log *slog.Logger) (*Message, error) {
|
||||
if mb == nil {
|
||||
return nil, errors.New("nil message buffer")
|
||||
}
|
||||
if mb.UID == 0 {
|
||||
return nil, errors.New("message has no UID")
|
||||
}
|
||||
if mb.Envelope == nil {
|
||||
return nil, errors.New("message has no envelope")
|
||||
}
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
|
||||
msg := &Message{
|
||||
UID: mb.UID,
|
||||
Envelope: mb.Envelope,
|
||||
log: log,
|
||||
}
|
||||
|
||||
// Try each body section until we successfully parse one.
|
||||
var parseErr error
|
||||
for _, section := range mb.BodySection {
|
||||
if len(section.Bytes) == 0 {
|
||||
continue
|
||||
}
|
||||
if err := parseBody(
|
||||
msg,
|
||||
bytes.NewReader(section.Bytes),
|
||||
log,
|
||||
); err != nil {
|
||||
parseErr = err
|
||||
continue
|
||||
}
|
||||
if len(msg.Parts) > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Failed to parse any content — return an error.
|
||||
if len(msg.Parts) == 0 && parseErr != nil {
|
||||
return msg, fmt.Errorf("parse body: %w", parseErr)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// parseBody extracts content from a MIME message body.
|
||||
// mail.Reader automatically flattens nested multipart structures, returning
|
||||
// only leaf parts (text/plain, attachments, etc.).
|
||||
func parseBody(msg *Message, r io.Reader, log *slog.Logger) error {
|
||||
reader, err := mail.CreateReader(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create reader: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Extract References header for threading.
|
||||
refs, err := reader.Header.MsgIDList("References")
|
||||
if err == nil && len(refs) > 0 {
|
||||
msg.References = refs
|
||||
}
|
||||
|
||||
// Process all parts.
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("next part: %w", err)
|
||||
}
|
||||
|
||||
if err := processPart(msg, part, log); err != nil {
|
||||
log.Debug("skipped part",
|
||||
slog.Any("error", err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processPart handles a single MIME part returned by mail.Reader.
|
||||
// Inline text and images are added to Parts; attachments go to Attachments.
|
||||
func processPart(msg *Message, part *mail.Part, log *slog.Logger) error {
|
||||
switch h := part.Header.(type) {
|
||||
case *mail.InlineHeader:
|
||||
ct, _, err := h.Header.ContentType()
|
||||
if err != nil {
|
||||
ct = "text/plain"
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(part.Body, maxPartSize))
|
||||
if err != nil {
|
||||
return fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(ct, "text/"):
|
||||
msg.Parts = append(msg.Parts, Part{
|
||||
Content: string(body),
|
||||
ContentType: ct,
|
||||
})
|
||||
case supportedImageTypes[ct]:
|
||||
msg.Parts = append(msg.Parts, Part{
|
||||
ContentType: ct,
|
||||
Data: body,
|
||||
})
|
||||
default:
|
||||
log.Debug("skipped unsupported inline content type",
|
||||
slog.String("content_type", ct),
|
||||
)
|
||||
}
|
||||
|
||||
case *mail.AttachmentHeader:
|
||||
filename, _ := h.Filename()
|
||||
ct, _, _ := h.Header.ContentType()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(part.Body, maxPartSize))
|
||||
if err != nil {
|
||||
return fmt.Errorf("read attachment: %w", err)
|
||||
}
|
||||
|
||||
// Only store attachments we can use (text or images).
|
||||
switch {
|
||||
case strings.HasPrefix(ct, "text/"):
|
||||
msg.Attachments = append(msg.Attachments, Part{
|
||||
Content: string(body),
|
||||
ContentType: ct,
|
||||
Filename: filename,
|
||||
IsAttachment: true,
|
||||
})
|
||||
case supportedImageTypes[ct]:
|
||||
msg.Attachments = append(msg.Attachments, Part{
|
||||
ContentType: ct,
|
||||
Data: body,
|
||||
Filename: filename,
|
||||
IsAttachment: true,
|
||||
})
|
||||
default:
|
||||
log.Debug("skipped unsupported attachment type",
|
||||
slog.String("content_type", ct),
|
||||
slog.String("filename", filename),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reply holds a composed reply ready for sending.
|
||||
type Reply struct {
|
||||
body string
|
||||
header *mail.Header
|
||||
}
|
||||
|
||||
// Bytes serializes the reply to RFC 5322 wire format.
|
||||
func (r *Reply) Bytes() ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
mw, err := mail.CreateSingleInlineWriter(&buf, *r.header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create writer: %w", err)
|
||||
}
|
||||
if _, err := mw.Write([]byte(r.body)); err != nil {
|
||||
return nil, fmt.Errorf("write body: %w", err)
|
||||
}
|
||||
if err := mw.Close(); err != nil {
|
||||
return nil, fmt.Errorf("close writer: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Recipients returns the To addresses as formatted strings.
|
||||
func (r *Reply) Recipients() ([]string, error) {
|
||||
addrs, err := r.header.AddressList("To")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("address list: %w", err)
|
||||
}
|
||||
to := make([]string, len(addrs))
|
||||
for i, v := range addrs {
|
||||
to[i] = v.Address
|
||||
}
|
||||
return to, nil
|
||||
}
|
||||
1352
internal/message/message_test.go
Normal file
1352
internal/message/message_test.go
Normal file
File diff suppressed because it is too large
Load Diff
66
internal/smtp/smtp.go
Normal file
66
internal/smtp/smtp.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Package smtp provides a config wrapper around net/smtp for sending email.
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"net/smtp"
|
||||
)
|
||||
|
||||
// SMTP holds configuration for sending emails via an SMTP server.
|
||||
type SMTP struct {
|
||||
From string `yaml:"from"`
|
||||
Host string `yaml:"host"`
|
||||
Password string `yaml:"password"`
|
||||
Port string `yaml:"port"`
|
||||
User string `yaml:"user"`
|
||||
}
|
||||
|
||||
// Address returns the host:port string.
|
||||
func (s *SMTP) Address() string {
|
||||
return fmt.Sprintf("%s:%s", s.Host, s.Port)
|
||||
}
|
||||
|
||||
// Auth creates a PlainAuth using the configured credentials.
|
||||
func (s *SMTP) Auth() smtp.Auth {
|
||||
return smtp.PlainAuth("", s.User, s.Password, s.Host)
|
||||
}
|
||||
|
||||
// Send emails the given recipients using net/smtp.
|
||||
// Refer to smtp.SendMail for its parameters and limitations.
|
||||
func (s *SMTP) Send(to []string, msg []byte) error {
|
||||
if err := smtp.SendMail(
|
||||
s.Address(),
|
||||
s.Auth(),
|
||||
s.From,
|
||||
to,
|
||||
msg,
|
||||
); err != nil {
|
||||
return fmt.Errorf("smtp send: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks that all required configuration fields are set.
|
||||
func (s *SMTP) Validate() error {
|
||||
if s.From == "" {
|
||||
return fmt.Errorf("missing from")
|
||||
}
|
||||
if _, err := mail.ParseAddress(s.From); err != nil {
|
||||
return fmt.Errorf("invalid from: %w", err)
|
||||
}
|
||||
if s.Host == "" {
|
||||
return fmt.Errorf("missing host")
|
||||
}
|
||||
if s.Password == "" {
|
||||
return fmt.Errorf("missing password")
|
||||
}
|
||||
if s.Port == "" {
|
||||
return fmt.Errorf("missing port")
|
||||
}
|
||||
if s.User == "" {
|
||||
return fmt.Errorf("missing user")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
111
internal/smtp/smtp_test.go
Normal file
111
internal/smtp/smtp_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAddress(t *testing.T) {
|
||||
s := &SMTP{
|
||||
Host: "smtp.example.com",
|
||||
Port: "587",
|
||||
}
|
||||
expected := "smtp.example.com:587"
|
||||
if got := s.Address(); got != expected {
|
||||
t.Errorf("Address() = %q, want %q", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
s := &SMTP{
|
||||
User: "user",
|
||||
Password: "password",
|
||||
Host: "smtp.example.com",
|
||||
}
|
||||
|
||||
// Verify return of non-nil Auth mechanism.
|
||||
if got := s.Auth(); got == nil {
|
||||
t.Error("Auth() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s *SMTP
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid configuration",
|
||||
s: &SMTP{
|
||||
From: "sender@example.com",
|
||||
Host: "smtp.example.com",
|
||||
Password: "secretpassword",
|
||||
Port: "587",
|
||||
User: "user",
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "missing from",
|
||||
s: &SMTP{
|
||||
Host: "smtp.example.com",
|
||||
Password: "p",
|
||||
Port: "587",
|
||||
User: "u",
|
||||
},
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "missing host",
|
||||
s: &SMTP{
|
||||
From: "f",
|
||||
Password: "p",
|
||||
Port: "587",
|
||||
User: "u",
|
||||
},
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "missing password",
|
||||
s: &SMTP{
|
||||
From: "f",
|
||||
Host: "h",
|
||||
Port: "587",
|
||||
User: "u",
|
||||
},
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "missing port",
|
||||
s: &SMTP{
|
||||
From: "f",
|
||||
Host: "h",
|
||||
Password: "p",
|
||||
User: "u",
|
||||
},
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "missing user",
|
||||
s: &SMTP{
|
||||
From: "f",
|
||||
Host: "h",
|
||||
Password: "p",
|
||||
Port: "587",
|
||||
},
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.s.Validate()
|
||||
if tt.shouldError && err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if !tt.shouldError && err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
235
internal/tool/tools.go
Normal file
235
internal/tool/tools.go
Normal file
@@ -0,0 +1,235 @@
|
||||
// Package tool provides a registry for external tools that can be invoked by
|
||||
// LLMs.
|
||||
//
|
||||
// The package bridges YAML configuration to exec.CommandContext, allowing
|
||||
// tools to be defined declaratively without writing Go code. Each tool
|
||||
// specifies a command, argument templates using Go's text/template syntax,
|
||||
// JSON Schema parameters for LLM input, and an optional execution timeout.
|
||||
//
|
||||
// The registry validates all tool definitions at construction time,
|
||||
// failing fast on configuration errors. At execution time, it expands
|
||||
// argument templates with LLM-provided JSON, runs the subprocess, and
|
||||
// returns stdout.
|
||||
package tool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// fn provides template functions available to argument templates.
|
||||
var fn = template.FuncMap{
|
||||
"json": func(v any) string {
|
||||
b, _ := json.Marshal(v)
|
||||
return string(b)
|
||||
},
|
||||
}
|
||||
|
||||
// Tool represents an external tool that can be invoked by LLMs.
|
||||
// Tools are executed as subprocesses with templated arguments.
|
||||
type Tool struct {
|
||||
// Name uniquely identifies the tool within a registry.
|
||||
Name string `yaml:"name"`
|
||||
// Description explains the tool's purpose for the LLM.
|
||||
Description string `yaml:"description"`
|
||||
// Command is the executable path or name.
|
||||
Command string `yaml:"command"`
|
||||
// Arguments are Go templates expanded with LLM-provided parameters.
|
||||
// Empty results after expansion are filtered out.
|
||||
Arguments []string `yaml:"arguments"`
|
||||
// Parameters is a JSON Schema describing expected input from the LLM.
|
||||
Parameters map[string]any `yaml:"parameters"`
|
||||
// Timeout limits execution time (e.g., "30s", "5m").
|
||||
// Empty means no timeout.
|
||||
Timeout string `yaml:"timeout"`
|
||||
// timeout is the parsed duration, set during registry construction.
|
||||
timeout time.Duration `yaml:"-"`
|
||||
}
|
||||
|
||||
// OpenAI converts the tool to an OpenAI function definition for API calls.
|
||||
func (t Tool) OpenAI() openai.Tool {
|
||||
return openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &openai.FunctionDefinition{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: t.Parameters,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ParseArguments expands argument templates with the provided JSON data.
|
||||
// The args parameter should be a JSON object string; empty string or "{}"
|
||||
// results in an empty data map. Templates producing empty strings are
|
||||
// filtered from the result, allowing conditional arguments.
|
||||
func (t Tool) ParseArguments(args string) ([]string, error) {
|
||||
var data = map[string]any{}
|
||||
if args != "" && args != "{}" {
|
||||
if err := json.Unmarshal([]byte(args), &data); err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"invalid arguments JSON: %w", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, v := range t.Arguments {
|
||||
tmpl, err := template.New("").Funcs(fn).Parse(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"invalid template %q: %w", v, err,
|
||||
)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, data); err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"execute template %q: %w", v, err,
|
||||
)
|
||||
}
|
||||
|
||||
// Filter out empty strings (unused conditional arguments).
|
||||
if s := buf.String(); s != "" {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Validate checks that the tool definition is complete and valid.
|
||||
// It verifies required fields are present, the timeout (if specified)
|
||||
// is parseable, and all argument templates are syntactically valid.
|
||||
func (t Tool) Validate() error {
|
||||
if t.Name == "" {
|
||||
return fmt.Errorf("missing name")
|
||||
}
|
||||
if t.Description == "" {
|
||||
return fmt.Errorf("missing description")
|
||||
}
|
||||
if t.Command == "" {
|
||||
return fmt.Errorf("missing command")
|
||||
}
|
||||
if t.Timeout != "" {
|
||||
if _, err := time.ParseDuration(t.Timeout); err != nil {
|
||||
return fmt.Errorf("invalid timeout: %v", err)
|
||||
}
|
||||
}
|
||||
for _, arg := range t.Arguments {
|
||||
if _, err := template.New("").Funcs(fn).Parse(arg); err != nil {
|
||||
return fmt.Errorf("invalid argument template")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Registry holds tools indexed by name and handles their execution.
|
||||
// It validates all tools at construction time to fail fast on
|
||||
// configuration errors.
|
||||
type Registry struct {
|
||||
tools map[string]*Tool
|
||||
}
|
||||
|
||||
// NewRegistry creates a registry from the provided tool definitions.
|
||||
// Returns an error if any tool fails validation or if duplicate names exist.
|
||||
func NewRegistry(tools []Tool) (*Registry, error) {
|
||||
var r = &Registry{
|
||||
tools: make(map[string]*Tool),
|
||||
}
|
||||
for _, t := range tools {
|
||||
if err := t.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid tool: %v", err)
|
||||
}
|
||||
if t.Timeout != "" {
|
||||
d, err := time.ParseDuration(t.Timeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"parse timeout: %v", err,
|
||||
)
|
||||
}
|
||||
t.timeout = d
|
||||
}
|
||||
|
||||
if _, exists := r.tools[t.Name]; exists {
|
||||
return nil, fmt.Errorf(
|
||||
"duplicate tool name: %s", t.Name,
|
||||
)
|
||||
}
|
||||
r.tools[t.Name] = &t
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Get returns a tool by name and a boolean indicating if it was found.
|
||||
func (r *Registry) Get(name string) (*Tool, bool) {
|
||||
tool, ok := r.tools[name]
|
||||
return tool, ok
|
||||
}
|
||||
|
||||
// List returns all registered tool names in arbitrary order.
|
||||
func (r *Registry) List() []string {
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for name := range r.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// Execute runs a tool by name with the provided JSON arguments.
|
||||
// It expands argument templates, executes the command as a subprocess,
|
||||
// and returns stdout on success. The context can be used for cancellation;
|
||||
// tool-specific timeouts are applied on top of any context deadline.
|
||||
func (r *Registry) Execute(
|
||||
ctx context.Context, name string, args string,
|
||||
) (string, error) {
|
||||
tool, ok := r.tools[name]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unknown tool: %s", name)
|
||||
}
|
||||
|
||||
// Evaluate argument templates.
|
||||
cmdArgs, err := tool.ParseArguments(args)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse arguments: %w", err)
|
||||
}
|
||||
|
||||
// If defined, use the timeout.
|
||||
if tool.timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, tool.timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Setup and run the command.
|
||||
var (
|
||||
stdout, stderr bytes.Buffer
|
||||
cmd = exec.CommandContext(
|
||||
ctx, tool.Command, cmdArgs...,
|
||||
)
|
||||
)
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded && tool.timeout > 0 {
|
||||
return "", fmt.Errorf(
|
||||
"tool %s timed out after %v",
|
||||
name, tool.timeout,
|
||||
)
|
||||
}
|
||||
return "", fmt.Errorf(
|
||||
"tool %s: %w\nstderr: %s",
|
||||
name, err, stderr.String(),
|
||||
)
|
||||
}
|
||||
|
||||
return stdout.String(), nil
|
||||
}
|
||||
374
internal/tool/tools_test.go
Normal file
374
internal/tool/tools_test.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestToolValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tool Tool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid tool",
|
||||
tool: Tool{
|
||||
Name: "echo",
|
||||
Description: "echoes input",
|
||||
Command: "echo",
|
||||
Arguments: []string{"{{.message}}"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid tool with timeout",
|
||||
tool: Tool{
|
||||
Name: "slow",
|
||||
Description: "slow command",
|
||||
Command: "sleep",
|
||||
Arguments: []string{"1"},
|
||||
Timeout: "5s",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
tool: Tool{
|
||||
Description: "test",
|
||||
Command: "echo",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing description",
|
||||
tool: Tool{
|
||||
Name: "test",
|
||||
Command: "echo",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing command",
|
||||
tool: Tool{
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid timeout",
|
||||
tool: Tool{
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
Command: "echo",
|
||||
Timeout: "invalid",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid template",
|
||||
tool: Tool{
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
Command: "echo",
|
||||
Arguments: []string{"{{.unclosed"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.tool.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolParseArguments(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tool Tool
|
||||
args string
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple substitution",
|
||||
tool: Tool{
|
||||
Arguments: []string{"{{.message}}"},
|
||||
},
|
||||
args: `{"message": "hello"}`,
|
||||
want: []string{"hello"},
|
||||
},
|
||||
{
|
||||
name: "multiple arguments",
|
||||
tool: Tool{
|
||||
Arguments: []string{"-n", "{{.count}}", "{{.file}}"},
|
||||
},
|
||||
args: `{"count": "10", "file": "test.txt"}`,
|
||||
want: []string{"-n", "10", "test.txt"},
|
||||
},
|
||||
{
|
||||
name: "conditional with if",
|
||||
tool: Tool{
|
||||
Arguments: []string{"{{.required}}", "{{if .optional}}{{.optional}}{{end}}"},
|
||||
},
|
||||
args: `{"required": "value"}`,
|
||||
want: []string{"value"},
|
||||
},
|
||||
{
|
||||
name: "json function",
|
||||
tool: Tool{
|
||||
Arguments: []string{`{{json .data}}`},
|
||||
},
|
||||
args: `{"data": {"key": "value"}}`,
|
||||
want: []string{`{"key":"value"}`},
|
||||
},
|
||||
{
|
||||
name: "empty JSON object",
|
||||
tool: Tool{
|
||||
Arguments: []string{"fixed"},
|
||||
},
|
||||
args: "{}",
|
||||
want: []string{"fixed"},
|
||||
},
|
||||
{
|
||||
name: "empty string args",
|
||||
tool: Tool{
|
||||
Arguments: []string{"fixed"},
|
||||
},
|
||||
args: "",
|
||||
want: []string{"fixed"},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
tool: Tool{
|
||||
Arguments: []string{"{{.x}}"},
|
||||
},
|
||||
args: "not json",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.tool.ParseArguments(tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseArguments() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && !slices.Equal(got, tt.want) {
|
||||
t.Errorf("ParseArguments() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRegistry(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tools []Tool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid registry",
|
||||
tools: []Tool{
|
||||
{Name: "a", Description: "a", Command: "echo"},
|
||||
{Name: "b", Description: "b", Command: "cat"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty registry",
|
||||
tools: []Tool{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "duplicate names",
|
||||
tools: []Tool{
|
||||
{Name: "dup", Description: "first", Command: "echo"},
|
||||
{Name: "dup", Description: "second", Command: "cat"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tool",
|
||||
tools: []Tool{
|
||||
{Name: "", Description: "missing name", Command: "echo"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewRegistry(tt.tools)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewRegistry() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryGet(t *testing.T) {
|
||||
r, err := NewRegistry([]Tool{
|
||||
{Name: "echo", Description: "echo", Command: "echo"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
tool, ok := r.Get("echo")
|
||||
if !ok {
|
||||
t.Error("Get(echo) returned false, want true")
|
||||
}
|
||||
if tool.Name != "echo" {
|
||||
t.Errorf("Get(echo).Name = %q, want %q", tool.Name, "echo")
|
||||
}
|
||||
|
||||
_, ok = r.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("Get(nonexistent) returned true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryList(t *testing.T) {
|
||||
r, err := NewRegistry([]Tool{
|
||||
{Name: "a", Description: "a", Command: "echo"},
|
||||
{Name: "b", Description: "b", Command: "cat"},
|
||||
{Name: "c", Description: "c", Command: "ls"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
names := r.List()
|
||||
if len(names) != 3 {
|
||||
t.Errorf("List() returned %d names, want 3", len(names))
|
||||
}
|
||||
|
||||
slices.Sort(names)
|
||||
want := []string{"a", "b", "c"}
|
||||
if !slices.Equal(names, want) {
|
||||
t.Errorf("List() = %v, want %v", names, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryExecute(t *testing.T) {
|
||||
r, err := NewRegistry([]Tool{
|
||||
{
|
||||
Name: "echo",
|
||||
Description: "echo message",
|
||||
Command: "echo",
|
||||
Arguments: []string{"{{.message}}"},
|
||||
},
|
||||
{
|
||||
Name: "cat",
|
||||
Description: "cat stdin",
|
||||
Command: "cat",
|
||||
Arguments: []string{},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Successful execution.
|
||||
out, err := r.Execute(ctx, "echo", `{"message": "hello world"}`)
|
||||
if err != nil {
|
||||
t.Errorf("Execute(echo) error = %v", err)
|
||||
}
|
||||
if out != "hello world\n" {
|
||||
t.Errorf("Execute(echo) = %q, want %q", out, "hello world\n")
|
||||
}
|
||||
|
||||
// Unknown tool.
|
||||
_, err = r.Execute(ctx, "unknown", "{}")
|
||||
if err == nil {
|
||||
t.Error("Execute(unknown) expected error, got nil")
|
||||
}
|
||||
|
||||
// Invalid JSON arguments.
|
||||
_, err = r.Execute(ctx, "echo", "not json")
|
||||
if err == nil {
|
||||
t.Error("Execute with invalid JSON expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryExecuteTimeout(t *testing.T) {
|
||||
r, err := NewRegistry([]Tool{
|
||||
{
|
||||
Name: "slow",
|
||||
Description: "slow command",
|
||||
Command: "sleep",
|
||||
Arguments: []string{"10"},
|
||||
Timeout: "50ms",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
start := time.Now()
|
||||
_, err = r.Execute(ctx, "slow", "{}")
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Execute(slow) expected timeout error, got nil")
|
||||
}
|
||||
if elapsed > time.Second {
|
||||
t.Errorf("Execute(slow) took %v, expected ~50ms timeout", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryExecuteFailure(t *testing.T) {
|
||||
r, err := NewRegistry([]Tool{
|
||||
{
|
||||
Name: "fail",
|
||||
Description: "always fails",
|
||||
Command: "false",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = r.Execute(context.Background(), "fail", "{}")
|
||||
if err == nil {
|
||||
t.Error("Execute(fail) expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolOpenAI(t *testing.T) {
|
||||
tool := Tool{
|
||||
Name: "test",
|
||||
Description: "test tool",
|
||||
Command: "echo",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "the message",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
openaiTool := tool.OpenAI()
|
||||
if openaiTool.Function.Name != "test" {
|
||||
t.Errorf("OpenAI().Function.Name = %q, want %q", openaiTool.Function.Name, "test")
|
||||
}
|
||||
if openaiTool.Function.Description != "test tool" {
|
||||
t.Errorf("OpenAI().Function.Description = %q, want %q", openaiTool.Function.Description, "test tool")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user