Add LLM client
This commit is contained in:
410
internal/llm/llm.go
Normal file
410
internal/llm/llm.go
Normal file
@@ -0,0 +1,410 @@
|
||||
// Package llm provides an OpenAI-compatible client for LLM interactions.
|
||||
// It handles chat completions with automatic tool call execution.
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/chimerical-llc/odidere/internal/tool"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// Config holds the configuration for an LLM client.
|
||||
type Config struct {
|
||||
// Key is the API key for authentication.
|
||||
Key string `yaml:"key"`
|
||||
// Model is the model identifier.
|
||||
Model string `yaml:"model"`
|
||||
// SystemPrompt is prepended to all conversations.
|
||||
SystemPrompt string `yaml:"system_prompt"`
|
||||
// Timeout is the maximum duration for a query (e.g., "5m").
|
||||
// Defaults to 5 minutes if empty.
|
||||
Timeout string `yaml:"timeout"`
|
||||
// URL is the base URL of the OpenAI-compatible API endpoint.
|
||||
URL string `yaml:"url"`
|
||||
}
|
||||
|
||||
// Validate checks that required configuration values are present and valid.
|
||||
func (cfg Config) Validate() error {
|
||||
if cfg.Model == "" {
|
||||
return fmt.Errorf("missing model")
|
||||
}
|
||||
if cfg.Timeout != "" {
|
||||
if _, err := time.ParseDuration(cfg.Timeout); err != nil {
|
||||
return fmt.Errorf("invalid timeout: %w", err)
|
||||
}
|
||||
}
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("missing URL")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Client wraps an OpenAI-compatible client with tool execution support.
|
||||
type Client struct {
|
||||
client *openai.Client
|
||||
log *slog.Logger
|
||||
model string
|
||||
registry *tool.Registry
|
||||
systemPrompt string
|
||||
timeout time.Duration
|
||||
tools []openai.Tool
|
||||
}
|
||||
|
||||
// NewClient creates a new LLM client with the provided configuration.
|
||||
// The registry is optional; if nil, tool calling is disabled.
|
||||
func NewClient(
|
||||
cfg Config,
|
||||
registry *tool.Registry,
|
||||
log *slog.Logger,
|
||||
) (*Client, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
llm := &Client{
|
||||
log: log,
|
||||
model: cfg.Model,
|
||||
systemPrompt: cfg.SystemPrompt,
|
||||
registry: registry,
|
||||
}
|
||||
|
||||
if cfg.Timeout == "" {
|
||||
llm.timeout = 5 * time.Minute
|
||||
} else {
|
||||
d, err := time.ParseDuration(cfg.Timeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse timeout: %v", err)
|
||||
}
|
||||
llm.timeout = d
|
||||
}
|
||||
|
||||
// Setup client.
|
||||
clientConfig := openai.DefaultConfig(cfg.Key)
|
||||
clientConfig.BaseURL = cfg.URL
|
||||
llm.client = openai.NewClientWithConfig(clientConfig)
|
||||
|
||||
// Parse tools.
|
||||
if llm.registry != nil {
|
||||
for _, name := range llm.registry.List() {
|
||||
t, _ := llm.registry.Get(name)
|
||||
llm.tools = append(llm.tools, t.OpenAI())
|
||||
}
|
||||
}
|
||||
|
||||
return llm, nil
|
||||
}
|
||||
|
||||
// ListModels returns available models from the LLM server.
|
||||
func (c *Client) ListModels(ctx context.Context) ([]openai.Model, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||||
defer cancel()
|
||||
|
||||
res, err := c.client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing models: %w", err)
|
||||
}
|
||||
|
||||
return res.Models, nil
|
||||
}
|
||||
|
||||
// DefaultModel returns the configured default model.
|
||||
func (c *Client) DefaultModel() string {
|
||||
return c.model
|
||||
}
|
||||
|
||||
// Query sends messages to the LLM using the specified model.
|
||||
// If model is empty, uses the default configured model.
|
||||
// Returns all messages generated during the query, including tool calls
|
||||
// and tool results. The final message is the last element in the slice.
|
||||
func (c *Client) Query(
|
||||
ctx context.Context,
|
||||
messages []openai.ChatCompletionMessage,
|
||||
model string,
|
||||
) ([]openai.ChatCompletionMessage, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Fallback to the default model.
|
||||
if model == "" {
|
||||
model = c.model
|
||||
}
|
||||
|
||||
// Prepend system prompt, if configured and not already present.
|
||||
if c.systemPrompt != "" && (len(messages) == 0 ||
|
||||
messages[0].Role != openai.ChatMessageRoleSystem) {
|
||||
messages = append(
|
||||
[]openai.ChatCompletionMessage{{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: c.systemPrompt,
|
||||
}},
|
||||
messages...,
|
||||
)
|
||||
}
|
||||
|
||||
// Track messages generated during this query.
|
||||
var generated []openai.ChatCompletionMessage
|
||||
|
||||
// Loop for tool calls.
|
||||
for {
|
||||
req := openai.ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
}
|
||||
if len(c.tools) > 0 {
|
||||
req.Tools = c.tools
|
||||
}
|
||||
|
||||
res, err := c.client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chat completion: %w", err)
|
||||
}
|
||||
if len(res.Choices) == 0 {
|
||||
return nil, fmt.Errorf("no response choices returned")
|
||||
}
|
||||
|
||||
choice := res.Choices[0]
|
||||
message := choice.Message
|
||||
|
||||
// If no tool calls, we're done.
|
||||
if len(message.ToolCalls) == 0 {
|
||||
generated = append(generated, message)
|
||||
return generated, nil
|
||||
}
|
||||
|
||||
// Add assistant message with tool calls to history.
|
||||
generated = append(generated, message)
|
||||
messages = append(messages, message)
|
||||
|
||||
// Process each tool call.
|
||||
for _, tc := range message.ToolCalls {
|
||||
c.log.InfoContext(
|
||||
ctx,
|
||||
"calling tool",
|
||||
slog.String("name", tc.Function.Name),
|
||||
slog.String("args", tc.Function.Arguments),
|
||||
)
|
||||
|
||||
result, err := c.registry.Execute(
|
||||
ctx, tc.Function.Name, tc.Function.Arguments,
|
||||
)
|
||||
if err != nil {
|
||||
c.log.Error(
|
||||
"failed to call tool",
|
||||
slog.Any("error", err),
|
||||
slog.String("name", tc.Function.Name),
|
||||
)
|
||||
result = fmt.Sprintf(
|
||||
`{"ok": false, "error": %q}`, err,
|
||||
)
|
||||
} else {
|
||||
c.log.Info(
|
||||
"called tool",
|
||||
slog.String("name", tc.Function.Name),
|
||||
)
|
||||
}
|
||||
|
||||
// Content cannot be empty.
|
||||
if strings.TrimSpace(result) == "" {
|
||||
result = `{"ok": true, "result": null}`
|
||||
}
|
||||
|
||||
// Add tool result to messages.
|
||||
toolResult := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleTool,
|
||||
Content: result,
|
||||
Name: tc.Function.Name,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
generated = append(generated, toolResult)
|
||||
messages = append(messages, toolResult)
|
||||
}
|
||||
// Loop to get LLM's response after tool execution.
|
||||
}
|
||||
}
|
||||
|
||||
// StreamEvent wraps a ChatCompletionMessage produced during streaming.
|
||||
type StreamEvent struct {
|
||||
Message openai.ChatCompletionMessage
|
||||
}
|
||||
|
||||
// QueryStream sends messages to the LLM using the specified model and
|
||||
// streams results. Each complete message (assistant reply, tool call,
|
||||
// tool result) is sent to the events channel as it becomes available.
|
||||
// The channel is closed before returning.
|
||||
// Returns all messages generated during the query.
|
||||
func (c *Client) QueryStream(
|
||||
ctx context.Context,
|
||||
messages []openai.ChatCompletionMessage,
|
||||
model string,
|
||||
events chan<- StreamEvent,
|
||||
) error {
|
||||
defer close(events)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Fallback to the default model.
|
||||
if model == "" {
|
||||
model = c.model
|
||||
}
|
||||
|
||||
// Prepend system prompt, if configured and not already present.
|
||||
if c.systemPrompt != "" && (len(messages) == 0 ||
|
||||
messages[0].Role != openai.ChatMessageRoleSystem) {
|
||||
messages = append(
|
||||
[]openai.ChatCompletionMessage{{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: c.systemPrompt,
|
||||
}},
|
||||
messages...,
|
||||
)
|
||||
}
|
||||
|
||||
// Loop for tool calls.
|
||||
for {
|
||||
req := openai.ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
}
|
||||
if len(c.tools) > 0 {
|
||||
req.Tools = c.tools
|
||||
}
|
||||
|
||||
stream, err := c.client.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("chat completion stream: %w", err)
|
||||
}
|
||||
|
||||
// Accumulate the streamed response.
|
||||
var (
|
||||
content strings.Builder
|
||||
reasoning strings.Builder
|
||||
toolCalls []openai.ToolCall
|
||||
role string
|
||||
)
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
return fmt.Errorf("stream recv: %w", err)
|
||||
}
|
||||
if len(chunk.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check the first Choice. Only one is expected, since
|
||||
// our request does not set N > 1.
|
||||
delta := chunk.Choices[0].Delta
|
||||
if delta.Role != "" {
|
||||
role = delta.Role
|
||||
}
|
||||
if delta.Content != "" {
|
||||
content.WriteString(delta.Content)
|
||||
}
|
||||
if delta.ReasoningContent != "" {
|
||||
reasoning.WriteString(delta.ReasoningContent)
|
||||
}
|
||||
|
||||
// Accumulate tool call deltas by index.
|
||||
for _, tc := range delta.ToolCalls {
|
||||
i := 0
|
||||
if tc.Index != nil {
|
||||
i = *tc.Index
|
||||
}
|
||||
// Grow the slice as needed.
|
||||
for len(toolCalls) <= i {
|
||||
toolCalls = append(
|
||||
toolCalls,
|
||||
openai.ToolCall{
|
||||
Type: openai.ToolTypeFunction,
|
||||
},
|
||||
)
|
||||
}
|
||||
if tc.ID != "" {
|
||||
toolCalls[i].ID = tc.ID
|
||||
}
|
||||
if tc.Function.Name != "" {
|
||||
toolCalls[i].Function.Name +=
|
||||
tc.Function.Name
|
||||
}
|
||||
if tc.Function.Arguments != "" {
|
||||
toolCalls[i].Function.Arguments +=
|
||||
tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
}
|
||||
stream.Close()
|
||||
|
||||
// Build the complete message from accumulated buffers.
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: content.String(),
|
||||
ReasoningContent: reasoning.String(),
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
events <- StreamEvent{Message: message}
|
||||
|
||||
// If no tool calls, we're done.
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add assistant message with tool calls to history.
|
||||
messages = append(messages, message)
|
||||
|
||||
// Process each tool call.
|
||||
for _, tc := range message.ToolCalls {
|
||||
c.log.InfoContext(
|
||||
ctx,
|
||||
"calling tool",
|
||||
slog.String("name", tc.Function.Name),
|
||||
slog.String("args", tc.Function.Arguments),
|
||||
)
|
||||
|
||||
result, err := c.registry.Execute(
|
||||
ctx, tc.Function.Name, tc.Function.Arguments,
|
||||
)
|
||||
if err != nil {
|
||||
c.log.Error(
|
||||
"failed to call tool",
|
||||
slog.Any("error", err),
|
||||
slog.String("name", tc.Function.Name),
|
||||
)
|
||||
result = fmt.Sprintf(
|
||||
`{"ok": false, "error": %q}`, err,
|
||||
)
|
||||
} else {
|
||||
c.log.Info(
|
||||
"called tool",
|
||||
slog.String("name", tc.Function.Name),
|
||||
)
|
||||
}
|
||||
// Content cannot be empty.
|
||||
if strings.TrimSpace(result) == "" {
|
||||
result = `{"ok": true, "result": null}`
|
||||
}
|
||||
|
||||
// Add tool result to messages.
|
||||
toolResult := openai.ChatCompletionMessage{
|
||||
Content: result,
|
||||
Name: tc.Function.Name,
|
||||
Role: openai.ChatMessageRoleTool,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResult)
|
||||
events <- StreamEvent{Message: toolResult}
|
||||
}
|
||||
// Loop to get LLM's response after tool execution.
|
||||
}
|
||||
}
|
||||
168
internal/llm/llm_test.go
Normal file
168
internal/llm/llm_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty config",
|
||||
cfg: Config{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing model",
|
||||
cfg: Config{
|
||||
URL: "http://localhost:8080",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing URL",
|
||||
cfg: Config{
|
||||
Model: "test-model",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid timeout",
|
||||
cfg: Config{
|
||||
Model: "test-model",
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: "not-a-duration",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid minimal config",
|
||||
cfg: Config{
|
||||
Model: "test-model",
|
||||
URL: "http://localhost:8080",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with timeout",
|
||||
cfg: Config{
|
||||
Model: "test-model",
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: "15m",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with complex timeout",
|
||||
cfg: Config{
|
||||
Model: "test-model",
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: "1h30m45s",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid full config",
|
||||
cfg: Config{
|
||||
Key: "sk-test-key",
|
||||
Model: "test-model",
|
||||
SystemPrompt: "You are a helpful assistant.",
|
||||
Timeout: "30m",
|
||||
URL: "http://localhost:8080",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.cfg.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf(
|
||||
"Validate() error = %v, wantErr %v",
|
||||
err, tt.wantErr,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid config",
|
||||
cfg: Config{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid config without timeout",
|
||||
cfg: Config{
|
||||
Model: "test-model",
|
||||
URL: "http://localhost:8080",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with timeout",
|
||||
cfg: Config{
|
||||
Model: "test-model",
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: "10m",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with all fields",
|
||||
cfg: Config{
|
||||
Key: "test-key",
|
||||
Model: "test-model",
|
||||
SystemPrompt: "Test prompt",
|
||||
Timeout: "5m",
|
||||
URL: "http://localhost:8080",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, err := NewClient(tt.cfg, nil, nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf(
|
||||
"NewClient() error = %v, wantErr %v",
|
||||
err, tt.wantErr,
|
||||
)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && client == nil {
|
||||
t.Error("NewClient() returned nil client")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientDefaultModel(t *testing.T) {
|
||||
cfg := Config{
|
||||
Model: "my-custom-model",
|
||||
URL: "http://localhost:8080",
|
||||
}
|
||||
|
||||
client, err := NewClient(cfg, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error = %v", err)
|
||||
}
|
||||
|
||||
if got := client.DefaultModel(); got != "my-custom-model" {
|
||||
t.Errorf(
|
||||
"DefaultModel() = %q, want %q",
|
||||
got, "my-custom-model",
|
||||
)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user