diff --git a/internal/answer/answer.go b/internal/answer/answer.go new file mode 100644 index 0000000..dc21771 --- /dev/null +++ b/internal/answer/answer.go @@ -0,0 +1,263 @@ +// Package answer implements the email response pipeline. +// Workers receive message UIDs, fetch and parse messages, query an LLM +// for responses, and send replies via SMTP. +package answer + +import ( + "context" + "fmt" + "log/slog" + "time" + + "raven/internal/filter" + "raven/internal/imap" + "raven/internal/llm" + "raven/internal/message" + "raven/internal/smtp" + "raven/internal/tracker" + + goimap "github.com/emersion/go-imap/v2" + "github.com/emersion/go-message/mail" +) + +const fallbackResponse = "Your message was received, but I was unable to generate a response. Please try again later or contact the administrator." + +// Worker processes incoming messages through the response pipeline. +// Each worker maintains its own IMAP connection and processes UIDs +// received from a shared work channel. The pipeline: +// 1. Fetch message from IMAP +// 2. Check sender against allowlist +// 3. Query LLM for response +// 4. Send reply via SMTP +// 5. Mark original message as seen +type Worker struct { + // from is the sender address for outgoing replies. + from *mail.Address + // ic is the IMAP client for fetching messages and marking seen. + ic *imap.Client + // id identifies this worker in logs. + id int + // filters determines which senders are allowed. + filters filter.Filters + // llm generates response content. + llm *llm.Client + // log is the worker's logger with worker ID context. + log *slog.Logger + // smtp sends composed replies. + smtp smtp.SMTP + // tracker prevents duplicate processing across workers. + tracker *tracker.Tracker + // work receives UIDs to process. + work <-chan goimap.UID +} + +// NewWorker creates a Worker with its own IMAP connection. +// The from address is parsed from smtp.From for reply composition. +func NewWorker( + id int, + filters filter.Filters, + imapConfig imap.Config, + smtp smtp.SMTP, + llm *llm.Client, + tracker *tracker.Tracker, + log *slog.Logger, + work <-chan goimap.UID, +) (*Worker, error) { + ic, err := imap.NewClient(imapConfig, filters, log) + if err != nil { + return nil, fmt.Errorf("create imap client: %v", err) + } + + from, err := mail.ParseAddress(smtp.From) + if err != nil { + return nil, fmt.Errorf("parse from: %v", err) + } + + return &Worker{ + from: from, + ic: ic, + id: id, + filters: filters, + smtp: smtp, + llm: llm, + log: log.With(slog.String( + "worker", fmt.Sprintf("answer[%d]", id), + )), + tracker: tracker, + work: work, + }, nil +} + +// Run processes UIDs from the work channel until ctx is canceled +// or the channel is closed. Each UID is processed independently; +// errors are logged but do not stop the worker. +func (w *Worker) Run(ctx context.Context) { + defer w.log.Info("worker stopped", slog.Int("id", w.id)) + for { + select { + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + w.log.InfoContext( + ctx, + "context closed", + slog.Any("error", err), + ) + } + return + + case uid, ok := <-w.work: + if !ok { + w.log.InfoContext( + ctx, + "channel closed", + ) + return + } + w.respond(ctx, uid) + w.tracker.Release(uid) + } + } +} + +// respond handles a single message through the full pipeline. +func (w *Worker) respond(ctx context.Context, uid goimap.UID) { + defer w.ic.Disconnect(ctx) + w.log.InfoContext(ctx, "processing", slog.Any("uid", uid)) + + // Connect to the IMAP server and retrieve the message. + if err := w.ic.Connect(ctx, w.log, nil); err != nil { + w.log.ErrorContext( + ctx, + "failed to connect IMAP client", + slog.Any("error", err), + slog.Any("uid", uid), + ) + return + } + fetchedMail, err := w.ic.Fetch(uid, true) + if err != nil { + w.log.ErrorContext( + ctx, + "failed to fetch message", + slog.Any("error", err), + slog.Any("uid", uid), + ) + return + } + + msg, err := message.New(fetchedMail, w.log) + if err != nil { + w.log.ErrorContext( + ctx, + "skipping: failed to parse message", + slog.Any("error", err), + slog.Any("uid", uid), + ) + if err := w.ic.MarkSeen(uid); err != nil { + w.log.ErrorContext( + ctx, + "failed to mark message as seen", + slog.Any("error", err), + slog.Any("uid", uid), + ) + } + return + } + + // Enforce allowlist. + if !w.filters.MatchSender(msg.Envelope.From[0].Addr()) { + w.log.InfoContext( + ctx, + "skipping: sender not in allowlist", + slog.String("from", msg.Envelope.From[0].Addr()), + slog.Any("uid", uid), + ) + if err := w.ic.MarkSeen(uid); err != nil { + w.log.ErrorContext( + ctx, + "failed to mark message as seen", + slog.Any("error", err), + slog.Any("uid", uid), + ) + } + return + } + + // Disconnect during LLM query to avoid IMAP timeout. + // LLM queries can take significant time depending on model and load. + w.ic.Disconnect(ctx) + + res, err := w.llm.Query(ctx, msg) + if err != nil { + w.log.ErrorContext( + ctx, + "failed to query LLM", + slog.Any("error", err), + slog.Any("uid", uid), + ) + res = fallbackResponse + } + + reply, err := msg.ComposeReply(time.Now(), w.from, res) + if err != nil { + w.log.ErrorContext( + ctx, + "failed to compose reply", + slog.Any("error", err), + slog.Any("uid", uid), + ) + return + } + to, err := reply.Recipients() + if err != nil { + w.log.ErrorContext( + ctx, + "failed to retrieve reply recipients", + slog.Any("error", err), + slog.Any("uid", uid), + ) + return + } + body, err := reply.Bytes() + if err != nil { + w.log.ErrorContext( + ctx, + "failed to retrieve reply body", + slog.Any("error", err), + slog.Any("uid", uid), + ) + return + } + + if err := w.smtp.Send(to, body); err != nil { + w.log.ErrorContext( + ctx, + "failed to reply", + slog.Any("error", err), + slog.Any("uid", uid), + ) + return + } + + // Reconnect to flag the message as processed. + if err := w.ic.Connect(ctx, w.log, nil); err != nil { + w.log.ErrorContext( + ctx, + "failed to connect IMAP client", + slog.Any("error", err), + slog.Any("uid", uid), + ) + return + } + if err := w.ic.MarkSeen(uid); err != nil { + w.log.ErrorContext( + ctx, + "failed to mark message as seen", + slog.Any("error", err), + slog.Any("uid", uid), + ) + return + } + + w.log.InfoContext(ctx, "completed", slog.Any("uid", uid)) +} diff --git a/internal/llm/llm.go b/internal/llm/llm.go new file mode 100644 index 0000000..3b912f7 --- /dev/null +++ b/internal/llm/llm.go @@ -0,0 +1,207 @@ +// Package llm provides an OpenAI-compatible client for LLM interactions. +// It handles chat completions with automatic tool call execution. +package llm + +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" + + "raven/internal/message" + "raven/internal/tool" + + openai "github.com/sashabaranov/go-openai" +) + +// Config holds the configuration for an LLM client. +type Config struct { + // Key is the API key for authentication. + Key string `yaml:"key"` + // Model is the model identifier (e.g., "gpt-oss-120b", "qwen3-32b"). + Model string `yaml:"model"` + // SystemPrompt is prepended to all conversations. + SystemPrompt string `yaml:"system_prompt"` + // Timeout is the maximum duration for a query (e.g., "15m"). + // Defaults to 15 minutes if empty. + Timeout string `yaml:"timeout"` + // URL is the base URL of the OpenAI-compatible API endpoint. + URL string `yaml:"url"` +} + +// Validate checks that required configuration values are present and valid. +func (cfg Config) Validate() error { + if cfg.Model == "" { + return fmt.Errorf("missing model") + } + if cfg.Timeout != "" { + if _, err := time.ParseDuration(cfg.Timeout); err != nil { + return fmt.Errorf("invalid duration") + } + } + if cfg.URL == "" { + return fmt.Errorf("missing URL") + } + + return nil +} + +// Client wraps an OpenAI-compatible client with tool execution support. +type Client struct { + client *openai.Client + log *slog.Logger + model string + registry *tool.Registry + systemPrompt string + timeout time.Duration + tools []openai.Tool +} + +// NewClient creates a new LLM client with the provided configuration. +// The registry is optional; if nil, tool calling is disabled. +func NewClient( + cfg Config, + registry *tool.Registry, + log *slog.Logger, +) (*Client, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %v", err) + } + var llm = &Client{ + log: log, + model: cfg.Model, + systemPrompt: cfg.SystemPrompt, + registry: registry, + } + if cfg.Timeout == "" { + llm.timeout = time.Duration(15 * time.Minute) + } else { + d, err := time.ParseDuration(cfg.Timeout) + if err != nil { + return nil, fmt.Errorf("parse timeout: %v", err) + } + llm.timeout = d + } + + // Setup client. + clientConfig := openai.DefaultConfig(cfg.Key) + clientConfig.BaseURL = cfg.URL + llm.client = openai.NewClientWithConfig(clientConfig) + + // Parse tools. + if llm.registry != nil { + for _, name := range llm.registry.List() { + tool, _ := llm.registry.Get(name) + llm.tools = append(llm.tools, tool.OpenAI()) + } + } + + return llm, nil +} + +// Query sends a message to the LLM and returns its response. +// It automatically executes any tool calls requested by the model, +// looping until the model returns a final text response. +// Supports multimodal content (text and images) via the Message type. +func (c *Client) Query( + ctx context.Context, msg *message.Message, +) (string, error) { + ctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + // Build user message from email content. + // Uses MultiContent to support both text and image parts. + userMessage := openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + MultiContent: msg.ToOpenAIMessages(), + } + + // Set the system message and the user prompt. + messages := []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: c.systemPrompt, + }, + userMessage, + } + + // Loop for tool calls. + for { + req := openai.ChatCompletionRequest{ + Model: c.model, + Messages: messages, + } + if len(c.tools) > 0 { + req.Tools = c.tools + } + + res, err := c.client.CreateChatCompletion(ctx, req) + if err != nil { + return "", fmt.Errorf("chat completion: %w", err) + } + if len(res.Choices) == 0 { + return "", fmt.Errorf("no response choices returned") + } + + choice := res.Choices[0] + message := choice.Message + + // If no tool calls, we're done. + if len(message.ToolCalls) == 0 { + return message.Content, nil + } + + // Add assistant message with tool calls to history. + messages = append(messages, message) + + // Process each tool call. + for _, tc := range message.ToolCalls { + c.log.InfoContext( + ctx, + "calling tool", + slog.String("name", tc.Function.Name), + slog.String("args", tc.Function.Arguments), + ) + + res, err := c.registry.Execute( + ctx, tc.Function.Name, tc.Function.Arguments, + ) + if err != nil { + // Return error to LLM so it can recover. + c.log.Error( + "failed to call tool", + slog.Any("error", err), + slog.String("name", tc.Function.Name), + ) + // Assume JSON is a more helpful response + // for an LLM. + res = fmt.Sprintf( + `{"ok": false,"error": %q}`, err, + ) + } else { + c.log.Info( + "called tool", + slog.String("name", tc.Function.Name), + ) + } + + // Content cannot be empty. + // Assume JSON is better than OK or (no output). + if strings.TrimSpace(res) == "" { + res = `{"ok":true,"result":null}` + } + + // Add tool result to messages. + messages = append( + messages, + openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleTool, + Content: res, + ToolCallID: tc.ID, + }, + ) + } + // Loop to get LLM's response after tool execution. + } +} diff --git a/internal/llm/llm_test.go b/internal/llm/llm_test.go new file mode 100644 index 0000000..f2f3a77 --- /dev/null +++ b/internal/llm/llm_test.go @@ -0,0 +1,91 @@ +package llm + +import ( + "testing" +) + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + cfg Config + wantErr bool + }{ + { + name: "empty config", + cfg: Config{}, + wantErr: true, + }, + { + name: "missing model", + cfg: Config{ + URL: "https://api.example.com", + }, + wantErr: true, + }, + { + name: "missing URL", + cfg: Config{ + Model: "gpt-4o", + }, + wantErr: true, + }, + { + name: "invalid timeout", + cfg: Config{ + Model: "gpt-4o", + URL: "https://api.example.com", + Timeout: "not-a-duration", + }, + wantErr: true, + }, + { + name: "valid minimal config", + cfg: Config{ + Model: "gpt-4o", + URL: "https://api.example.com", + }, + wantErr: false, + }, + { + name: "valid config with timeout", + cfg: Config{ + Model: "gpt-4o", + URL: "https://api.example.com", + Timeout: "15m", + }, + wantErr: false, + }, + { + name: "valid config with complex timeout", + cfg: Config{ + Model: "gpt-4o", + URL: "https://api.example.com", + Timeout: "1h30m45s", + }, + wantErr: false, + }, + { + name: "valid full config", + cfg: Config{ + Key: "sk-test-key", + Model: "gpt-4o", + SystemPrompt: "You are a helpful assistant.", + Timeout: "30m", + URL: "https://api.example.com", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.Validate() + if (err != nil) != tt.wantErr { + t.Errorf( + "Validate() error = %v, wantErr %v", + err, tt.wantErr, + ) + } + }) + } +} diff --git a/internal/message/message.go b/internal/message/message.go new file mode 100644 index 0000000..05edd6b --- /dev/null +++ b/internal/message/message.go @@ -0,0 +1,529 @@ +// Package message handles email message parsing and reply composition. +package message + +import ( + "bytes" + "encoding/base64" + "errors" + "fmt" + "io" + "log/slog" + "strings" + "time" + + "github.com/emersion/go-imap/v2" + "github.com/emersion/go-imap/v2/imapclient" + _ "github.com/emersion/go-message/charset" + "github.com/emersion/go-message/mail" + openai "github.com/sashabaranov/go-openai" +) + +// RFC5322 defines the date format specified in RFC 5322 §3.3. +const RFC5322 = "Mon, 2 Jan 2006 15:04:05 -0700" + +// Supported image MIME types for vision models. +var supportedImageTypes = map[string]bool{ + "image/png": true, + "image/jpeg": true, + "image/gif": true, + "image/webp": true, +} + +// maxPartSize is the maximum size in bytes for a single MIME part (32MB). +const maxPartSize = 32 << 20 + +// Part represents a single piece of message content. +type Part struct { + // Content holds text content for text parts, empty for images. + Content string + // ContentType is the MIME type (e.g., "text/plain", "image/png"). + ContentType string + // Data holds raw bytes for binary content like images. + Data []byte + // Filename is set for attachment parts. + Filename string + // IsAttachment distinguishes attachments from inline content. + IsAttachment bool +} + +// IsImage returns true if the part is a supported image type. +func (p *Part) IsImage() bool { + return supportedImageTypes[p.ContentType] +} + +// IsText returns true if the part contains text content. +func (p *Part) IsText() bool { + return strings.HasPrefix(p.ContentType, "text/") +} + +// Message represents a parsed email with its metadata and content. +type Message struct { + // Attachments are parts with Content-Disposition: attachment. + // Stored separately to allow appending after inline content. + Attachments []Part + Envelope *imap.Envelope + // Parts contains inline content in order of appearance. + Parts []Part + // References are Message-IDs from the References header, + // used for threading. + References []string + UID imap.UID + log *slog.Logger +} + +func (msg *Message) TextFrom() string { + var str strings.Builder + if msg.Envelope == nil || len(msg.Envelope.From) == 0 { + return "" + } + str.WriteString("From: ") + for i, a := range msg.Envelope.From { + if a.Name == "" { + fmt.Fprintf(&str, "%s", a.Addr()) + } else { + fmt.Fprintf( + &str, "%s <%s>", + a.Name, a.Addr(), + ) + } + if i+1 < len(msg.Envelope.From) { + str.WriteString(", ") + } + } + + return str.String() +} + +// TextBody returns the concatenated text content from all text parts. +// Used for reply composition and quoting. +func (msg *Message) TextBody() string { + var sb strings.Builder + first := true + for _, p := range msg.Parts { + if p.IsText() { + if !first { + sb.WriteString("\n") + } + sb.WriteString(p.Content) + first = false + } + } + return sb.String() +} + +// ToOpenAIMessages converts the message content to OpenAI chat message parts. +// Inline parts appear first in order, followed by attachments. +// Text parts become text content, supported images become image_url content. +func (msg *Message) ToOpenAIMessages() []openai.ChatMessagePart { + var parts = []openai.ChatMessagePart{} + + if msg != nil && msg.Envelope != nil { + var str strings.Builder + if v := msg.TextFrom(); v != "" { + fmt.Fprintf(&str, "%s\n", v) + } + if !msg.Envelope.Date.IsZero() { + fmt.Fprintf( + &str, + "Date: %s\n", + msg.Envelope.Date.Format(time.RFC3339), + ) + } + if msg.Envelope.Subject != "" { + fmt.Fprintf(&str, "Subject: %v\n", msg.Envelope.Subject) + } + if v := str.String(); v != "" { + parts = append( + parts, + openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeText, + Text: v, + }, + ) + } + } + + // Process inline parts first, preserving order. + for _, p := range msg.Parts { + if part, ok := convertPart(p); ok { + parts = append(parts, part) + } + } + + // Append attachments at the end. + for _, p := range msg.Attachments { + if part, ok := convertPart(p); ok { + parts = append(parts, part) + } + } + + return parts +} + +// convertPart converts a Part to an OpenAI ChatMessagePart. +// Returns false if the part type is not supported for LLM input. +func convertPart(p Part) (openai.ChatMessagePart, bool) { + switch { + case p.IsText(): + return openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeText, + Text: p.Content, + }, true + case p.IsImage(): + dataURI := fmt.Sprintf( + "data:%s;base64,%s", + p.ContentType, + base64.StdEncoding.EncodeToString(p.Data), + ) + return openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: dataURI, + Detail: openai.ImageURLDetailAuto, + }, + }, true + default: + return openai.ChatMessagePart{}, false + } +} + +// composeAttribution builds the attribution line for quoted replies. +// Returns sender and timestamp in a standard format like: +// "On Mon, 2 Jan 2006 15:04:05 -0700, raven wrote:" +func (msg *Message) composeAttribution() string { + if len(msg.Envelope.From) == 0 { + return "> \n" + } + + from := msg.Envelope.From[0] + sender := from.Addr() + if from.Name != "" { + sender = fmt.Sprintf("%s <%s>", from.Name, from.Addr()) + } + + if msg.Envelope.Date.IsZero() { + return fmt.Sprintf("%s wrote:\n", sender) + } + + return fmt.Sprintf( + "On %s, %s wrote:\n", + msg.Envelope.Date.Format(RFC5322), + sender, + ) +} + +// composeBody assembles the reply body: user's response, attribution line, +// and quoted original message. +func (msg *Message) composeBody(res string) string { + var body strings.Builder + body.WriteString(strings.TrimRight(res, "\n\r \t")) + body.WriteString("\n\n") + body.WriteString(msg.composeAttribution()) + body.WriteString(msg.QuotedBody()) + + return body.String() +} + +// composeHeader builds RFC 5322-compliant headers for a reply. +// Sets From, To, Subject, Date, Message-ID, In-Reply-To, and References. +func (msg *Message) composeHeader( + date time.Time, from *mail.Address, +) (*mail.Header, error) { + h := &mail.Header{} + h.SetDate(date) + h.SetContentType("text/plain", map[string]string{"charset": "utf-8"}) + h.SetAddressList("From", []*mail.Address{from}) + h.SetAddressList("Reply-To", []*mail.Address{from}) + + to := msg.composeRecipients() + if len(to) == 0 { + return nil, errors.New("missing recipients") + } + h.SetAddressList("To", to) + h.SetSubject(msg.composeSubject()) + + // Use sender's domain for Message-ID per RFC 5322 recommendation. + parts := strings.SplitN(from.Address, "@", 2) + if len(parts) == 2 { + if err := h.GenerateMessageIDWithHostname(parts[1]); err != nil { + return nil, fmt.Errorf("generate message id: %w", err) + } + } else { + if err := h.GenerateMessageID(); err != nil { + return nil, fmt.Errorf("generate message id: %w", err) + } + } + + inReplyTo, refs := msg.composeReferences() + if inReplyTo != "" { + h.SetMsgIDList("In-Reply-To", []string{inReplyTo}) + } + if len(refs) > 0 { + h.SetMsgIDList("References", refs) + } + + return h, nil +} + +// composeReferences builds threading headers per RFC 5322 §3.6.4. +// In-Reply-To contains the parent's Message-ID. +// References contains the full thread ancestry. +func (msg *Message) composeReferences() (inReplyTo string, refs []string) { + if msg.Envelope.MessageID != "" { + inReplyTo = msg.Envelope.MessageID + refs = append([]string(nil), msg.References...) + refs = append(refs, msg.Envelope.MessageID) + } else { + refs = append([]string(nil), msg.References...) + } + return +} + +// composeRecipients determines the To address for a reply. +// Uses Reply-To if present, otherwise From (per RFC 5322 §3.6.2). +func (msg *Message) composeRecipients() []*mail.Address { + src := msg.Envelope.ReplyTo + if len(src) == 0 { + src = msg.Envelope.From + } + to := make([]*mail.Address, 0, len(src)) + for _, v := range src { + to = append(to, &mail.Address{ + Name: v.Name, + Address: v.Addr(), + }) + } + return to +} + +// composeSubject prepends "Re: " if not already present (case-insensitive). +func (msg *Message) composeSubject() string { + s := msg.Envelope.Subject + if !strings.HasPrefix(strings.ToLower(s), "re:") { + s = "Re: " + s + } + return s +} + +// ComposeReply creates a reply to this message. +// The reply includes proper threading headers and the original message quoted. +func (msg *Message) ComposeReply( + date time.Time, from *mail.Address, res string, +) (*Reply, error) { + if msg == nil || msg.Envelope == nil { + return nil, errors.New("missing envelope") + } + if from == nil || from.Address == "" { + return nil, errors.New("missing from address") + } + + header, err := msg.composeHeader(date, from) + if err != nil { + return nil, fmt.Errorf("compose header: %v", err) + } + + return &Reply{body: msg.composeBody(res), header: header}, nil +} + +// QuotedBody returns the message text body with each line prefixed by "> ". +func (msg *Message) QuotedBody() string { + var quoted strings.Builder + for line := range strings.SplitSeq(msg.TextBody(), "\n") { + line = strings.TrimSuffix(line, "\r") + quoted.WriteString("> ") + quoted.WriteString(line) + quoted.WriteString("\n") + } + return quoted.String() +} + +// New creates a Message from an IMAP fetch buffer. +// Handles multipart messages by collecting text parts and images inline, +// with attachments stored separately for appending later. +// Logs skipped parts and non-fatal errors. +func New(mb *imapclient.FetchMessageBuffer, log *slog.Logger) (*Message, error) { + if mb == nil { + return nil, errors.New("nil message buffer") + } + if mb.UID == 0 { + return nil, errors.New("message has no UID") + } + if mb.Envelope == nil { + return nil, errors.New("message has no envelope") + } + if log == nil { + log = slog.Default() + } + + msg := &Message{ + UID: mb.UID, + Envelope: mb.Envelope, + log: log, + } + + // Try each body section until we successfully parse one. + var parseErr error + for _, section := range mb.BodySection { + if len(section.Bytes) == 0 { + continue + } + if err := parseBody( + msg, + bytes.NewReader(section.Bytes), + log, + ); err != nil { + parseErr = err + continue + } + if len(msg.Parts) > 0 { + break + } + } + + // Failed to parse any content — return an error. + if len(msg.Parts) == 0 && parseErr != nil { + return msg, fmt.Errorf("parse body: %w", parseErr) + } + + return msg, nil +} + +// parseBody extracts content from a MIME message body. +// mail.Reader automatically flattens nested multipart structures, returning +// only leaf parts (text/plain, attachments, etc.). +func parseBody(msg *Message, r io.Reader, log *slog.Logger) error { + reader, err := mail.CreateReader(r) + if err != nil { + return fmt.Errorf("create reader: %w", err) + } + defer reader.Close() + + // Extract References header for threading. + refs, err := reader.Header.MsgIDList("References") + if err == nil && len(refs) > 0 { + msg.References = refs + } + + // Process all parts. + for { + part, err := reader.NextPart() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return fmt.Errorf("next part: %w", err) + } + + if err := processPart(msg, part, log); err != nil { + log.Debug("skipped part", + slog.Any("error", err), + ) + continue + } + } + + return nil +} + +// processPart handles a single MIME part returned by mail.Reader. +// Inline text and images are added to Parts; attachments go to Attachments. +func processPart(msg *Message, part *mail.Part, log *slog.Logger) error { + switch h := part.Header.(type) { + case *mail.InlineHeader: + ct, _, err := h.Header.ContentType() + if err != nil { + ct = "text/plain" + } + + body, err := io.ReadAll(io.LimitReader(part.Body, maxPartSize)) + if err != nil { + return fmt.Errorf("read body: %w", err) + } + + switch { + case strings.HasPrefix(ct, "text/"): + msg.Parts = append(msg.Parts, Part{ + Content: string(body), + ContentType: ct, + }) + case supportedImageTypes[ct]: + msg.Parts = append(msg.Parts, Part{ + ContentType: ct, + Data: body, + }) + default: + log.Debug("skipped unsupported inline content type", + slog.String("content_type", ct), + ) + } + + case *mail.AttachmentHeader: + filename, _ := h.Filename() + ct, _, _ := h.Header.ContentType() + + body, err := io.ReadAll(io.LimitReader(part.Body, maxPartSize)) + if err != nil { + return fmt.Errorf("read attachment: %w", err) + } + + // Only store attachments we can use (text or images). + switch { + case strings.HasPrefix(ct, "text/"): + msg.Attachments = append(msg.Attachments, Part{ + Content: string(body), + ContentType: ct, + Filename: filename, + IsAttachment: true, + }) + case supportedImageTypes[ct]: + msg.Attachments = append(msg.Attachments, Part{ + ContentType: ct, + Data: body, + Filename: filename, + IsAttachment: true, + }) + default: + log.Debug("skipped unsupported attachment type", + slog.String("content_type", ct), + slog.String("filename", filename), + ) + } + } + + return nil +} + +// Reply holds a composed reply ready for sending. +type Reply struct { + body string + header *mail.Header +} + +// Bytes serializes the reply to RFC 5322 wire format. +func (r *Reply) Bytes() ([]byte, error) { + var buf bytes.Buffer + mw, err := mail.CreateSingleInlineWriter(&buf, *r.header) + if err != nil { + return nil, fmt.Errorf("create writer: %w", err) + } + if _, err := mw.Write([]byte(r.body)); err != nil { + return nil, fmt.Errorf("write body: %w", err) + } + if err := mw.Close(); err != nil { + return nil, fmt.Errorf("close writer: %w", err) + } + return buf.Bytes(), nil +} + +// Recipients returns the To addresses as formatted strings. +func (r *Reply) Recipients() ([]string, error) { + addrs, err := r.header.AddressList("To") + if err != nil { + return nil, fmt.Errorf("address list: %w", err) + } + to := make([]string, len(addrs)) + for i, v := range addrs { + to[i] = v.Address + } + return to, nil +} diff --git a/internal/message/message_test.go b/internal/message/message_test.go new file mode 100644 index 0000000..e1cb0b5 --- /dev/null +++ b/internal/message/message_test.go @@ -0,0 +1,1352 @@ +package message + +import ( + "io" + "log/slog" + "strings" + "testing" + "time" + + "github.com/emersion/go-imap/v2" + "github.com/emersion/go-imap/v2/imapclient" + "github.com/emersion/go-message/mail" + openai "github.com/sashabaranov/go-openai" +) + +var log = slog.New(slog.NewTextHandler(io.Discard, nil)) + +var newMessageTests = []struct { + name string + raw string + bufferOverrides func(*imapclient.FetchMessageBuffer) + nilBuffer bool + Parts []Part + wantAttach []Part + wantRefs []string + wantTextBody string + shouldError bool +}{ + { + name: "Plain Text", + raw: "Content-Type: text/plain; charset=utf-8\r\n" + + "\r\n" + + "Hello, world!\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Hello, world!\r\n", + }, + }, + wantTextBody: "Hello, world!\r\n", + }, + { + name: "Multipart Alternative", + raw: "Content-Type: multipart/alternative; boundary=bound\r\n" + + "\r\n" + + "--bound\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "Plain text\r\n" + + "--bound\r\n" + + "Content-Type: text/html\r\n" + + "\r\n" + + "

