208 lines
5.2 KiB
Go
208 lines
5.2 KiB
Go
// Package llm provides an OpenAI-compatible client for LLM interactions.
|
|
// It handles chat completions with automatic tool call execution.
|
|
package llm
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/chimerical-llc/raven/internal/message"
|
|
"github.com/chimerical-llc/raven/internal/tool"
|
|
|
|
openai "github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
// Config holds the configuration for an LLM client.
|
|
type Config struct {
|
|
// Key is the API key for authentication.
|
|
Key string `yaml:"key"`
|
|
// Model is the model identifier (e.g., "gpt-oss-120b", "qwen3-32b").
|
|
Model string `yaml:"model"`
|
|
// SystemPrompt is prepended to all conversations.
|
|
SystemPrompt string `yaml:"system_prompt"`
|
|
// Timeout is the maximum duration for a query (e.g., "15m").
|
|
// Defaults to 15 minutes if empty.
|
|
Timeout string `yaml:"timeout"`
|
|
// URL is the base URL of the OpenAI-compatible API endpoint.
|
|
URL string `yaml:"url"`
|
|
}
|
|
|
|
// Validate checks that required configuration values are present and valid.
|
|
func (cfg Config) Validate() error {
|
|
if cfg.Model == "" {
|
|
return fmt.Errorf("missing model")
|
|
}
|
|
if cfg.Timeout != "" {
|
|
if _, err := time.ParseDuration(cfg.Timeout); err != nil {
|
|
return fmt.Errorf("invalid duration")
|
|
}
|
|
}
|
|
if cfg.URL == "" {
|
|
return fmt.Errorf("missing URL")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Client wraps an OpenAI-compatible client with tool execution support.
|
|
type Client struct {
|
|
client *openai.Client
|
|
log *slog.Logger
|
|
model string
|
|
registry *tool.Registry
|
|
systemPrompt string
|
|
timeout time.Duration
|
|
tools []openai.Tool
|
|
}
|
|
|
|
// NewClient creates a new LLM client with the provided configuration.
|
|
// The registry is optional; if nil, tool calling is disabled.
|
|
func NewClient(
|
|
cfg Config,
|
|
registry *tool.Registry,
|
|
log *slog.Logger,
|
|
) (*Client, error) {
|
|
if err := cfg.Validate(); err != nil {
|
|
return nil, fmt.Errorf("invalid config: %v", err)
|
|
}
|
|
var llm = &Client{
|
|
log: log,
|
|
model: cfg.Model,
|
|
systemPrompt: cfg.SystemPrompt,
|
|
registry: registry,
|
|
}
|
|
if cfg.Timeout == "" {
|
|
llm.timeout = time.Duration(15 * time.Minute)
|
|
} else {
|
|
d, err := time.ParseDuration(cfg.Timeout)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse timeout: %v", err)
|
|
}
|
|
llm.timeout = d
|
|
}
|
|
|
|
// Setup client.
|
|
clientConfig := openai.DefaultConfig(cfg.Key)
|
|
clientConfig.BaseURL = cfg.URL
|
|
llm.client = openai.NewClientWithConfig(clientConfig)
|
|
|
|
// Parse tools.
|
|
if llm.registry != nil {
|
|
for _, name := range llm.registry.List() {
|
|
tool, _ := llm.registry.Get(name)
|
|
llm.tools = append(llm.tools, tool.OpenAI())
|
|
}
|
|
}
|
|
|
|
return llm, nil
|
|
}
|
|
|
|
// Query sends a message to the LLM and returns its response.
|
|
// It automatically executes any tool calls requested by the model,
|
|
// looping until the model returns a final text response.
|
|
// Supports multimodal content (text and images) via the Message type.
|
|
func (c *Client) Query(
|
|
ctx context.Context, msg *message.Message,
|
|
) (string, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
|
defer cancel()
|
|
|
|
// Build user message from email content.
|
|
// Uses MultiContent to support both text and image parts.
|
|
userMessage := openai.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleUser,
|
|
MultiContent: msg.ToOpenAIMessages(),
|
|
}
|
|
|
|
// Set the system message and the user prompt.
|
|
messages := []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleSystem,
|
|
Content: c.systemPrompt,
|
|
},
|
|
userMessage,
|
|
}
|
|
|
|
// Loop for tool calls.
|
|
for {
|
|
req := openai.ChatCompletionRequest{
|
|
Model: c.model,
|
|
Messages: messages,
|
|
}
|
|
if len(c.tools) > 0 {
|
|
req.Tools = c.tools
|
|
}
|
|
|
|
res, err := c.client.CreateChatCompletion(ctx, req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("chat completion: %w", err)
|
|
}
|
|
if len(res.Choices) == 0 {
|
|
return "", fmt.Errorf("no response choices returned")
|
|
}
|
|
|
|
choice := res.Choices[0]
|
|
message := choice.Message
|
|
|
|
// If no tool calls, we're done.
|
|
if len(message.ToolCalls) == 0 {
|
|
return message.Content, nil
|
|
}
|
|
|
|
// Add assistant message with tool calls to history.
|
|
messages = append(messages, message)
|
|
|
|
// Process each tool call.
|
|
for _, tc := range message.ToolCalls {
|
|
c.log.InfoContext(
|
|
ctx,
|
|
"calling tool",
|
|
slog.String("name", tc.Function.Name),
|
|
slog.String("args", tc.Function.Arguments),
|
|
)
|
|
|
|
res, err := c.registry.Execute(
|
|
ctx, tc.Function.Name, tc.Function.Arguments,
|
|
)
|
|
if err != nil {
|
|
// Return error to LLM so it can recover.
|
|
c.log.Error(
|
|
"failed to call tool",
|
|
slog.Any("error", err),
|
|
slog.String("name", tc.Function.Name),
|
|
)
|
|
// Assume JSON is a more helpful response
|
|
// for an LLM.
|
|
res = fmt.Sprintf(
|
|
`{"ok": false,"error": %q}`, err,
|
|
)
|
|
} else {
|
|
c.log.Info(
|
|
"called tool",
|
|
slog.String("name", tc.Function.Name),
|
|
)
|
|
}
|
|
|
|
// Content cannot be empty.
|
|
// Assume JSON is better than OK or (no output).
|
|
if strings.TrimSpace(res) == "" {
|
|
res = `{"ok":true,"result":null}`
|
|
}
|
|
|
|
// Add tool result to messages.
|
|
messages = append(
|
|
messages,
|
|
openai.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleTool,
|
|
Content: res,
|
|
ToolCallID: tc.ID,
|
|
},
|
|
)
|
|
}
|
|
// Loop to get LLM's response after tool execution.
|
|
}
|
|
}
|