172 lines
4.5 KiB
Go
172 lines
4.5 KiB
Go
// Package stt provides a client for whisper.cpp speech-to-text.
|
|
package stt
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"time"
|
|
|
|
"log/slog"
|
|
)
|
|
|
|
// Config holds the configuration for a STT client.
|
|
type Config struct {
|
|
// URL is the base URL of the whisper-server.
|
|
// For example, "http://localhost:8178".
|
|
URL string `yaml:"url"`
|
|
// Timeout is the maximum duration for a transcription request.
|
|
// Defaults to 30s.
|
|
Timeout string `yaml:"timeout"`
|
|
}
|
|
|
|
// Validate checks that required configuration values are present and valid.
|
|
func (cfg Config) Validate() error {
|
|
if cfg.URL == "" {
|
|
return fmt.Errorf("missing URL")
|
|
}
|
|
if cfg.Timeout != "" {
|
|
if _, err := time.ParseDuration(cfg.Timeout); err != nil {
|
|
return fmt.Errorf("invalid timeout: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Client wraps an HTTP client for whisper-server requests.
|
|
type Client struct {
|
|
client *http.Client
|
|
log *slog.Logger
|
|
url string
|
|
timeout time.Duration
|
|
}
|
|
|
|
// NewClient creates a new STT client with the provided configuration.
|
|
func NewClient(cfg Config, log *slog.Logger) (*Client, error) {
|
|
if err := cfg.Validate(); err != nil {
|
|
return nil, fmt.Errorf("invalid config: %v", err)
|
|
}
|
|
|
|
var timeout = 30 * time.Second
|
|
if cfg.Timeout != "" {
|
|
d, err := time.ParseDuration(cfg.Timeout)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse timeout: %v", err)
|
|
}
|
|
timeout = d
|
|
}
|
|
|
|
return &Client{
|
|
client: &http.Client{Timeout: timeout},
|
|
log: log,
|
|
url: cfg.URL,
|
|
timeout: timeout,
|
|
}, nil
|
|
}
|
|
|
|
// Output represents the JSON response from whisper.cpp (verbose_json format).
|
|
type Output struct {
|
|
Task string `json:"task,omitempty"`
|
|
Language string `json:"language,omitempty"`
|
|
Duration float64 `json:"duration,omitempty"`
|
|
Text string `json:"text"`
|
|
DetectedLanguage string `json:"detected_language,omitempty"`
|
|
DetectedLanguageProbability float64 `json:"detected_language_probability,omitempty"`
|
|
Segments []Segment `json:"segments,omitempty"`
|
|
}
|
|
|
|
// Segment represents a transcription segment with timing and confidence
|
|
// information.
|
|
type Segment struct {
|
|
ID int `json:"id"`
|
|
Start float64 `json:"start"`
|
|
End float64 `json:"end"`
|
|
Text string `json:"text"`
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
AvgLogProb float64 `json:"avg_logprob,omitempty"`
|
|
NoSpeechProb float64 `json:"no_speech_prob,omitempty"`
|
|
Tokens []int `json:"tokens,omitempty"`
|
|
}
|
|
|
|
// Transcribe sends audio to whisper.cpp and returns the transcription output.
|
|
func (c *Client) Transcribe(ctx context.Context, audio []byte) (*Output, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, c.timeout)
|
|
defer cancel()
|
|
|
|
// Build multipart form.
|
|
var buf bytes.Buffer
|
|
w := multipart.NewWriter(&buf)
|
|
|
|
fw, err := w.CreateFormFile("file", "audio.webm")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create form file: %w", err)
|
|
}
|
|
if _, err := fw.Write(audio); err != nil {
|
|
return nil, fmt.Errorf("write audio: %w", err)
|
|
}
|
|
|
|
// Request verbose JSON response format to get full output including
|
|
// detected_language.
|
|
if err := w.WriteField("response_format", "verbose_json"); err != nil {
|
|
return nil, fmt.Errorf("write response_format: %w", err)
|
|
}
|
|
|
|
if err := w.Close(); err != nil {
|
|
return nil, fmt.Errorf("close multipart: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(
|
|
ctx,
|
|
http.MethodPost,
|
|
c.url+"/inference",
|
|
&buf,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
|
|
|
res, err := c.client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("send request: %w", err)
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
if res.StatusCode != http.StatusOK {
|
|
body, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
c.log.ErrorContext(
|
|
ctx,
|
|
"failed to read response body",
|
|
slog.Any("error", err),
|
|
)
|
|
}
|
|
|
|
return nil, fmt.Errorf(
|
|
"whisper error %d: %s", res.StatusCode, body,
|
|
)
|
|
}
|
|
|
|
body, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read response: %w", err)
|
|
}
|
|
|
|
var output Output
|
|
if err := json.Unmarshal(body, &output); err != nil {
|
|
return nil, fmt.Errorf("parse response: %w", err)
|
|
}
|
|
c.log.DebugContext(ctx, "stt response",
|
|
slog.String("text", output.Text),
|
|
slog.String("language", output.Language),
|
|
slog.String("detected_language", output.DetectedLanguage),
|
|
slog.Float64("duration", output.Duration),
|
|
)
|
|
|
|
return &output, nil
|
|
}
|