From 6f0509ff183a35f3d4fa69f45ce0bf05c572d931 Mon Sep 17 00:00:00 2001 From: dwrz Date: Fri, 13 Feb 2026 15:02:30 +0000 Subject: [PATCH] Add LLM client --- internal/llm/llm.go | 410 +++++++++++++++++++++++++++++++++++++++ internal/llm/llm_test.go | 168 ++++++++++++++++ 2 files changed, 578 insertions(+) create mode 100644 internal/llm/llm.go create mode 100644 internal/llm/llm_test.go diff --git a/internal/llm/llm.go b/internal/llm/llm.go new file mode 100644 index 0000000..fd6f1a3 --- /dev/null +++ b/internal/llm/llm.go @@ -0,0 +1,410 @@ +// Package llm provides an OpenAI-compatible client for LLM interactions. +// It handles chat completions with automatic tool call execution. +package llm + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "strings" + "time" + + "github.com/chimerical-llc/odidere/internal/tool" + + openai "github.com/sashabaranov/go-openai" +) + +// Config holds the configuration for an LLM client. +type Config struct { + // Key is the API key for authentication. + Key string `yaml:"key"` + // Model is the model identifier. + Model string `yaml:"model"` + // SystemPrompt is prepended to all conversations. + SystemPrompt string `yaml:"system_prompt"` + // Timeout is the maximum duration for a query (e.g., "5m"). + // Defaults to 5 minutes if empty. + Timeout string `yaml:"timeout"` + // URL is the base URL of the OpenAI-compatible API endpoint. + URL string `yaml:"url"` +} + +// Validate checks that required configuration values are present and valid. +func (cfg Config) Validate() error { + if cfg.Model == "" { + return fmt.Errorf("missing model") + } + if cfg.Timeout != "" { + if _, err := time.ParseDuration(cfg.Timeout); err != nil { + return fmt.Errorf("invalid timeout: %w", err) + } + } + if cfg.URL == "" { + return fmt.Errorf("missing URL") + } + return nil +} + +// Client wraps an OpenAI-compatible client with tool execution support. +type Client struct { + client *openai.Client + log *slog.Logger + model string + registry *tool.Registry + systemPrompt string + timeout time.Duration + tools []openai.Tool +} + +// NewClient creates a new LLM client with the provided configuration. +// The registry is optional; if nil, tool calling is disabled. +func NewClient( + cfg Config, + registry *tool.Registry, + log *slog.Logger, +) (*Client, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + llm := &Client{ + log: log, + model: cfg.Model, + systemPrompt: cfg.SystemPrompt, + registry: registry, + } + + if cfg.Timeout == "" { + llm.timeout = 5 * time.Minute + } else { + d, err := time.ParseDuration(cfg.Timeout) + if err != nil { + return nil, fmt.Errorf("parse timeout: %v", err) + } + llm.timeout = d + } + + // Setup client. + clientConfig := openai.DefaultConfig(cfg.Key) + clientConfig.BaseURL = cfg.URL + llm.client = openai.NewClientWithConfig(clientConfig) + + // Parse tools. + if llm.registry != nil { + for _, name := range llm.registry.List() { + t, _ := llm.registry.Get(name) + llm.tools = append(llm.tools, t.OpenAI()) + } + } + + return llm, nil +} + +// ListModels returns available models from the LLM server. +func (c *Client) ListModels(ctx context.Context) ([]openai.Model, error) { + ctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + res, err := c.client.ListModels(ctx) + if err != nil { + return nil, fmt.Errorf("listing models: %w", err) + } + + return res.Models, nil +} + +// DefaultModel returns the configured default model. +func (c *Client) DefaultModel() string { + return c.model +} + +// Query sends messages to the LLM using the specified model. +// If model is empty, uses the default configured model. +// Returns all messages generated during the query, including tool calls +// and tool results. The final message is the last element in the slice. +func (c *Client) Query( + ctx context.Context, + messages []openai.ChatCompletionMessage, + model string, +) ([]openai.ChatCompletionMessage, error) { + ctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + // Fallback to the default model. + if model == "" { + model = c.model + } + + // Prepend system prompt, if configured and not already present. + if c.systemPrompt != "" && (len(messages) == 0 || + messages[0].Role != openai.ChatMessageRoleSystem) { + messages = append( + []openai.ChatCompletionMessage{{ + Role: openai.ChatMessageRoleSystem, + Content: c.systemPrompt, + }}, + messages..., + ) + } + + // Track messages generated during this query. + var generated []openai.ChatCompletionMessage + + // Loop for tool calls. + for { + req := openai.ChatCompletionRequest{ + Model: model, + Messages: messages, + } + if len(c.tools) > 0 { + req.Tools = c.tools + } + + res, err := c.client.CreateChatCompletion(ctx, req) + if err != nil { + return nil, fmt.Errorf("chat completion: %w", err) + } + if len(res.Choices) == 0 { + return nil, fmt.Errorf("no response choices returned") + } + + choice := res.Choices[0] + message := choice.Message + + // If no tool calls, we're done. + if len(message.ToolCalls) == 0 { + generated = append(generated, message) + return generated, nil + } + + // Add assistant message with tool calls to history. + generated = append(generated, message) + messages = append(messages, message) + + // Process each tool call. + for _, tc := range message.ToolCalls { + c.log.InfoContext( + ctx, + "calling tool", + slog.String("name", tc.Function.Name), + slog.String("args", tc.Function.Arguments), + ) + + result, err := c.registry.Execute( + ctx, tc.Function.Name, tc.Function.Arguments, + ) + if err != nil { + c.log.Error( + "failed to call tool", + slog.Any("error", err), + slog.String("name", tc.Function.Name), + ) + result = fmt.Sprintf( + `{"ok": false, "error": %q}`, err, + ) + } else { + c.log.Info( + "called tool", + slog.String("name", tc.Function.Name), + ) + } + + // Content cannot be empty. + if strings.TrimSpace(result) == "" { + result = `{"ok": true, "result": null}` + } + + // Add tool result to messages. + toolResult := openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleTool, + Content: result, + Name: tc.Function.Name, + ToolCallID: tc.ID, + } + generated = append(generated, toolResult) + messages = append(messages, toolResult) + } + // Loop to get LLM's response after tool execution. + } +} + +// StreamEvent wraps a ChatCompletionMessage produced during streaming. +type StreamEvent struct { + Message openai.ChatCompletionMessage +} + +// QueryStream sends messages to the LLM using the specified model and +// streams results. Each complete message (assistant reply, tool call, +// tool result) is sent to the events channel as it becomes available. +// The channel is closed before returning. +// Returns all messages generated during the query. +func (c *Client) QueryStream( + ctx context.Context, + messages []openai.ChatCompletionMessage, + model string, + events chan<- StreamEvent, +) error { + defer close(events) + + ctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + // Fallback to the default model. + if model == "" { + model = c.model + } + + // Prepend system prompt, if configured and not already present. + if c.systemPrompt != "" && (len(messages) == 0 || + messages[0].Role != openai.ChatMessageRoleSystem) { + messages = append( + []openai.ChatCompletionMessage{{ + Role: openai.ChatMessageRoleSystem, + Content: c.systemPrompt, + }}, + messages..., + ) + } + + // Loop for tool calls. + for { + req := openai.ChatCompletionRequest{ + Model: model, + Messages: messages, + } + if len(c.tools) > 0 { + req.Tools = c.tools + } + + stream, err := c.client.CreateChatCompletionStream(ctx, req) + if err != nil { + return fmt.Errorf("chat completion stream: %w", err) + } + + // Accumulate the streamed response. + var ( + content strings.Builder + reasoning strings.Builder + toolCalls []openai.ToolCall + role string + ) + for { + chunk, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + stream.Close() + return fmt.Errorf("stream recv: %w", err) + } + if len(chunk.Choices) == 0 { + continue + } + + // Check the first Choice. Only one is expected, since + // our request does not set N > 1. + delta := chunk.Choices[0].Delta + if delta.Role != "" { + role = delta.Role + } + if delta.Content != "" { + content.WriteString(delta.Content) + } + if delta.ReasoningContent != "" { + reasoning.WriteString(delta.ReasoningContent) + } + + // Accumulate tool call deltas by index. + for _, tc := range delta.ToolCalls { + i := 0 + if tc.Index != nil { + i = *tc.Index + } + // Grow the slice as needed. + for len(toolCalls) <= i { + toolCalls = append( + toolCalls, + openai.ToolCall{ + Type: openai.ToolTypeFunction, + }, + ) + } + if tc.ID != "" { + toolCalls[i].ID = tc.ID + } + if tc.Function.Name != "" { + toolCalls[i].Function.Name += + tc.Function.Name + } + if tc.Function.Arguments != "" { + toolCalls[i].Function.Arguments += + tc.Function.Arguments + } + } + } + stream.Close() + + // Build the complete message from accumulated buffers. + message := openai.ChatCompletionMessage{ + Role: role, + Content: content.String(), + ReasoningContent: reasoning.String(), + ToolCalls: toolCalls, + } + events <- StreamEvent{Message: message} + + // If no tool calls, we're done. + if len(toolCalls) == 0 { + return nil + } + + // Add assistant message with tool calls to history. + messages = append(messages, message) + + // Process each tool call. + for _, tc := range message.ToolCalls { + c.log.InfoContext( + ctx, + "calling tool", + slog.String("name", tc.Function.Name), + slog.String("args", tc.Function.Arguments), + ) + + result, err := c.registry.Execute( + ctx, tc.Function.Name, tc.Function.Arguments, + ) + if err != nil { + c.log.Error( + "failed to call tool", + slog.Any("error", err), + slog.String("name", tc.Function.Name), + ) + result = fmt.Sprintf( + `{"ok": false, "error": %q}`, err, + ) + } else { + c.log.Info( + "called tool", + slog.String("name", tc.Function.Name), + ) + } + // Content cannot be empty. + if strings.TrimSpace(result) == "" { + result = `{"ok": true, "result": null}` + } + + // Add tool result to messages. + toolResult := openai.ChatCompletionMessage{ + Content: result, + Name: tc.Function.Name, + Role: openai.ChatMessageRoleTool, + ToolCallID: tc.ID, + } + messages = append(messages, toolResult) + events <- StreamEvent{Message: toolResult} + } + // Loop to get LLM's response after tool execution. + } +} diff --git a/internal/llm/llm_test.go b/internal/llm/llm_test.go new file mode 100644 index 0000000..967d0b8 --- /dev/null +++ b/internal/llm/llm_test.go @@ -0,0 +1,168 @@ +package llm + +import ( + "testing" +) + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + cfg Config + wantErr bool + }{ + { + name: "empty config", + cfg: Config{}, + wantErr: true, + }, + { + name: "missing model", + cfg: Config{ + URL: "http://localhost:8080", + }, + wantErr: true, + }, + { + name: "missing URL", + cfg: Config{ + Model: "test-model", + }, + wantErr: true, + }, + { + name: "invalid timeout", + cfg: Config{ + Model: "test-model", + URL: "http://localhost:8080", + Timeout: "not-a-duration", + }, + wantErr: true, + }, + { + name: "valid minimal config", + cfg: Config{ + Model: "test-model", + URL: "http://localhost:8080", + }, + wantErr: false, + }, + { + name: "valid config with timeout", + cfg: Config{ + Model: "test-model", + URL: "http://localhost:8080", + Timeout: "15m", + }, + wantErr: false, + }, + { + name: "valid config with complex timeout", + cfg: Config{ + Model: "test-model", + URL: "http://localhost:8080", + Timeout: "1h30m45s", + }, + wantErr: false, + }, + { + name: "valid full config", + cfg: Config{ + Key: "sk-test-key", + Model: "test-model", + SystemPrompt: "You are a helpful assistant.", + Timeout: "30m", + URL: "http://localhost:8080", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.Validate() + if (err != nil) != tt.wantErr { + t.Errorf( + "Validate() error = %v, wantErr %v", + err, tt.wantErr, + ) + } + }) + } +} + +func TestNewClient(t *testing.T) { + tests := []struct { + name string + cfg Config + wantErr bool + }{ + { + name: "invalid config", + cfg: Config{}, + wantErr: true, + }, + { + name: "valid config without timeout", + cfg: Config{ + Model: "test-model", + URL: "http://localhost:8080", + }, + wantErr: false, + }, + { + name: "valid config with timeout", + cfg: Config{ + Model: "test-model", + URL: "http://localhost:8080", + Timeout: "10m", + }, + wantErr: false, + }, + { + name: "valid config with all fields", + cfg: Config{ + Key: "test-key", + Model: "test-model", + SystemPrompt: "Test prompt", + Timeout: "5m", + URL: "http://localhost:8080", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewClient(tt.cfg, nil, nil) + if (err != nil) != tt.wantErr { + t.Errorf( + "NewClient() error = %v, wantErr %v", + err, tt.wantErr, + ) + return + } + if !tt.wantErr && client == nil { + t.Error("NewClient() returned nil client") + } + }) + } +} + +func TestClientDefaultModel(t *testing.T) { + cfg := Config{ + Model: "my-custom-model", + URL: "http://localhost:8080", + } + + client, err := NewClient(cfg, nil, nil) + if err != nil { + t.Fatalf("NewClient() error = %v", err) + } + + if got := client.DefaultModel(); got != "my-custom-model" { + t.Errorf( + "DefaultModel() = %q, want %q", + got, "my-custom-model", + ) + } +}