HTML

\r\n" + + "--bound--\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Plain text", + }, + { + ContentType: "text/html", + Content: "

HTML

", + }, + }, + wantTextBody: "Plain text\n

HTML

", + }, + { + name: "Multipart Mixed With Attachment", + raw: "Content-Type: multipart/mixed; boundary=outer\r\n" + + "\r\n" + + "--outer\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "Message body\r\n" + + "--outer\r\n" + + "Content-Type: text/csv\r\n" + + "Content-Disposition: attachment; filename=\"data.csv\"\r\n" + + "\r\n" + + "a,b,c\r\n" + + "--outer--\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Message body", + }, + }, + wantAttach: []Part{ + { + ContentType: "text/csv", + Content: "a,b,c", + Filename: "data.csv", + IsAttachment: true, + }, + }, + wantTextBody: "Message body", + }, + { + name: "Inline Image", + raw: "Content-Type: multipart/mixed; boundary=bound\r\n" + + "\r\n" + + "--bound\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "See image:\r\n" + + "--bound\r\n" + + "Content-Type: image/png\r\n" + + "Content-Disposition: inline\r\n" + + "Content-Transfer-Encoding: base64\r\n" + + "\r\n" + + "iVBORw0KGgo=\r\n" + + "--bound\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "What do you think?\r\n" + + "--bound--\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "See image:", + }, + { + ContentType: "image/png", + Data: []byte{1}, + }, + { + ContentType: "text/plain", + Content: "What do you think?", + }, + }, + wantTextBody: "See image:\nWhat do you think?", + }, + { + name: "Image Attachment", + raw: "Content-Type: multipart/mixed; boundary=bound\r\n" + + "\r\n" + + "--bound\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "Please review attached image.\r\n" + + "--bound\r\n" + + "Content-Type: image/jpeg\r\n" + + "Content-Disposition: attachment; filename=\"photo.jpg\"\r\n" + + "Content-Transfer-Encoding: base64\r\n" + + "\r\n" + + "/9j/4AAQ\r\n" + + "--bound--\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Please review attached image.", + }, + }, + wantAttach: []Part{ + { + ContentType: "image/jpeg", + Data: []byte{1}, + Filename: "photo.jpg", + IsAttachment: true, + }, + }, + wantTextBody: "Please review attached image.", + }, + { + name: "Nested Multipart", + raw: "Content-Type: multipart/mixed; boundary=outer\r\n" + + "\r\n" + + "--outer\r\n" + + "Content-Type: multipart/alternative; boundary=inner\r\n" + + "\r\n" + + "--inner\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "Plain\r\n" + + "--inner\r\n" + + "Content-Type: text/html\r\n" + + "\r\n" + + "HTML\r\n" + + "--inner--\r\n" + + "--outer\r\n" + + "Content-Type: image/png\r\n" + + "Content-Disposition: attachment; filename=\"img.png\"\r\n" + + "\r\n" + + "PNG-DATA\r\n" + + "--outer--\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Plain", + }, + { + ContentType: "text/html", + Content: "HTML", + }, + }, + wantAttach: []Part{ + { + ContentType: "image/png", + Data: []byte{1}, + Filename: "img.png", + IsAttachment: true, + }, + }, + wantTextBody: "Plain\nHTML", + }, + { + name: "References Header", + raw: "References: \r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "Body\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Body\r\n", + }, + }, + wantRefs: []string{ + "abc@example.com", + "def@example.com", + }, + wantTextBody: "Body\r\n", + }, + { + name: "Quoted Printable", + raw: "Content-Type: text/plain; charset=utf-8\r\n" + + "Content-Transfer-Encoding: quoted-printable\r\n" + + "\r\n" + + "Hello=20World\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Hello World\r\n", + }, + }, + wantTextBody: "Hello World\r\n", + }, + { + name: "Base64 Text", + raw: "Content-Type: text/plain; charset=utf-8\r\n" + + "Content-Transfer-Encoding: base64\r\n" + + "\r\n" + + "SGVsbG8gV29ybGQ=\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Hello World", + }, + }, + wantTextBody: "Hello World", + }, + { + name: "No Content Type", + raw: "\r\nPlain body with no headers\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Plain body with no headers\r\n", + }, + }, + wantTextBody: "Plain body with no headers\r\n", + }, + { + name: "Multiple Text Parts", + raw: "Content-Type: multipart/mixed; boundary=bound\r\n" + + "\r\n" + + "--bound\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "First part\r\n" + + "--bound\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "Second part\r\n" + + "--bound--\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "First part", + }, + { + ContentType: "text/plain", + Content: "Second part", + }, + }, + wantTextBody: "First part\nSecond part", + }, + { + name: "Skips Unsupported Attachment", + raw: "Content-Type: multipart/mixed; boundary=bound\r\n" + + "\r\n" + + "--bound\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "Message\r\n" + + "--bound\r\n" + + "Content-Type: application/pdf\r\n" + + "Content-Disposition: attachment; filename=\"doc.pdf\"\r\n" + + "\r\n" + + "PDF-DATA\r\n" + + "--bound--\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Message", + }, + }, + wantAttach: nil, + wantTextBody: "Message", + }, + { + name: "Skips Unsupported Inline", + raw: "Content-Type: multipart/mixed; boundary=bound\r\n" + + "\r\n" + + "--bound\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "Message\r\n" + + "--bound\r\n" + + "Content-Type: audio/mpeg\r\n" + + "\r\n" + + "AUDIO-DATA\r\n" + + "--bound--\r\n", + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Message", + }, + }, + wantTextBody: "Message", + }, + { + name: "Nil Buffer", + nilBuffer: true, + shouldError: true, + }, + { + name: "Missing UID", + raw: "Subject: x\r\n\r\nBody", + bufferOverrides: func( + mb *imapclient.FetchMessageBuffer, + ) { + mb.UID = 0 + }, + shouldError: true, + }, + { + name: "Missing Envelope", + raw: "Subject: x\r\n\r\nBody", + bufferOverrides: func( + mb *imapclient.FetchMessageBuffer, + ) { + mb.Envelope = nil + }, + shouldError: true, + }, + { + name: "Empty Body Section", + raw: "Content-Type: text/plain\r\n\r\nBody\r\n", + bufferOverrides: func( + mb *imapclient.FetchMessageBuffer, + ) { + // Prepend an empty section to verify it gets skipped. + emptySec := imapclient.FetchBodySectionBuffer{ + Section: &imap.FetchItemBodySection{}, + Bytes: []byte{}, + } + mb.BodySection = append( + []imapclient.FetchBodySectionBuffer{ + emptySec, + }, + mb.BodySection..., + ) + }, + Parts: []Part{ + { + ContentType: "text/plain", + Content: "Body\r\n", + }, + }, + wantTextBody: "Body\r\n", + }, +} + +// TestNew verifies MIME parsing across message structures: +// plain text, multipart/alternative, multipart/mixed with attachments, +// nested multipart, transfer encodings (quoted-printable, base64), +// inline images, and error conditions. +func TestNew(t *testing.T) { + for _, tt := range newMessageTests { + t.Run(tt.name, func(t *testing.T) { + var mb *imapclient.FetchMessageBuffer + + if !tt.nilBuffer { + mb = &imapclient.FetchMessageBuffer{ + UID: 1, + Envelope: &imap.Envelope{}, + BodySection: []imapclient.FetchBodySectionBuffer{ + { + Section: &imap.FetchItemBodySection{}, + Bytes: []byte(tt.raw), + }, + }, + } + if tt.bufferOverrides != nil { + tt.bufferOverrides(mb) + } + } + + msg, err := New(mb, log) + + if tt.shouldError { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if msg == nil { + t.Fatal("expected message, got nil") + } + + // Verify parts. + if len(msg.Parts) != len(tt.Parts) { + t.Fatalf( + "Parts count: got %d, want %d", + len(msg.Parts), len(tt.Parts), + ) + } + for i, want := range tt.Parts { + got := msg.Parts[i] + if !strings.HasPrefix( + got.ContentType, want.ContentType, + ) { + t.Errorf( + "Parts[%d] content-type: got %q, want prefix %q", + i, got.ContentType, + want.ContentType, + ) + } + if want.Content != "" && + got.Content != want.Content { + t.Errorf( + "Parts[%d] Content: got %q, want %q", + i, got.Content, want.Content, + ) + } + if len(want.Data) > 0 && len(got.Data) == 0 { + t.Errorf( + "Parts[%d] expected data, got empty", + i, + ) + } + } + + // Verify attachments. + if len(msg.Attachments) != len(tt.wantAttach) { + t.Fatalf( + "Attachments count: got %d, want %d", + len(msg.Attachments), + len(tt.wantAttach)) + } + for i, want := range tt.wantAttach { + got := msg.Attachments[i] + if !strings.HasPrefix( + got.ContentType, want.ContentType, + ) { + t.Errorf( + "Attachments[%d] type: got %q, want prefix %q", + i, got.ContentType, want.ContentType, + ) + } + if want.Filename != "" && + got.Filename != want.Filename { + t.Errorf( + "Attachments[%d] Filename: got %q, want %q", + i, got.Filename, want.Filename, + ) + } + if want.Content != "" && got.Content != want.Content { + t.Errorf( + "Attachments[%d] Content: got %q, want %q", + i, got.Content, want.Content, + ) + } + if len(want.Data) > 0 && len(got.Data) == 0 { + t.Errorf( + "Attachments[%d] expected data, got empty", + i, + ) + } + if !got.IsAttachment { + t.Errorf( + "Attachments[%d] IsAttachment: got false, want true", + i, + ) + } + } + + // Verify TextBody. + if tt.wantTextBody != "" && + msg.TextBody() != tt.wantTextBody { + t.Errorf( + "TextBody: got %q, want %q", + msg.TextBody(), tt.wantTextBody, + ) + } + + // Verify References. + if len(tt.wantRefs) > 0 { + if len(msg.References) != len(tt.wantRefs) { + t.Errorf( + "References count: got %d, want %d", + len(msg.References), + len(tt.wantRefs), + ) + } + for i, wantRef := range tt.wantRefs { + if i < len(msg.References) && + msg.References[i] != wantRef { + t.Errorf( + "References[%d]: got %q, want %q", + i, + msg.References[i], + wantRef, + ) + } + } + } + }) + } +} + +var chatTests = []struct { + name string + parts []Part + attachments []Part + envelope *imap.Envelope + expected []openai.ChatMessagePart +}{ + { + name: "Text Only", + parts: []Part{ + { + ContentType: "text/plain", + Content: "Hello", + }, + { + ContentType: "text/html", + Content: "

World

", + }, + }, + expected: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "Hello", + }, + { + Type: openai.ChatMessagePartTypeText, + Text: "

World

", + }, + }, + }, + { + name: "Text And Image", + parts: []Part{ + { + ContentType: "text/plain", + Content: "See this:", + }, + { + ContentType: "image/png", + Data: []byte("PNG"), + }, + }, + expected: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "See this:", + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: "data:image/png;base64,UE5H", + Detail: openai.ImageURLDetailAuto, + }, + }, + }, + }, + { + name: "Parts Then Attachments", + parts: []Part{ + {ContentType: "text/plain", Content: "Body"}, + }, + attachments: []Part{ + { + ContentType: "image/jpeg", + Data: []byte("JPG"), + IsAttachment: true, + }, + }, + expected: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "Body", + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: "data:image/jpeg;base64,SlBH", + Detail: openai.ImageURLDetailAuto, + }, + }, + }, + }, + { + name: "Skips Unsupported Types", + parts: []Part{ + {ContentType: "text/plain", Content: "Text"}, + {ContentType: "application/pdf", Data: []byte("PDF")}, + }, + expected: []openai.ChatMessagePart{ + {Type: openai.ChatMessagePartTypeText, Text: "Text"}, + }, + }, + { + name: "Order Preserved", + parts: []Part{ + {ContentType: "text/plain", Content: "First"}, + {ContentType: "image/gif", Data: []byte("GIF1")}, + {ContentType: "text/csv", Content: "a,b,c"}, + {ContentType: "image/webp", Data: []byte("WEBP")}, + {ContentType: "text/plain", Content: "Last"}, + }, + expected: []openai.ChatMessagePart{ + {Type: openai.ChatMessagePartTypeText, Text: "First"}, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: "data:image/gif;base64,R0lGMQ==", + Detail: openai.ImageURLDetailAuto, + }, + }, + {Type: openai.ChatMessagePartTypeText, Text: "a,b,c"}, + {Type: openai.ChatMessagePartTypeImageURL, ImageURL: &openai.ChatMessageImageURL{ + URL: "data:image/webp;base64,V0VCUA==", + Detail: openai.ImageURLDetailAuto, + }}, + {Type: openai.ChatMessagePartTypeText, Text: "Last"}, + }, + }, + { + name: "HeaderWithAllFields", + parts: []Part{ + {ContentType: "text/plain", Content: "body"}, + }, + envelope: &imap.Envelope{ + From: []imap.Address{ + { + Name: "Alice", + Mailbox: "alice", + Host: "example.com", + }, + { + Mailbox: "bob", + Host: "example.org", + }, + }, + Date: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + Subject: "Test Subject", + }, + expected: []openai.ChatMessagePart{ + { + // The header part – From, Date and Subject. + Type: openai.ChatMessagePartTypeText, + Text: "From: Alice , bob@example.org\n" + + "Date: 2024-01-15T10:30:00Z\n" + + "Subject: Test Subject\n", + }, + { + // The body part that follows the header. + Type: openai.ChatMessagePartTypeText, + Text: "body", + }, + }, + }, + { + name: "HeaderWithoutFrom", + parts: []Part{ + {ContentType: "text/plain", Content: "only date"}, + }, + envelope: &imap.Envelope{ + Date: time.Date(2022, 12, 31, 23, 59, 59, 0, time.FixedZone("-05:00", -5*3600)), + Subject: "Year‑End", + }, + expected: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "Date: 2022-12-31T23:59:59-05:00\n" + + "Subject: Year‑End\n", + }, + { + Type: openai.ChatMessagePartTypeText, + Text: "only date", + }, + }, + }, + { + name: "HeaderWithoutDate", + parts: []Part{ + {ContentType: "text/plain", Content: "no date"}, + }, + envelope: &imap.Envelope{ + From: []imap.Address{ + { + Name: "Charlie", + Mailbox: "charlie", + Host: "example.net", + }, + }, + Subject: "Missing‑Date", + }, + expected: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "From: Charlie \n" + + "Subject: Missing‑Date\n", + }, + { + Type: openai.ChatMessagePartTypeText, + Text: "no date", + }, + }, + }, + { + name: "NilEnvelope", + parts: []Part{ + {ContentType: "text/plain", Content: "just body"}, + }, + // envelope left nil – we expect *no* header part. + expected: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "just body", + }, + }, + }, +} + +// TestToOpenAIMessages verifies conversion to OpenAI multimodal format. +func TestToOpenAIMessages(t *testing.T) { + for _, tt := range chatTests { + t.Run(tt.name, func(t *testing.T) { + msg := &Message{ + Parts: tt.parts, + Attachments: tt.attachments, + Envelope: tt.envelope, + } + + got := msg.ToOpenAIMessages() + + if len(got) != len(tt.expected) { + t.Fatalf( + "Part count: got %d, want %d", + len(got), len(tt.expected), + ) + } + + for i, want := range tt.expected { + g := got[i] + if g.Type != want.Type { + t.Errorf( + "[%d] Type: got %q, want %q", + i, g.Type, want.Type, + ) + } + if want.Type == openai.ChatMessagePartTypeText { + if g.Text != want.Text { + t.Errorf( + "[%d] Text: got %q, want %q", + i, g.Text, want.Text, + ) + } + } + if want.Type == + openai.ChatMessagePartTypeImageURL { + if g.ImageURL == nil { + t.Errorf( + "[%d] ImageURL is nil", + i, + ) + continue + } + if g.ImageURL.URL != want.ImageURL.URL { + t.Errorf( + "[%d] ImageURL.URL: got %q, want %q", + i, g.ImageURL.URL, + want.ImageURL.URL, + ) + } + if g.ImageURL.Detail != + want.ImageURL.Detail { + t.Errorf( + "[%d] ImageURL.Detail: got %q, want %q", + i, g.ImageURL.Detail, + want.ImageURL.Detail, + ) + } + } + } + }) + } +} + +var fixedDate = time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) +var defaultFrom = &mail.Address{Name: "Bob", Address: "bob@example.com"} +var replyTests = []struct { + name string + msg *Message + from *mail.Address + replyText string + shouldError bool + expectedRecipients []string + expectedSubject string + expectedInReplyTo string + expectedReferences []string + expectedBodyContains []string +}{ + { + name: "Basic Reply", + msg: &Message{ + Envelope: &imap.Envelope{ + Subject: "Test", + MessageID: "orig@example.com", + From: []imap.Address{{ + Name: "Alice", + Mailbox: "alice", + Host: "example.com", + }}, + Date: fixedDate, + }, + Parts: []Part{{ + ContentType: "text/plain", + Content: "Original message", + }}, + }, + replyText: "My response", + expectedRecipients: []string{"alice@example.com"}, + expectedSubject: "Re: Test", + expectedInReplyTo: "", + expectedBodyContains: []string{ + "My response", + "> Original message", + "On Mon, 15 Jan 2024 10:30:00 +0000, Alice wrote:", + }, + }, + { + name: "Uses Reply-To", + msg: &Message{ + Envelope: &imap.Envelope{ + Subject: "Test", + From: []imap.Address{{ + Mailbox: "from", Host: "example.com", + }}, + ReplyTo: []imap.Address{{ + Mailbox: "replyto", + Host: "example.com", + }}, + Date: fixedDate, + }, + Parts: []Part{{ + ContentType: "text/plain", Content: "Body", + }}, + }, + from: &mail.Address{Address: "me@example.com"}, + replyText: "Response", + expectedRecipients: []string{"replyto@example.com"}, + }, + { + name: "Subject Already Re", + msg: &Message{ + Envelope: &imap.Envelope{ + Subject: "Re: Already replied", + From: []imap.Address{{ + Mailbox: "from", Host: "example.com", + }}, + }, + Parts: []Part{{ + ContentType: "text/plain", Content: "Body", + }}, + }, + replyText: "Response", + expectedSubject: "Re: Already replied", + }, + { + name: "Thread References", + msg: &Message{ + Envelope: &imap.Envelope{ + Subject: "Thread", + MessageID: "msg3@example.com", + From: []imap.Address{{ + Mailbox: "from", Host: "example.com", + }}, + }, + References: []string{ + "msg1@example.com", "msg2@example.com", + }, + Parts: []Part{{ + ContentType: "text/plain", Content: "Body", + }}, + }, + replyText: "Response", + expectedReferences: []string{ + "msg1@example.com", + "msg2@example.com", + "msg3@example.com", + }, + }, + { + name: "Nil Message", + msg: nil, + shouldError: true, + }, + { + name: "Nil Envelope", + msg: &Message{}, + shouldError: true, + }, + { + name: "Nil From", + msg: &Message{ + Envelope: &imap.Envelope{ + From: []imap.Address{{ + Mailbox: "from", Host: "example.com", + }}, + }, + }, + from: nil, + shouldError: true, + }, + { + name: "Empty From Address", + msg: &Message{ + Envelope: &imap.Envelope{ + From: []imap.Address{{ + Mailbox: "from", Host: "example.com", + }}, + }, + }, + from: &mail.Address{Name: "Name Only"}, + shouldError: true, + }, + { + name: "From Address No At Sign", + msg: &Message{ + Envelope: &imap.Envelope{ + From: []imap.Address{{ + Mailbox: "from", Host: "example.com", + }}, + }, + }, + from: &mail.Address{Address: "localonly"}, + shouldError: false, + }, + { + name: "No Recipients", + msg: &Message{ + Envelope: &imap.Envelope{Subject: "Test"}, + Parts: []Part{{ + ContentType: "text/plain", Content: "Body", + }}, + }, + shouldError: true, + }, +} + +// TestComposeReply verifies reply composition: recipient selection, +// subject handling, threading headers, and error conditions. +func TestComposeReply(t *testing.T) { + for _, tt := range replyTests { + t.Run(tt.name, func(t *testing.T) { + useFrom := tt.from + // Use default sender for non-error cases when from is + // not specified. + if useFrom == nil && !tt.shouldError { + useFrom = defaultFrom + } + + reply, err := tt.msg.ComposeReply( + fixedDate, useFrom, tt.replyText, + ) + if tt.shouldError { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(tt.expectedRecipients) > 0 { + recipients, err := reply.Recipients() + if err != nil { + t.Fatalf("Recipients() error: %v", err) + } + + if len(recipients) != + len(tt.expectedRecipients) { + t.Errorf( + "Recipients count: got %d, want %d", + len(recipients), + len(tt.expectedRecipients), + ) + } + for i, want := range tt.expectedRecipients { + if i < len(recipients) && + !strings.Contains( + recipients[i], want, + ) { + t.Errorf( + "Recipient[%d]: got %q, want substring %q", + i, recipients[i], want, + ) + } + } + } + + data, err := reply.Bytes() + if err != nil { + t.Fatalf("Bytes() error: %v", err) + } + bodyStr := string(data) + + if tt.expectedSubject != "" && + !strings.Contains( + bodyStr, + "Subject: "+tt.expectedSubject, + ) { + t.Errorf( + "Subject missing or mismatch.\nBody: %s\nExpected Subject: %s", + bodyStr, + tt.expectedSubject, + ) + } + if tt.expectedInReplyTo != "" { + if !strings.Contains( + bodyStr, + "In-Reply-To: "+tt.expectedInReplyTo, + ) { + t.Errorf( + "In-Reply-To missing.\nBody: %s\nExpected: %s", + bodyStr, + tt.expectedInReplyTo, + ) + } + } + + for _, ref := range tt.expectedReferences { + if !strings.Contains(bodyStr, ref) { + t.Errorf( + "References header missing ID %q", + ref, + ) + } + } + + for _, part := range tt.expectedBodyContains { + if !strings.Contains(bodyStr, part) { + t.Errorf( + "Body content missing: %q", + part, + ) + } + } + }) + } +} + +var attributionTests = []struct { + name string + envelope *imap.Envelope + expected string +}{ + { + name: "No From", + envelope: &imap.Envelope{}, + expected: "> \n", + }, + { + name: "No Date", + envelope: &imap.Envelope{ + From: []imap.Address{{ + Name: "Alice", + Mailbox: "alice", + Host: "example.com", + }}, + }, + expected: "Alice wrote:\n", + }, + { + name: "No Name", + envelope: &imap.Envelope{ + From: []imap.Address{{ + Mailbox: "alice", Host: "example.com", + }}, + }, + expected: "alice@example.com wrote:\n", + }, + { + name: "Full", + envelope: &imap.Envelope{ + From: []imap.Address{{ + Name: "Bob", Mailbox: "bob", Host: "ex.com", + }}, + Date: time.Date( + 2024, 1, 2, 15, 4, 5, 0, + time.FixedZone("", -25200), + ), + }, + expected: "On Tue, 2 Jan 2024 15:04:05 -0700, Bob wrote:\n", + }, +} + +// TestComposeAttribution verifies attribution line formatting for various +// sender/date combinations. +func TestComposeAttribution(t *testing.T) { + for _, tt := range attributionTests { + t.Run(tt.name, func(t *testing.T) { + msg := &Message{Envelope: tt.envelope} + if got := msg.composeAttribution(); got != tt.expected { + t.Errorf("got %q, want %q", got, tt.expected) + } + }) + } +} + +var quoteTests = []struct { + name string + parts []Part + expected string +}{ + { + name: "Simple", + parts: []Part{ + {ContentType: "text/plain", Content: "line1\nline2\n"}, + }, + expected: "> line1\n> line2\n> \n", + }, + { + name: "Strips CR", + parts: []Part{ + { + ContentType: "text/plain", + Content: "line1\r\nline2\r\n", + }, + }, + expected: "> line1\n> line2\n> \n", + }, + { + name: "Multiple Parts", + parts: []Part{ + {ContentType: "text/plain", Content: "first"}, + {ContentType: "text/plain", Content: "second"}, + }, + expected: "> first\n> second\n", + }, + { + name: "Ignores Images", + parts: []Part{ + {ContentType: "text/plain", Content: "text"}, + {ContentType: "image/png", Data: []byte("PNG")}, + }, + expected: "> text\n", + }, +} + +// TestQuotedBody verifies line quoting and CR stripping. +func TestQuotedBody(t *testing.T) { + for _, tt := range quoteTests { + t.Run(tt.name, func(t *testing.T) { + msg := &Message{Parts: tt.parts, Envelope: &imap.Envelope{}} + if got := msg.QuotedBody(); got != tt.expected { + t.Errorf("got %q, want %q", got, tt.expected) + } + }) + } +} + +var textBodyTests = []struct { + name string + parts []Part + expected string +}{ + { + name: "Single Text Part", + parts: []Part{ + {ContentType: "text/plain", Content: "Hello"}, + }, + expected: "Hello", + }, + { + name: "Multiple Text Parts", + parts: []Part{ + {ContentType: "text/plain", Content: "First"}, + {ContentType: "text/html", Content: "

Second

"}, + }, + expected: "First\n

Second

", + }, + { + name: "Mixed With Images", + parts: []Part{ + {ContentType: "text/plain", Content: "Before"}, + {ContentType: "image/png", Data: []byte("PNG")}, + {ContentType: "text/plain", Content: "After"}, + }, + expected: "Before\nAfter", + }, + { + name: "No Text Parts", + parts: []Part{{ + ContentType: "image/jpeg", + Data: []byte("JPG"), + }}, + expected: "", + }, +} + +// TestTextBody verifies text extraction from parts. +func TestTextBody(t *testing.T) { + for _, tt := range textBodyTests { + t.Run(tt.name, func(t *testing.T) { + msg := &Message{Parts: tt.parts} + if got := msg.TextBody(); got != tt.expected { + t.Errorf("got %q, want %q", got, tt.expected) + } + }) + } +} + +// TestTextFrom checks the formatting of the “From:” line produced by +// Message.TextFrom. +func TestTextFrom(t *testing.T) { + tests := []struct { + name string + env *imap.Envelope + expected string + }{ + { + name: "Nil envelope", + env: nil, + expected: "", + }, + { + name: "No From addresses", + env: &imap.Envelope{ + From: []imap.Address{}, + }, + expected: "", + }, + { + name: "Single address with name", + env: &imap.Envelope{ + From: []imap.Address{ + { + Name: "Alice", + Mailbox: "alice", + Host: "example.com", + }, + }, + }, + expected: "From: Alice ", + }, + { + name: "Single address without name", + env: &imap.Envelope{ + From: []imap.Address{ + { + Mailbox: "bob", + Host: "example.org", + }, + }, + }, + expected: "From: bob@example.org", + }, + { + name: "Multiple mixed addresses", + env: &imap.Envelope{ + From: []imap.Address{ + { + Name: "Carol", + Mailbox: "carol", + Host: "example.net", + }, + { + Mailbox: "dave", + Host: "example.net", + }, + }, + }, + expected: "From: Carol , dave@example.net", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg := &Message{Envelope: tt.env} + got := msg.TextFrom() + if got != tt.expected { + t.Errorf("TextFrom() = %q, want %q", got, tt.expected) + } + }) + } +} + +// TestPartHelpers verifies Part.IsImage and Part.IsText methods. +func TestPartHelpers(t *testing.T) { + var tests = []struct { + contentType string + isText bool + isImage bool + }{ + {"application/pdf", false, false}, + {"audio/mpeg", false, false}, + {"image/gif", false, true}, + {"image/jpeg", false, true}, + {"image/png", false, true}, + {"image/svg+xml", false, false}, + {"image/webp", false, true}, + {"text/csv", true, false}, + {"text/html", true, false}, + {"text/plain", true, false}, + } + for _, tt := range tests { + t.Run(tt.contentType, func(t *testing.T) { + p := Part{ContentType: tt.contentType} + if got := p.IsText(); got != tt.isText { + t.Errorf( + "IsText: got %v, want %v", + got, tt.isText, + ) + } + if got := p.IsImage(); got != tt.isImage { + t.Errorf( + "IsImage: got %v, want %v", + got, tt.isImage, + ) + } + }) + } +} diff --git a/internal/smtp/smtp.go b/internal/smtp/smtp.go new file mode 100644 index 0000000..3a82a1d --- /dev/null +++ b/internal/smtp/smtp.go @@ -0,0 +1,66 @@ +// Package smtp provides a config wrapper around net/smtp for sending email. +package smtp + +import ( + "fmt" + "net/mail" + "net/smtp" +) + +// SMTP holds configuration for sending emails via an SMTP server. +type SMTP struct { + From string `yaml:"from"` + Host string `yaml:"host"` + Password string `yaml:"password"` + Port string `yaml:"port"` + User string `yaml:"user"` +} + +// Address returns the host:port string. +func (s *SMTP) Address() string { + return fmt.Sprintf("%s:%s", s.Host, s.Port) +} + +// Auth creates a PlainAuth using the configured credentials. +func (s *SMTP) Auth() smtp.Auth { + return smtp.PlainAuth("", s.User, s.Password, s.Host) +} + +// Send emails the given recipients using net/smtp. +// Refer to smtp.SendMail for its parameters and limitations. +func (s *SMTP) Send(to []string, msg []byte) error { + if err := smtp.SendMail( + s.Address(), + s.Auth(), + s.From, + to, + msg, + ); err != nil { + return fmt.Errorf("smtp send: %w", err) + } + + return nil +} + +// Validate checks that all required configuration fields are set. +func (s *SMTP) Validate() error { + if s.From == "" { + return fmt.Errorf("missing from") + } + if _, err := mail.ParseAddress(s.From); err != nil { + return fmt.Errorf("invalid from: %w", err) + } + if s.Host == "" { + return fmt.Errorf("missing host") + } + if s.Password == "" { + return fmt.Errorf("missing password") + } + if s.Port == "" { + return fmt.Errorf("missing port") + } + if s.User == "" { + return fmt.Errorf("missing user") + } + return nil +} diff --git a/internal/smtp/smtp_test.go b/internal/smtp/smtp_test.go new file mode 100644 index 0000000..c6260a4 --- /dev/null +++ b/internal/smtp/smtp_test.go @@ -0,0 +1,111 @@ +package smtp + +import ( + "testing" +) + +func TestAddress(t *testing.T) { + s := &SMTP{ + Host: "smtp.example.com", + Port: "587", + } + expected := "smtp.example.com:587" + if got := s.Address(); got != expected { + t.Errorf("Address() = %q, want %q", got, expected) + } +} + +func TestAuth(t *testing.T) { + s := &SMTP{ + User: "user", + Password: "password", + Host: "smtp.example.com", + } + + // Verify return of non-nil Auth mechanism. + if got := s.Auth(); got == nil { + t.Error("Auth() returned nil") + } +} + +func TestValidate(t *testing.T) { + tests := []struct { + name string + s *SMTP + shouldError bool + }{ + { + name: "valid configuration", + s: &SMTP{ + From: "sender@example.com", + Host: "smtp.example.com", + Password: "secretpassword", + Port: "587", + User: "user", + }, + shouldError: false, + }, + { + name: "missing from", + s: &SMTP{ + Host: "smtp.example.com", + Password: "p", + Port: "587", + User: "u", + }, + shouldError: true, + }, + { + name: "missing host", + s: &SMTP{ + From: "f", + Password: "p", + Port: "587", + User: "u", + }, + shouldError: true, + }, + { + name: "missing password", + s: &SMTP{ + From: "f", + Host: "h", + Port: "587", + User: "u", + }, + shouldError: true, + }, + { + name: "missing port", + s: &SMTP{ + From: "f", + Host: "h", + Password: "p", + User: "u", + }, + shouldError: true, + }, + { + name: "missing user", + s: &SMTP{ + From: "f", + Host: "h", + Password: "p", + Port: "587", + }, + shouldError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.s.Validate() + if tt.shouldError && err == nil { + t.Fatalf("expected error, got nil") + } + if !tt.shouldError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} diff --git a/internal/tool/tools.go b/internal/tool/tools.go new file mode 100644 index 0000000..a6244c9 --- /dev/null +++ b/internal/tool/tools.go @@ -0,0 +1,235 @@ +// Package tool provides a registry for external tools that can be invoked by +// LLMs. +// +// The package bridges YAML configuration to exec.CommandContext, allowing +// tools to be defined declaratively without writing Go code. Each tool +// specifies a command, argument templates using Go's text/template syntax, +// JSON Schema parameters for LLM input, and an optional execution timeout. +// +// The registry validates all tool definitions at construction time, +// failing fast on configuration errors. At execution time, it expands +// argument templates with LLM-provided JSON, runs the subprocess, and +// returns stdout. +package tool + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os/exec" + "text/template" + "time" + + "github.com/sashabaranov/go-openai" +) + +// fn provides template functions available to argument templates. +var fn = template.FuncMap{ + "json": func(v any) string { + b, _ := json.Marshal(v) + return string(b) + }, +} + +// Tool represents an external tool that can be invoked by LLMs. +// Tools are executed as subprocesses with templated arguments. +type Tool struct { + // Name uniquely identifies the tool within a registry. + Name string `yaml:"name"` + // Description explains the tool's purpose for the LLM. + Description string `yaml:"description"` + // Command is the executable path or name. + Command string `yaml:"command"` + // Arguments are Go templates expanded with LLM-provided parameters. + // Empty results after expansion are filtered out. + Arguments []string `yaml:"arguments"` + // Parameters is a JSON Schema describing expected input from the LLM. + Parameters map[string]any `yaml:"parameters"` + // Timeout limits execution time (e.g., "30s", "5m"). + // Empty means no timeout. + Timeout string `yaml:"timeout"` + // timeout is the parsed duration, set during registry construction. + timeout time.Duration `yaml:"-"` +} + +// OpenAI converts the tool to an OpenAI function definition for API calls. +func (t Tool) OpenAI() openai.Tool { + return openai.Tool{ + Type: openai.ToolTypeFunction, + Function: &openai.FunctionDefinition{ + Name: t.Name, + Description: t.Description, + Parameters: t.Parameters, + }, + } +} + +// ParseArguments expands argument templates with the provided JSON data. +// The args parameter should be a JSON object string; empty string or "{}" +// results in an empty data map. Templates producing empty strings are +// filtered from the result, allowing conditional arguments. +func (t Tool) ParseArguments(args string) ([]string, error) { + var data = map[string]any{} + if args != "" && args != "{}" { + if err := json.Unmarshal([]byte(args), &data); err != nil { + return nil, fmt.Errorf( + "invalid arguments JSON: %w", err, + ) + } + } + + var result []string + for _, v := range t.Arguments { + tmpl, err := template.New("").Funcs(fn).Parse(v) + if err != nil { + return nil, fmt.Errorf( + "invalid template %q: %w", v, err, + ) + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return nil, fmt.Errorf( + "execute template %q: %w", v, err, + ) + } + + // Filter out empty strings (unused conditional arguments). + if s := buf.String(); s != "" { + result = append(result, s) + } + } + + return result, nil +} + +// Validate checks that the tool definition is complete and valid. +// It verifies required fields are present, the timeout (if specified) +// is parseable, and all argument templates are syntactically valid. +func (t Tool) Validate() error { + if t.Name == "" { + return fmt.Errorf("missing name") + } + if t.Description == "" { + return fmt.Errorf("missing description") + } + if t.Command == "" { + return fmt.Errorf("missing command") + } + if t.Timeout != "" { + if _, err := time.ParseDuration(t.Timeout); err != nil { + return fmt.Errorf("invalid timeout: %v", err) + } + } + for _, arg := range t.Arguments { + if _, err := template.New("").Funcs(fn).Parse(arg); err != nil { + return fmt.Errorf("invalid argument template") + } + } + + return nil +} + +// Registry holds tools indexed by name and handles their execution. +// It validates all tools at construction time to fail fast on +// configuration errors. +type Registry struct { + tools map[string]*Tool +} + +// NewRegistry creates a registry from the provided tool definitions. +// Returns an error if any tool fails validation or if duplicate names exist. +func NewRegistry(tools []Tool) (*Registry, error) { + var r = &Registry{ + tools: make(map[string]*Tool), + } + for _, t := range tools { + if err := t.Validate(); err != nil { + return nil, fmt.Errorf("invalid tool: %v", err) + } + if t.Timeout != "" { + d, err := time.ParseDuration(t.Timeout) + if err != nil { + return nil, fmt.Errorf( + "parse timeout: %v", err, + ) + } + t.timeout = d + } + + if _, exists := r.tools[t.Name]; exists { + return nil, fmt.Errorf( + "duplicate tool name: %s", t.Name, + ) + } + r.tools[t.Name] = &t + } + + return r, nil +} + +// Get returns a tool by name and a boolean indicating if it was found. +func (r *Registry) Get(name string) (*Tool, bool) { + tool, ok := r.tools[name] + return tool, ok +} + +// List returns all registered tool names in arbitrary order. +func (r *Registry) List() []string { + names := make([]string, 0, len(r.tools)) + for name := range r.tools { + names = append(names, name) + } + return names +} + +// Execute runs a tool by name with the provided JSON arguments. +// It expands argument templates, executes the command as a subprocess, +// and returns stdout on success. The context can be used for cancellation; +// tool-specific timeouts are applied on top of any context deadline. +func (r *Registry) Execute( + ctx context.Context, name string, args string, +) (string, error) { + tool, ok := r.tools[name] + if !ok { + return "", fmt.Errorf("unknown tool: %s", name) + } + + // Evaluate argument templates. + cmdArgs, err := tool.ParseArguments(args) + if err != nil { + return "", fmt.Errorf("parse arguments: %w", err) + } + + // If defined, use the timeout. + if tool.timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, tool.timeout) + defer cancel() + } + + // Setup and run the command. + var ( + stdout, stderr bytes.Buffer + cmd = exec.CommandContext( + ctx, tool.Command, cmdArgs..., + ) + ) + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + if ctx.Err() == context.DeadlineExceeded && tool.timeout > 0 { + return "", fmt.Errorf( + "tool %s timed out after %v", + name, tool.timeout, + ) + } + return "", fmt.Errorf( + "tool %s: %w\nstderr: %s", + name, err, stderr.String(), + ) + } + + return stdout.String(), nil +} diff --git a/internal/tool/tools_test.go b/internal/tool/tools_test.go new file mode 100644 index 0000000..b01c0d9 --- /dev/null +++ b/internal/tool/tools_test.go @@ -0,0 +1,374 @@ +package tool + +import ( + "context" + "slices" + "testing" + "time" +) + +func TestToolValidate(t *testing.T) { + tests := []struct { + name string + tool Tool + wantErr bool + }{ + { + name: "valid tool", + tool: Tool{ + Name: "echo", + Description: "echoes input", + Command: "echo", + Arguments: []string{"{{.message}}"}, + }, + wantErr: false, + }, + { + name: "valid tool with timeout", + tool: Tool{ + Name: "slow", + Description: "slow command", + Command: "sleep", + Arguments: []string{"1"}, + Timeout: "5s", + }, + wantErr: false, + }, + { + name: "missing name", + tool: Tool{ + Description: "test", + Command: "echo", + }, + wantErr: true, + }, + { + name: "missing description", + tool: Tool{ + Name: "test", + Command: "echo", + }, + wantErr: true, + }, + { + name: "missing command", + tool: Tool{ + Name: "test", + Description: "test", + }, + wantErr: true, + }, + { + name: "invalid timeout", + tool: Tool{ + Name: "test", + Description: "test", + Command: "echo", + Timeout: "invalid", + }, + wantErr: true, + }, + { + name: "invalid template", + tool: Tool{ + Name: "test", + Description: "test", + Command: "echo", + Arguments: []string{"{{.unclosed"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.tool.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestToolParseArguments(t *testing.T) { + tests := []struct { + name string + tool Tool + args string + want []string + wantErr bool + }{ + { + name: "simple substitution", + tool: Tool{ + Arguments: []string{"{{.message}}"}, + }, + args: `{"message": "hello"}`, + want: []string{"hello"}, + }, + { + name: "multiple arguments", + tool: Tool{ + Arguments: []string{"-n", "{{.count}}", "{{.file}}"}, + }, + args: `{"count": "10", "file": "test.txt"}`, + want: []string{"-n", "10", "test.txt"}, + }, + { + name: "conditional with if", + tool: Tool{ + Arguments: []string{"{{.required}}", "{{if .optional}}{{.optional}}{{end}}"}, + }, + args: `{"required": "value"}`, + want: []string{"value"}, + }, + { + name: "json function", + tool: Tool{ + Arguments: []string{`{{json .data}}`}, + }, + args: `{"data": {"key": "value"}}`, + want: []string{`{"key":"value"}`}, + }, + { + name: "empty JSON object", + tool: Tool{ + Arguments: []string{"fixed"}, + }, + args: "{}", + want: []string{"fixed"}, + }, + { + name: "empty string args", + tool: Tool{ + Arguments: []string{"fixed"}, + }, + args: "", + want: []string{"fixed"}, + }, + { + name: "invalid JSON", + tool: Tool{ + Arguments: []string{"{{.x}}"}, + }, + args: "not json", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.tool.ParseArguments(tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("ParseArguments() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !slices.Equal(got, tt.want) { + t.Errorf("ParseArguments() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewRegistry(t *testing.T) { + tests := []struct { + name string + tools []Tool + wantErr bool + }{ + { + name: "valid registry", + tools: []Tool{ + {Name: "a", Description: "a", Command: "echo"}, + {Name: "b", Description: "b", Command: "cat"}, + }, + wantErr: false, + }, + { + name: "empty registry", + tools: []Tool{}, + wantErr: false, + }, + { + name: "duplicate names", + tools: []Tool{ + {Name: "dup", Description: "first", Command: "echo"}, + {Name: "dup", Description: "second", Command: "cat"}, + }, + wantErr: true, + }, + { + name: "invalid tool", + tools: []Tool{ + {Name: "", Description: "missing name", Command: "echo"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewRegistry(tt.tools) + if (err != nil) != tt.wantErr { + t.Errorf("NewRegistry() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestRegistryGet(t *testing.T) { + r, err := NewRegistry([]Tool{ + {Name: "echo", Description: "echo", Command: "echo"}, + }) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + tool, ok := r.Get("echo") + if !ok { + t.Error("Get(echo) returned false, want true") + } + if tool.Name != "echo" { + t.Errorf("Get(echo).Name = %q, want %q", tool.Name, "echo") + } + + _, ok = r.Get("nonexistent") + if ok { + t.Error("Get(nonexistent) returned true, want false") + } +} + +func TestRegistryList(t *testing.T) { + r, err := NewRegistry([]Tool{ + {Name: "a", Description: "a", Command: "echo"}, + {Name: "b", Description: "b", Command: "cat"}, + {Name: "c", Description: "c", Command: "ls"}, + }) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + names := r.List() + if len(names) != 3 { + t.Errorf("List() returned %d names, want 3", len(names)) + } + + slices.Sort(names) + want := []string{"a", "b", "c"} + if !slices.Equal(names, want) { + t.Errorf("List() = %v, want %v", names, want) + } +} + +func TestRegistryExecute(t *testing.T) { + r, err := NewRegistry([]Tool{ + { + Name: "echo", + Description: "echo message", + Command: "echo", + Arguments: []string{"{{.message}}"}, + }, + { + Name: "cat", + Description: "cat stdin", + Command: "cat", + Arguments: []string{}, + }, + }) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + ctx := context.Background() + + // Successful execution. + out, err := r.Execute(ctx, "echo", `{"message": "hello world"}`) + if err != nil { + t.Errorf("Execute(echo) error = %v", err) + } + if out != "hello world\n" { + t.Errorf("Execute(echo) = %q, want %q", out, "hello world\n") + } + + // Unknown tool. + _, err = r.Execute(ctx, "unknown", "{}") + if err == nil { + t.Error("Execute(unknown) expected error, got nil") + } + + // Invalid JSON arguments. + _, err = r.Execute(ctx, "echo", "not json") + if err == nil { + t.Error("Execute with invalid JSON expected error, got nil") + } +} + +func TestRegistryExecuteTimeout(t *testing.T) { + r, err := NewRegistry([]Tool{ + { + Name: "slow", + Description: "slow command", + Command: "sleep", + Arguments: []string{"10"}, + Timeout: "50ms", + }, + }) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + ctx := context.Background() + start := time.Now() + _, err = r.Execute(ctx, "slow", "{}") + elapsed := time.Since(start) + + if err == nil { + t.Error("Execute(slow) expected timeout error, got nil") + } + if elapsed > time.Second { + t.Errorf("Execute(slow) took %v, expected ~50ms timeout", elapsed) + } +} + +func TestRegistryExecuteFailure(t *testing.T) { + r, err := NewRegistry([]Tool{ + { + Name: "fail", + Description: "always fails", + Command: "false", + }, + }) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + _, err = r.Execute(context.Background(), "fail", "{}") + if err == nil { + t.Error("Execute(fail) expected error, got nil") + } +} + +func TestToolOpenAI(t *testing.T) { + tool := Tool{ + Name: "test", + Description: "test tool", + Command: "echo", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "the message", + }, + }, + }, + } + + openaiTool := tool.OpenAI() + if openaiTool.Function.Name != "test" { + t.Errorf("OpenAI().Function.Name = %q, want %q", openaiTool.Function.Name, "test") + } + if openaiTool.Function.Description != "test tool" { + t.Errorf("OpenAI().Function.Description = %q, want %q", openaiTool.Function.Description, "test tool") + } +}