Files
odidere/internal/llm/llm.go

411 lines
9.9 KiB
Go
Raw Normal View History

2026-02-13 15:02:30 +00:00
// Package llm provides an OpenAI-compatible client for LLM interactions.
// It handles chat completions with automatic tool call execution.
package llm
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"strings"
"time"
2026-02-21 19:47:00 +00:00
"code.chimeric.al/chimerical/odidere/internal/tool"
2026-02-13 15:02:30 +00:00
openai "github.com/sashabaranov/go-openai"
)
// Config holds the configuration for an LLM client.
type Config struct {
// Key is the API key for authentication.
Key string `yaml:"key"`
// Model is the model identifier.
Model string `yaml:"model"`
// SystemPrompt is prepended to all conversations.
SystemPrompt string `yaml:"system_prompt"`
// Timeout is the maximum duration for a query (e.g., "5m").
// Defaults to 5 minutes if empty.
Timeout string `yaml:"timeout"`
// URL is the base URL of the OpenAI-compatible API endpoint.
URL string `yaml:"url"`
}
// Validate checks that required configuration values are present and valid.
func (cfg Config) Validate() error {
if cfg.Model == "" {
return fmt.Errorf("missing model")
}
if cfg.Timeout != "" {
if _, err := time.ParseDuration(cfg.Timeout); err != nil {
return fmt.Errorf("invalid timeout: %w", err)
}
}
if cfg.URL == "" {
return fmt.Errorf("missing URL")
}
return nil
}
// Client wraps an OpenAI-compatible client with tool execution support.
type Client struct {
client *openai.Client
log *slog.Logger
model string
registry *tool.Registry
systemPrompt string
timeout time.Duration
tools []openai.Tool
}
// NewClient creates a new LLM client with the provided configuration.
// The registry is optional; if nil, tool calling is disabled.
func NewClient(
cfg Config,
registry *tool.Registry,
log *slog.Logger,
) (*Client, error) {
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
llm := &Client{
log: log,
model: cfg.Model,
systemPrompt: cfg.SystemPrompt,
registry: registry,
}
if cfg.Timeout == "" {
llm.timeout = 5 * time.Minute
} else {
d, err := time.ParseDuration(cfg.Timeout)
if err != nil {
return nil, fmt.Errorf("parse timeout: %v", err)
}
llm.timeout = d
}
// Setup client.
clientConfig := openai.DefaultConfig(cfg.Key)
clientConfig.BaseURL = cfg.URL
llm.client = openai.NewClientWithConfig(clientConfig)
// Parse tools.
if llm.registry != nil {
for _, name := range llm.registry.List() {
t, _ := llm.registry.Get(name)
llm.tools = append(llm.tools, t.OpenAI())
}
}
return llm, nil
}
// ListModels returns available models from the LLM server.
func (c *Client) ListModels(ctx context.Context) ([]openai.Model, error) {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
res, err := c.client.ListModels(ctx)
if err != nil {
return nil, fmt.Errorf("listing models: %w", err)
}
return res.Models, nil
}
// DefaultModel returns the configured default model.
func (c *Client) DefaultModel() string {
return c.model
}
// Query sends messages to the LLM using the specified model.
// If model is empty, uses the default configured model.
// Returns all messages generated during the query, including tool calls
// and tool results. The final message is the last element in the slice.
func (c *Client) Query(
ctx context.Context,
messages []openai.ChatCompletionMessage,
model string,
) ([]openai.ChatCompletionMessage, error) {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
// Fallback to the default model.
if model == "" {
model = c.model
}
// Prepend system prompt, if configured and not already present.
if c.systemPrompt != "" && (len(messages) == 0 ||
messages[0].Role != openai.ChatMessageRoleSystem) {
messages = append(
[]openai.ChatCompletionMessage{{
Role: openai.ChatMessageRoleSystem,
Content: c.systemPrompt,
}},
messages...,
)
}
// Track messages generated during this query.
var generated []openai.ChatCompletionMessage
// Loop for tool calls.
for {
req := openai.ChatCompletionRequest{
Model: model,
Messages: messages,
}
if len(c.tools) > 0 {
req.Tools = c.tools
}
res, err := c.client.CreateChatCompletion(ctx, req)
if err != nil {
return nil, fmt.Errorf("chat completion: %w", err)
}
if len(res.Choices) == 0 {
return nil, fmt.Errorf("no response choices returned")
}
choice := res.Choices[0]
message := choice.Message
// If no tool calls, we're done.
if len(message.ToolCalls) == 0 {
generated = append(generated, message)
return generated, nil
}
// Add assistant message with tool calls to history.
generated = append(generated, message)
messages = append(messages, message)
// Process each tool call.
for _, tc := range message.ToolCalls {
c.log.InfoContext(
ctx,
"calling tool",
slog.String("name", tc.Function.Name),
slog.String("args", tc.Function.Arguments),
)
result, err := c.registry.Execute(
ctx, tc.Function.Name, tc.Function.Arguments,
)
if err != nil {
c.log.Error(
"failed to call tool",
slog.Any("error", err),
slog.String("name", tc.Function.Name),
)
result = fmt.Sprintf(
`{"ok": false, "error": %q}`, err,
)
} else {
c.log.Info(
"called tool",
slog.String("name", tc.Function.Name),
)
}
// Content cannot be empty.
if strings.TrimSpace(result) == "" {
result = `{"ok": true, "result": null}`
}
// Add tool result to messages.
toolResult := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleTool,
Content: result,
Name: tc.Function.Name,
ToolCallID: tc.ID,
}
generated = append(generated, toolResult)
messages = append(messages, toolResult)
}
// Loop to get LLM's response after tool execution.
}
}
// StreamEvent wraps a ChatCompletionMessage produced during streaming.
type StreamEvent struct {
Message openai.ChatCompletionMessage
}
// QueryStream sends messages to the LLM using the specified model and
// streams results. Each complete message (assistant reply, tool call,
// tool result) is sent to the events channel as it becomes available.
// The channel is closed before returning.
// Returns all messages generated during the query.
func (c *Client) QueryStream(
ctx context.Context,
messages []openai.ChatCompletionMessage,
model string,
events chan<- StreamEvent,
) error {
defer close(events)
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
// Fallback to the default model.
if model == "" {
model = c.model
}
// Prepend system prompt, if configured and not already present.
if c.systemPrompt != "" && (len(messages) == 0 ||
messages[0].Role != openai.ChatMessageRoleSystem) {
messages = append(
[]openai.ChatCompletionMessage{{
Role: openai.ChatMessageRoleSystem,
Content: c.systemPrompt,
}},
messages...,
)
}
// Loop for tool calls.
for {
req := openai.ChatCompletionRequest{
Model: model,
Messages: messages,
}
if len(c.tools) > 0 {
req.Tools = c.tools
}
stream, err := c.client.CreateChatCompletionStream(ctx, req)
if err != nil {
return fmt.Errorf("chat completion stream: %w", err)
}
// Accumulate the streamed response.
var (
content strings.Builder
reasoning strings.Builder
toolCalls []openai.ToolCall
role string
)
for {
chunk, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
stream.Close()
return fmt.Errorf("stream recv: %w", err)
}
if len(chunk.Choices) == 0 {
continue
}
// Check the first Choice. Only one is expected, since
// our request does not set N > 1.
delta := chunk.Choices[0].Delta
if delta.Role != "" {
role = delta.Role
}
if delta.Content != "" {
content.WriteString(delta.Content)
}
if delta.ReasoningContent != "" {
reasoning.WriteString(delta.ReasoningContent)
}
// Accumulate tool call deltas by index.
for _, tc := range delta.ToolCalls {
i := 0
if tc.Index != nil {
i = *tc.Index
}
// Grow the slice as needed.
for len(toolCalls) <= i {
toolCalls = append(
toolCalls,
openai.ToolCall{
Type: openai.ToolTypeFunction,
},
)
}
if tc.ID != "" {
toolCalls[i].ID = tc.ID
}
if tc.Function.Name != "" {
toolCalls[i].Function.Name +=
tc.Function.Name
}
if tc.Function.Arguments != "" {
toolCalls[i].Function.Arguments +=
tc.Function.Arguments
}
}
}
stream.Close()
// Build the complete message from accumulated buffers.
message := openai.ChatCompletionMessage{
Role: role,
Content: content.String(),
ReasoningContent: reasoning.String(),
ToolCalls: toolCalls,
}
events <- StreamEvent{Message: message}
// If no tool calls, we're done.
if len(toolCalls) == 0 {
return nil
}
// Add assistant message with tool calls to history.
messages = append(messages, message)
// Process each tool call.
for _, tc := range message.ToolCalls {
c.log.InfoContext(
ctx,
"calling tool",
slog.String("name", tc.Function.Name),
slog.String("args", tc.Function.Arguments),
)
result, err := c.registry.Execute(
ctx, tc.Function.Name, tc.Function.Arguments,
)
if err != nil {
c.log.Error(
"failed to call tool",
slog.Any("error", err),
slog.String("name", tc.Function.Name),
)
result = fmt.Sprintf(
`{"ok": false, "error": %q}`, err,
)
} else {
c.log.Info(
"called tool",
slog.String("name", tc.Function.Name),
)
}
// Content cannot be empty.
if strings.TrimSpace(result) == "" {
result = `{"ok": true, "result": null}`
}
// Add tool result to messages.
toolResult := openai.ChatCompletionMessage{
Content: result,
Name: tc.Function.Name,
Role: openai.ChatMessageRoleTool,
ToolCallID: tc.ID,
}
messages = append(messages, toolResult)
events <- StreamEvent{Message: toolResult}
}
// Loop to get LLM's response after tool execution.
}
}