Add STT and TTS clients

This commit is contained in:
dwrz
2026-02-13 15:01:20 +00:00
parent faa1798eb0
commit c655dfc650
4 changed files with 721 additions and 0 deletions

208
internal/tts/tts.go Normal file
View File

@@ -0,0 +1,208 @@
// Package tts provides a client for kokoro-fastapi text-to-speech.
package tts
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"time"
)
// Config holds the configuration for a TTS client.
type Config struct {
// URL is the base URL of kokoro-fastapi.
// For example, "http://localhost:8880".
URL string `yaml:"url"`
// Voice is the default voice ID to use (e.g., "af_heart").
Voice string `yaml:"voice"`
// Timeout is the maximum duration for a synthesis request.
// Defaults to 60s if empty.
Timeout string `yaml:"timeout"`
// VoiceMap maps whisper language names to voice IDs.
// Used by SelectVoice to auto-select voices based on detected
// language.
// If empty, defaultVoices is used.
VoiceMap map[string]string `yaml:"voice_map"`
}
// Validate checks that required configuration values are present and valid.
func (cfg Config) Validate() error {
if cfg.URL == "" {
return fmt.Errorf("missing URL")
}
if cfg.Voice == "" {
return fmt.Errorf("missing voice")
}
if cfg.Timeout != "" {
if _, err := time.ParseDuration(cfg.Timeout); err != nil {
return fmt.Errorf("invalid timeout: %w", err)
}
}
return nil
}
// Client wraps an HTTP client for kokoro-fastapi requests.
type Client struct {
client *http.Client
log *slog.Logger
timeout time.Duration
url string
voice string
voiceMap map[string]string
}
// NewClient creates a new TTS client with the provided configuration.
func NewClient(cfg Config, log *slog.Logger) (*Client, error) {
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid config: %v", err)
}
timeout := 60 * time.Second
if cfg.Timeout != "" {
d, err := time.ParseDuration(cfg.Timeout)
if err != nil {
return nil, fmt.Errorf("parse timeout: %v", err)
}
timeout = d
}
voiceMap := cfg.VoiceMap
if len(voiceMap) == 0 {
voiceMap = defaultVoices
}
return &Client{
client: &http.Client{Timeout: timeout},
log: log,
timeout: timeout,
url: cfg.URL,
voice: cfg.Voice,
voiceMap: voiceMap,
}, nil
}
// defaultVoices maps whisper language names to default Kokoro voices.
// Used when Config.VoiceMap is empty.
var defaultVoices = map[string]string{
"chinese": "zf_xiaobei",
"english": "af_heart",
"french": "ff_siwis",
"hindi": "hf_alpha",
"italian": "if_sara",
"japanese": "jf_alpha",
"korean": "kf_sarah",
"portuguese": "pf_dora",
"spanish": "ef_dora",
}
// SelectVoice returns the voice ID for the given whisper language name.
// Returns an empty string if no mapping exists for the language.
func (c *Client) SelectVoice(lang string) string {
if voice, ok := c.voiceMap[lang]; ok {
return voice
}
return ""
}
// Synthesize converts text to speech and returns WAV audio.
// If voice is empty, the default configured voice is used.
func (c *Client) Synthesize(ctx context.Context, text, voice string) (
[]byte, error,
) {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
if voice == "" {
voice = c.voice
}
body, err := json.Marshal(struct {
Input string `json:"input"`
Voice string `json:"voice"`
}{
Input: text,
Voice: voice,
})
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
c.url+"/v1/audio/speech",
bytes.NewReader(body),
)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
res, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
body, err := io.ReadAll(res.Body)
if err != nil {
c.log.ErrorContext(
ctx,
"failed to read response body",
slog.Any("error", err),
)
}
return nil, fmt.Errorf(
"tts error %d: %s", res.StatusCode, body,
)
}
audio, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
return audio, nil
}
// ListVoices returns the available voices from kokoro-fastapi.
func (c *Client) ListVoices(ctx context.Context) ([]string, error) {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
req, err := http.NewRequestWithContext(
ctx,
http.MethodGet,
c.url+"/v1/audio/voices",
nil,
)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
res, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
body, _ := io.ReadAll(res.Body)
return nil, fmt.Errorf(
"voices error %d: %s", res.StatusCode, body,
)
}
var voices = struct {
Voices []string `json:"voices"`
}{}
if err := json.NewDecoder(res.Body).Decode(&voices); err != nil {
return nil, fmt.Errorf("decode response: %w", err)
}
return voices.Voices, nil
}

