From c655dfc65037136acfeb2d1e1e092831159b9bc0 Mon Sep 17 00:00:00 2001 From: dwrz Date: Fri, 13 Feb 2026 15:01:20 +0000 Subject: [PATCH] Add STT and TTS clients --- internal/stt/stt.go | 171 ++++++++++++++++++++++++++++ internal/stt/stt_test.go | 106 ++++++++++++++++++ internal/tts/tts.go | 208 ++++++++++++++++++++++++++++++++++ internal/tts/tts_test.go | 236 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 721 insertions(+) create mode 100644 internal/stt/stt.go create mode 100644 internal/stt/stt_test.go create mode 100644 internal/tts/tts.go create mode 100644 internal/tts/tts_test.go diff --git a/internal/stt/stt.go b/internal/stt/stt.go new file mode 100644 index 0000000..806bba7 --- /dev/null +++ b/internal/stt/stt.go @@ -0,0 +1,171 @@ +// Package stt provides a client for whisper.cpp speech-to-text. +package stt + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "time" + + "log/slog" +) + +// Config holds the configuration for a STT client. +type Config struct { + // URL is the base URL of the whisper-server. + // For example, "http://localhost:8178". + URL string `yaml:"url"` + // Timeout is the maximum duration for a transcription request. + // Defaults to 30s. + Timeout string `yaml:"timeout"` +} + +// Validate checks that required configuration values are present and valid. +func (cfg Config) Validate() error { + if cfg.URL == "" { + return fmt.Errorf("missing URL") + } + if cfg.Timeout != "" { + if _, err := time.ParseDuration(cfg.Timeout); err != nil { + return fmt.Errorf("invalid timeout: %w", err) + } + } + return nil +} + +// Client wraps an HTTP client for whisper-server requests. +type Client struct { + client *http.Client + log *slog.Logger + url string + timeout time.Duration +} + +// NewClient creates a new STT client with the provided configuration. +func NewClient(cfg Config, log *slog.Logger) (*Client, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %v", err) + } + + var timeout = 30 * time.Second + if cfg.Timeout != "" { + d, err := time.ParseDuration(cfg.Timeout) + if err != nil { + return nil, fmt.Errorf("parse timeout: %v", err) + } + timeout = d + } + + return &Client{ + client: &http.Client{Timeout: timeout}, + log: log, + url: cfg.URL, + timeout: timeout, + }, nil +} + +// Output represents the JSON response from whisper.cpp (verbose_json format). +type Output struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Text string `json:"text"` + DetectedLanguage string `json:"detected_language,omitempty"` + DetectedLanguageProbability float64 `json:"detected_language_probability,omitempty"` + Segments []Segment `json:"segments,omitempty"` +} + +// Segment represents a transcription segment with timing and confidence +// information. +type Segment struct { + ID int `json:"id"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Temperature float64 `json:"temperature,omitempty"` + AvgLogProb float64 `json:"avg_logprob,omitempty"` + NoSpeechProb float64 `json:"no_speech_prob,omitempty"` + Tokens []int `json:"tokens,omitempty"` +} + +// Transcribe sends audio to whisper.cpp and returns the transcription output. +func (c *Client) Transcribe(ctx context.Context, audio []byte) (*Output, error) { + ctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + // Build multipart form. + var buf bytes.Buffer + w := multipart.NewWriter(&buf) + + fw, err := w.CreateFormFile("file", "audio.webm") + if err != nil { + return nil, fmt.Errorf("create form file: %w", err) + } + if _, err := fw.Write(audio); err != nil { + return nil, fmt.Errorf("write audio: %w", err) + } + + // Request verbose JSON response format to get full output including + // detected_language. + if err := w.WriteField("response_format", "verbose_json"); err != nil { + return nil, fmt.Errorf("write response_format: %w", err) + } + + if err := w.Close(); err != nil { + return nil, fmt.Errorf("close multipart: %w", err) + } + + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + c.url+"/inference", + &buf, + ) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", w.FormDataContentType()) + + res, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + body, err := io.ReadAll(res.Body) + if err != nil { + c.log.ErrorContext( + ctx, + "failed to read response body", + slog.Any("error", err), + ) + } + + return nil, fmt.Errorf( + "whisper error %d: %s", res.StatusCode, body, + ) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + var output Output + if err := json.Unmarshal(body, &output); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + c.log.DebugContext(ctx, "stt response", + slog.String("text", output.Text), + slog.String("language", output.Language), + slog.String("detected_language", output.DetectedLanguage), + slog.Float64("duration", output.Duration), + ) + + return &output, nil +} diff --git a/internal/stt/stt_test.go b/internal/stt/stt_test.go new file mode 100644 index 0000000..d93de46 --- /dev/null +++ b/internal/stt/stt_test.go @@ -0,0 +1,106 @@ +package stt + +import ( + "testing" +) + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + cfg Config + wantErr bool + }{ + { + name: "empty config", + cfg: Config{}, + wantErr: true, + }, + { + name: "missing URL", + cfg: Config{ + Timeout: "30s", + }, + wantErr: true, + }, + { + name: "invalid timeout", + cfg: Config{ + URL: "http://localhost:8178", + Timeout: "not-a-duration", + }, + wantErr: true, + }, + { + name: "valid minimal config", + cfg: Config{ + URL: "http://localhost:8178", + }, + wantErr: false, + }, + { + name: "valid config with timeout", + cfg: Config{ + URL: "http://localhost:8178", + Timeout: "60s", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.Validate() + if (err != nil) != tt.wantErr { + t.Errorf( + "Validate() error = %v, wantErr %v", + err, tt.wantErr, + ) + } + }) + } +} + +func TestNewClient(t *testing.T) { + tests := []struct { + name string + cfg Config + wantErr bool + }{ + { + name: "invalid config", + cfg: Config{}, + wantErr: true, + }, + { + name: "valid config without timeout", + cfg: Config{ + URL: "http://localhost:8178", + }, + wantErr: false, + }, + { + name: "valid config with timeout", + cfg: Config{ + URL: "http://localhost:8178", + Timeout: "45s", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewClient(tt.cfg, nil) + if (err != nil) != tt.wantErr { + t.Errorf( + "NewClient() error = %v, wantErr %v", + err, tt.wantErr, + ) + return + } + if !tt.wantErr && client == nil { + t.Error("NewClient() returned nil client") + } + }) + } +} diff --git a/internal/tts/tts.go b/internal/tts/tts.go new file mode 100644 index 0000000..ec96d81 --- /dev/null +++ b/internal/tts/tts.go @@ -0,0 +1,208 @@ +// Package tts provides a client for kokoro-fastapi text-to-speech. +package tts + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "time" +) + +// Config holds the configuration for a TTS client. +type Config struct { + // URL is the base URL of kokoro-fastapi. + // For example, "http://localhost:8880". + URL string `yaml:"url"` + // Voice is the default voice ID to use (e.g., "af_heart"). + Voice string `yaml:"voice"` + // Timeout is the maximum duration for a synthesis request. + // Defaults to 60s if empty. + Timeout string `yaml:"timeout"` + // VoiceMap maps whisper language names to voice IDs. + // Used by SelectVoice to auto-select voices based on detected + // language. + // If empty, defaultVoices is used. + VoiceMap map[string]string `yaml:"voice_map"` +} + +// Validate checks that required configuration values are present and valid. +func (cfg Config) Validate() error { + if cfg.URL == "" { + return fmt.Errorf("missing URL") + } + if cfg.Voice == "" { + return fmt.Errorf("missing voice") + } + if cfg.Timeout != "" { + if _, err := time.ParseDuration(cfg.Timeout); err != nil { + return fmt.Errorf("invalid timeout: %w", err) + } + } + return nil +} + +// Client wraps an HTTP client for kokoro-fastapi requests. +type Client struct { + client *http.Client + log *slog.Logger + timeout time.Duration + url string + voice string + voiceMap map[string]string +} + +// NewClient creates a new TTS client with the provided configuration. +func NewClient(cfg Config, log *slog.Logger) (*Client, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %v", err) + } + + timeout := 60 * time.Second + if cfg.Timeout != "" { + d, err := time.ParseDuration(cfg.Timeout) + if err != nil { + return nil, fmt.Errorf("parse timeout: %v", err) + } + timeout = d + } + + voiceMap := cfg.VoiceMap + if len(voiceMap) == 0 { + voiceMap = defaultVoices + } + + return &Client{ + client: &http.Client{Timeout: timeout}, + log: log, + timeout: timeout, + url: cfg.URL, + voice: cfg.Voice, + voiceMap: voiceMap, + }, nil +} + +// defaultVoices maps whisper language names to default Kokoro voices. +// Used when Config.VoiceMap is empty. +var defaultVoices = map[string]string{ + "chinese": "zf_xiaobei", + "english": "af_heart", + "french": "ff_siwis", + "hindi": "hf_alpha", + "italian": "if_sara", + "japanese": "jf_alpha", + "korean": "kf_sarah", + "portuguese": "pf_dora", + "spanish": "ef_dora", +} + +// SelectVoice returns the voice ID for the given whisper language name. +// Returns an empty string if no mapping exists for the language. +func (c *Client) SelectVoice(lang string) string { + if voice, ok := c.voiceMap[lang]; ok { + return voice + } + return "" +} + +// Synthesize converts text to speech and returns WAV audio. +// If voice is empty, the default configured voice is used. +func (c *Client) Synthesize(ctx context.Context, text, voice string) ( + []byte, error, +) { + ctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + if voice == "" { + voice = c.voice + } + + body, err := json.Marshal(struct { + Input string `json:"input"` + Voice string `json:"voice"` + }{ + Input: text, + Voice: voice, + }) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + c.url+"/v1/audio/speech", + bytes.NewReader(body), + ) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + res, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + body, err := io.ReadAll(res.Body) + if err != nil { + c.log.ErrorContext( + ctx, + "failed to read response body", + slog.Any("error", err), + ) + } + return nil, fmt.Errorf( + "tts error %d: %s", res.StatusCode, body, + ) + } + + audio, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + return audio, nil +} + +// ListVoices returns the available voices from kokoro-fastapi. +func (c *Client) ListVoices(ctx context.Context) ([]string, error) { + ctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + c.url+"/v1/audio/voices", + nil, + ) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + res, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + body, _ := io.ReadAll(res.Body) + return nil, fmt.Errorf( + "voices error %d: %s", res.StatusCode, body, + ) + } + + var voices = struct { + Voices []string `json:"voices"` + }{} + if err := json.NewDecoder(res.Body).Decode(&voices); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + + return voices.Voices, nil +} diff --git a/internal/tts/tts_test.go b/internal/tts/tts_test.go new file mode 100644 index 0000000..f2be497 --- /dev/null +++ b/internal/tts/tts_test.go @@ -0,0 +1,236 @@ +package tts + +import ( + "testing" +) + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + cfg Config + wantErr bool + }{ + { + name: "empty config", + cfg: Config{}, + wantErr: true, + }, + { + name: "missing URL", + cfg: Config{ + Voice: "af_heart", + }, + wantErr: true, + }, + { + name: "missing voice", + cfg: Config{ + URL: "http://localhost:8880", + }, + wantErr: true, + }, + { + name: "invalid timeout", + cfg: Config{ + URL: "http://localhost:8880", + Voice: "af_heart", + Timeout: "not-a-duration", + }, + wantErr: true, + }, + { + name: "valid minimal config", + cfg: Config{ + URL: "http://localhost:8880", + Voice: "af_heart", + }, + wantErr: false, + }, + { + name: "valid config with timeout", + cfg: Config{ + URL: "http://localhost:8880", + Voice: "af_heart", + Timeout: "120s", + }, + wantErr: false, + }, + { + name: "valid config with voice map", + cfg: Config{ + URL: "http://localhost:8880", + Voice: "af_heart", + VoiceMap: map[string]string{ + "english": "af_heart", + "chinese": "zf_xiaobei", + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.Validate() + if (err != nil) != tt.wantErr { + t.Errorf( + "Validate() error = %v, wantErr %v", + err, tt.wantErr, + ) + } + }) + } +} + +func TestNewClient(t *testing.T) { + tests := []struct { + name string + cfg Config + wantErr bool + }{ + { + name: "invalid config", + cfg: Config{}, + wantErr: true, + }, + { + name: "valid config without timeout", + cfg: Config{ + URL: "http://localhost:8880", + Voice: "af_heart", + }, + wantErr: false, + }, + { + name: "valid config with timeout", + cfg: Config{ + URL: "http://localhost:8880", + Voice: "af_heart", + Timeout: "90s", + }, + wantErr: false, + }, + { + name: "valid config with custom voice map", + cfg: Config{ + URL: "http://localhost:8880", + Voice: "af_heart", + VoiceMap: map[string]string{ + "english": "custom_en", + "french": "custom_fr", + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewClient(tt.cfg, nil) + if (err != nil) != tt.wantErr { + t.Errorf( + "NewClient() error = %v, wantErr %v", + err, tt.wantErr, + ) + return + } + if !tt.wantErr && client == nil { + t.Error("NewClient() returned nil client") + } + }) + } +} + +func TestSelectVoice(t *testing.T) { + tests := []struct { + name string + voiceMap map[string]string + lang string + want string + }{ + { + name: "default map english", + voiceMap: nil, + lang: "english", + want: "af_heart", + }, + { + name: "default map chinese", + voiceMap: nil, + lang: "chinese", + want: "zf_xiaobei", + }, + { + name: "default map japanese", + voiceMap: nil, + lang: "japanese", + want: "jf_alpha", + }, + { + name: "default map unknown language", + voiceMap: nil, + lang: "klingon", + want: "", + }, + { + name: "custom map", + voiceMap: map[string]string{ + "english": "custom_english", + "german": "custom_german", + }, + lang: "english", + want: "custom_english", + }, + { + name: "custom map missing language", + voiceMap: map[string]string{ + "english": "custom_english", + }, + lang: "french", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := Config{ + URL: "http://localhost:8880", + Voice: "af_heart", + VoiceMap: tt.voiceMap, + } + client, err := NewClient(cfg, nil) + if err != nil { + t.Fatalf("NewClient() error = %v", err) + } + + got := client.SelectVoice(tt.lang) + if got != tt.want { + t.Errorf( + "SelectVoice(%q) = %q, want %q", + tt.lang, got, tt.want, + ) + } + }) + } +} + +func TestDefaultVoiceMap(t *testing.T) { + // Verify that DefaultVoiceMap contains expected entries. + expected := map[string]string{ + "english": "af_heart", + "chinese": "zf_xiaobei", + "japanese": "jf_alpha", + "spanish": "ef_dora", + "french": "ff_siwis", + "korean": "kf_sarah", + } + + for lang, voice := range expected { + if got := defaultVoices[lang]; got != voice { + t.Errorf( + "DefaultVoiceMap[%q] = %q, want %q", + lang, got, voice, + ) + } + } +}