209 lines
4.8 KiB
Go
209 lines
4.8 KiB
Go
|
|
// Package tts provides a client for kokoro-fastapi text-to-speech.
|
||
|
|
package tts
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"log/slog"
|
||
|
|
"net/http"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Config holds the configuration for a TTS client.
|
||
|
|
type Config struct {
|
||
|
|
// URL is the base URL of kokoro-fastapi.
|
||
|
|
// For example, "http://localhost:8880".
|
||
|
|
URL string `yaml:"url"`
|
||
|
|
// Voice is the default voice ID to use (e.g., "af_heart").
|
||
|
|
Voice string `yaml:"voice"`
|
||
|
|
// Timeout is the maximum duration for a synthesis request.
|
||
|
|
// Defaults to 60s if empty.
|
||
|
|
Timeout string `yaml:"timeout"`
|
||
|
|
// VoiceMap maps whisper language names to voice IDs.
|
||
|
|
// Used by SelectVoice to auto-select voices based on detected
|
||
|
|
// language.
|
||
|
|
// If empty, defaultVoices is used.
|
||
|
|
VoiceMap map[string]string `yaml:"voice_map"`
|
||
|
|
}
|
||
|
|
|
||
|
|
// Validate checks that required configuration values are present and valid.
|
||
|
|
func (cfg Config) Validate() error {
|
||
|
|
if cfg.URL == "" {
|
||
|
|
return fmt.Errorf("missing URL")
|
||
|
|
}
|
||
|
|
if cfg.Voice == "" {
|
||
|
|
return fmt.Errorf("missing voice")
|
||
|
|
}
|
||
|
|
if cfg.Timeout != "" {
|
||
|
|
if _, err := time.ParseDuration(cfg.Timeout); err != nil {
|
||
|
|
return fmt.Errorf("invalid timeout: %w", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Client wraps an HTTP client for kokoro-fastapi requests.
|
||
|
|
type Client struct {
|
||
|
|
client *http.Client
|
||
|
|
log *slog.Logger
|
||
|
|
timeout time.Duration
|
||
|
|
url string
|
||
|
|
voice string
|
||
|
|
voiceMap map[string]string
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewClient creates a new TTS client with the provided configuration.
|
||
|
|
func NewClient(cfg Config, log *slog.Logger) (*Client, error) {
|
||
|
|
if err := cfg.Validate(); err != nil {
|
||
|
|
return nil, fmt.Errorf("invalid config: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
timeout := 60 * time.Second
|
||
|
|
if cfg.Timeout != "" {
|
||
|
|
d, err := time.ParseDuration(cfg.Timeout)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("parse timeout: %v", err)
|
||
|
|
}
|
||
|
|
timeout = d
|
||
|
|
}
|
||
|
|
|
||
|
|
voiceMap := cfg.VoiceMap
|
||
|
|
if len(voiceMap) == 0 {
|
||
|
|
voiceMap = defaultVoices
|
||
|
|
}
|
||
|
|
|
||
|
|
return &Client{
|
||
|
|
client: &http.Client{Timeout: timeout},
|
||
|
|
log: log,
|
||
|
|
timeout: timeout,
|
||
|
|
url: cfg.URL,
|
||
|
|
voice: cfg.Voice,
|
||
|
|
voiceMap: voiceMap,
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// defaultVoices maps whisper language names to default Kokoro voices.
|
||
|
|
// Used when Config.VoiceMap is empty.
|
||
|
|
var defaultVoices = map[string]string{
|
||
|
|
"chinese": "zf_xiaobei",
|
||
|
|
"english": "af_heart",
|
||
|
|
"french": "ff_siwis",
|
||
|
|
"hindi": "hf_alpha",
|
||
|
|
"italian": "if_sara",
|
||
|
|
"japanese": "jf_alpha",
|
||
|
|
"korean": "kf_sarah",
|
||
|
|
"portuguese": "pf_dora",
|
||
|
|
"spanish": "ef_dora",
|
||
|
|
}
|
||
|
|
|
||
|
|
// SelectVoice returns the voice ID for the given whisper language name.
|
||
|
|
// Returns an empty string if no mapping exists for the language.
|
||
|
|
func (c *Client) SelectVoice(lang string) string {
|
||
|
|
if voice, ok := c.voiceMap[lang]; ok {
|
||
|
|
return voice
|
||
|
|
}
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
|
||
|
|
// Synthesize converts text to speech and returns WAV audio.
|
||
|
|
// If voice is empty, the default configured voice is used.
|
||
|
|
func (c *Client) Synthesize(ctx context.Context, text, voice string) (
|
||
|
|
[]byte, error,
|
||
|
|
) {
|
||
|
|
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||
|
|
defer cancel()
|
||
|
|
if voice == "" {
|
||
|
|
voice = c.voice
|
||
|
|
}
|
||
|
|
|
||
|
|
body, err := json.Marshal(struct {
|
||
|
|
Input string `json:"input"`
|
||
|
|
Voice string `json:"voice"`
|
||
|
|
}{
|
||
|
|
Input: text,
|
||
|
|
Voice: voice,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("marshal request: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
req, err := http.NewRequestWithContext(
|
||
|
|
ctx,
|
||
|
|
http.MethodPost,
|
||
|
|
c.url+"/v1/audio/speech",
|
||
|
|
bytes.NewReader(body),
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("create request: %w", err)
|
||
|
|
}
|
||
|
|
req.Header.Set("Content-Type", "application/json")
|
||
|
|
|
||
|
|
res, err := c.client.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("send request: %w", err)
|
||
|
|
}
|
||
|
|
defer res.Body.Close()
|
||
|
|
|
||
|
|
if res.StatusCode != http.StatusOK {
|
||
|
|
body, err := io.ReadAll(res.Body)
|
||
|
|
if err != nil {
|
||
|
|
c.log.ErrorContext(
|
||
|
|
ctx,
|
||
|
|
"failed to read response body",
|
||
|
|
slog.Any("error", err),
|
||
|
|
)
|
||
|
|
}
|
||
|
|
return nil, fmt.Errorf(
|
||
|
|
"tts error %d: %s", res.StatusCode, body,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
audio, err := io.ReadAll(res.Body)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("read response: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return audio, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// ListVoices returns the available voices from kokoro-fastapi.
|
||
|
|
func (c *Client) ListVoices(ctx context.Context) ([]string, error) {
|
||
|
|
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
req, err := http.NewRequestWithContext(
|
||
|
|
ctx,
|
||
|
|
http.MethodGet,
|
||
|
|
c.url+"/v1/audio/voices",
|
||
|
|
nil,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("create request: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
res, err := c.client.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("send request: %w", err)
|
||
|
|
}
|
||
|
|
defer res.Body.Close()
|
||
|
|
|
||
|
|
if res.StatusCode != http.StatusOK {
|
||
|
|
body, _ := io.ReadAll(res.Body)
|
||
|
|
return nil, fmt.Errorf(
|
||
|
|
"voices error %d: %s", res.StatusCode, body,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
var voices = struct {
|
||
|
|
Voices []string `json:"voices"`
|
||
|
|
}{}
|
||
|
|
if err := json.NewDecoder(res.Body).Decode(&voices); err != nil {
|
||
|
|
return nil, fmt.Errorf("decode response: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return voices.Voices, nil
|
||
|
|
}
|