236
internal/tts/tts_test.go Normal file
View File

@@ -0,0 +1,236 @@
package tts
import (
"testing"
)
func TestConfigValidate(t *testing.T) {
tests := []struct {
name string
cfg Config
wantErr bool
}{
{
name: "empty config",
cfg: Config{},
wantErr: true,
},
{
name: "missing URL",
cfg: Config{
Voice: "af_heart",
},
wantErr: true,
},
{
name: "missing voice",
cfg: Config{
URL: "http://localhost:8880",
},
wantErr: true,
},
{
name: "invalid timeout",
cfg: Config{
URL: "http://localhost:8880",
Voice: "af_heart",
Timeout: "not-a-duration",
},
wantErr: true,
},
{
name: "valid minimal config",
cfg: Config{
URL: "http://localhost:8880",
Voice: "af_heart",
},
wantErr: false,
},
{
name: "valid config with timeout",
cfg: Config{
URL: "http://localhost:8880",
Voice: "af_heart",
Timeout: "120s",
},
wantErr: false,
},
{
name: "valid config with voice map",
cfg: Config{
URL: "http://localhost:8880",
Voice: "af_heart",
VoiceMap: map[string]string{
"english": "af_heart",
"chinese": "zf_xiaobei",
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.cfg.Validate()
if (err != nil) != tt.wantErr {
t.Errorf(
"Validate() error = %v, wantErr %v",
err, tt.wantErr,
)
}
})
}
}
func TestNewClient(t *testing.T) {
tests := []struct {
name string
cfg Config
wantErr bool
}{
{
name: "invalid config",
cfg: Config{},
wantErr: true,
},
{
name: "valid config without timeout",
cfg: Config{
URL: "http://localhost:8880",
Voice: "af_heart",
},
wantErr: false,
},
{
name: "valid config with timeout",
cfg: Config{
URL: "http://localhost:8880",
Voice: "af_heart",
Timeout: "90s",
},
wantErr: false,
},
{
name: "valid config with custom voice map",
cfg: Config{
URL: "http://localhost:8880",
Voice: "af_heart",
VoiceMap: map[string]string{
"english": "custom_en",
"french": "custom_fr",
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, err := NewClient(tt.cfg, nil)
if (err != nil) != tt.wantErr {
t.Errorf(
"NewClient() error = %v, wantErr %v",
err, tt.wantErr,
)
return
}
if !tt.wantErr && client == nil {
t.Error("NewClient() returned nil client")
}
})
}
}
func TestSelectVoice(t *testing.T) {
tests := []struct {
name string
voiceMap map[string]string
lang string
want string
}{
{
name: "default map english",
voiceMap: nil,
lang: "english",
want: "af_heart",
},
{
name: "default map chinese",
voiceMap: nil,
lang: "chinese",
want: "zf_xiaobei",
},
{
name: "default map japanese",
voiceMap: nil,
lang: "japanese",
want: "jf_alpha",
},
{
name: "default map unknown language",
voiceMap: nil,
lang: "klingon",
want: "",
},
{
name: "custom map",
voiceMap: map[string]string{
"english": "custom_english",
"german": "custom_german",
},
lang: "english",
want: "custom_english",
},
{
name: "custom map missing language",
voiceMap: map[string]string{
"english": "custom_english",
},
lang: "french",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := Config{
URL: "http://localhost:8880",
Voice: "af_heart",
VoiceMap: tt.voiceMap,
}
client, err := NewClient(cfg, nil)
if err != nil {
t.Fatalf("NewClient() error = %v", err)
}
got := client.SelectVoice(tt.lang)
if got != tt.want {
t.Errorf(
"SelectVoice(%q) = %q, want %q",
tt.lang, got, tt.want,
)
}
})
}
}
func TestDefaultVoiceMap(t *testing.T) {
// Verify that DefaultVoiceMap contains expected entries.
expected := map[string]string{
"english": "af_heart",
"chinese": "zf_xiaobei",
"japanese": "jf_alpha",
"spanish": "ef_dora",
"french": "ff_siwis",
"korean": "kf_sarah",
}
for lang, voice := range expected {
if got := defaultVoices[lang]; got != voice {
t.Errorf(
"DefaultVoiceMap[%q] = %q, want %q",
lang, got, voice,
)
}
}
}