348 lines
7.7 KiB
Go
348 lines
7.7 KiB
Go
// Copyright 2023 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
// Package client provides an interface for accessing vulnerability
|
|
// databases, via either HTTP or local filesystem access.
|
|
//
|
|
// The protocol is described at https://go.dev/security/vuln/database.
|
|
package client
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/vuln/internal/derrors"
|
|
"golang.org/x/vuln/internal/osv"
|
|
isem "golang.org/x/vuln/internal/semver"
|
|
"golang.org/x/vuln/internal/web"
|
|
)
|
|
|
|
// A Client for reading vulnerability databases.
|
|
type Client struct {
|
|
source
|
|
}
|
|
|
|
type Options struct {
|
|
HTTPClient *http.Client
|
|
}
|
|
|
|
// NewClient returns a client that reads the vulnerability database
|
|
// in source (an "http" or "file" prefixed URL).
|
|
//
|
|
// It supports databases following the API described
|
|
// in https://go.dev/security/vuln/database#api.
|
|
func NewClient(source string, opts *Options) (_ *Client, err error) {
|
|
source = strings.TrimRight(source, "/")
|
|
uri, err := url.Parse(source)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
switch uri.Scheme {
|
|
case "http", "https":
|
|
return newHTTPClient(uri, opts)
|
|
case "file":
|
|
return newLocalClient(uri)
|
|
default:
|
|
return nil, fmt.Errorf("source %q has unsupported scheme", uri)
|
|
}
|
|
}
|
|
|
|
var errUnknownSchema = errors.New("unrecognized vulndb format; see https://go.dev/security/vuln/database#api for accepted schema")
|
|
|
|
func newHTTPClient(uri *url.URL, opts *Options) (*Client, error) {
|
|
source := uri.String()
|
|
|
|
// v1 returns true if the source likely follows the V1 schema.
|
|
v1 := func() bool {
|
|
return source == "https://vuln.go.dev" ||
|
|
endpointExistsHTTP(source, "index/modules.json.gz")
|
|
}
|
|
|
|
if v1() {
|
|
return &Client{source: newHTTPSource(uri.String(), opts)}, nil
|
|
}
|
|
|
|
return nil, errUnknownSchema
|
|
}
|
|
|
|
func endpointExistsHTTP(source, endpoint string) bool {
|
|
r, err := http.Head(source + "/" + endpoint)
|
|
return err == nil && r.StatusCode == http.StatusOK
|
|
}
|
|
|
|
func newLocalClient(uri *url.URL) (*Client, error) {
|
|
dir, err := toDir(uri)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check if the DB likely follows the v1 schema by
|
|
// looking for the "index/modules.json" endpoint.
|
|
if endpointExistsDir(dir, modulesEndpoint+".json") {
|
|
return &Client{source: newLocalSource(dir)}, nil
|
|
}
|
|
|
|
// If the DB doesn't follow the v1 schema,
|
|
// attempt to intepret it as a flat list of OSV files.
|
|
// This is currently a "hidden" feature, so don't output the
|
|
// specific error if this fails.
|
|
src, err := newHybridSource(dir)
|
|
if err != nil {
|
|
return nil, errUnknownSchema
|
|
}
|
|
return &Client{source: src}, nil
|
|
}
|
|
|
|
func toDir(uri *url.URL) (string, error) {
|
|
dir, err := web.URLToFilePath(uri)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
fi, err := os.Stat(dir)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if !fi.IsDir() {
|
|
return "", fmt.Errorf("%s is not a directory", dir)
|
|
}
|
|
return dir, nil
|
|
}
|
|
|
|
func endpointExistsDir(dir, endpoint string) bool {
|
|
_, err := os.Stat(filepath.Join(dir, endpoint))
|
|
return err == nil
|
|
}
|
|
|
|
func NewInMemoryClient(entries []*osv.Entry) (*Client, error) {
|
|
s, err := newInMemorySource(entries)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Client{source: s}, nil
|
|
}
|
|
|
|
func (c *Client) LastModifiedTime(ctx context.Context) (_ time.Time, err error) {
|
|
derrors.Wrap(&err, "LastModifiedTime()")
|
|
|
|
b, err := c.source.get(ctx, dbEndpoint)
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
|
|
var dbMeta dbMeta
|
|
if err := json.Unmarshal(b, &dbMeta); err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
|
|
return dbMeta.Modified, nil
|
|
}
|
|
|
|
type ModuleRequest struct {
|
|
// The module path to filter on.
|
|
// This must be set (if empty, ByModule errors).
|
|
Path string
|
|
// (Optional) If set, only return vulnerabilities affected
|
|
// at this version.
|
|
Version string
|
|
}
|
|
|
|
type ModuleResponse struct {
|
|
Path string
|
|
Version string
|
|
Entries []*osv.Entry
|
|
}
|
|
|
|
// ByModules returns a list of responses
|
|
// containing the OSV entries corresponding to each request.
|
|
//
|
|
// The order of the requests is preserved, and each request has
|
|
// a response even if there are no entries (in which case the Entries
|
|
// field is nil).
|
|
func (c *Client) ByModules(ctx context.Context, reqs []*ModuleRequest) (_ []*ModuleResponse, err error) {
|
|
derrors.Wrap(&err, "ByModules(%v)", reqs)
|
|
|
|
metas, err := c.moduleMetas(ctx, reqs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resps := make([]*ModuleResponse, len(reqs))
|
|
g, gctx := errgroup.WithContext(ctx)
|
|
g.SetLimit(10)
|
|
for i, req := range reqs {
|
|
i, req := i, req
|
|
g.Go(func() error {
|
|
entries, err := c.byModule(gctx, req, metas[i])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resps[i] = &ModuleResponse{
|
|
Path: req.Path,
|
|
Version: req.Version,
|
|
Entries: entries,
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
if err := g.Wait(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return resps, nil
|
|
}
|
|
|
|
func (c *Client) moduleMetas(ctx context.Context, reqs []*ModuleRequest) (_ []*moduleMeta, err error) {
|
|
b, err := c.source.get(ctx, modulesEndpoint)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dec, err := newStreamDecoder(b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
metas := make([]*moduleMeta, len(reqs))
|
|
for dec.More() {
|
|
var m moduleMeta
|
|
err := dec.Decode(&m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i, req := range reqs {
|
|
if m.Path == req.Path {
|
|
metas[i] = &m
|
|
}
|
|
}
|
|
}
|
|
|
|
return metas, nil
|
|
}
|
|
|
|
// byModule returns the OSV entries matching the ModuleRequest,
|
|
// or (nil, nil) if there are none.
|
|
func (c *Client) byModule(ctx context.Context, req *ModuleRequest, m *moduleMeta) (_ []*osv.Entry, err error) {
|
|
// This module isn't in the database.
|
|
if m == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
if req.Path == "" {
|
|
return nil, fmt.Errorf("module path must be set")
|
|
}
|
|
|
|
if req.Version != "" && !isem.Valid(req.Version) {
|
|
return nil, fmt.Errorf("version %s is not valid semver", req.Version)
|
|
}
|
|
|
|
var ids []string
|
|
for _, v := range m.Vulns {
|
|
if v.Fixed == "" || isem.Less(req.Version, v.Fixed) {
|
|
ids = append(ids, v.ID)
|
|
}
|
|
}
|
|
|
|
if len(ids) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
entries, err := c.byIDs(ctx, ids)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Filter by version.
|
|
if req.Version != "" {
|
|
affected := func(e *osv.Entry) bool {
|
|
for _, a := range e.Affected {
|
|
if a.Module.Path == req.Path && isem.Affects(a.Ranges, req.Version) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
var filtered []*osv.Entry
|
|
for _, entry := range entries {
|
|
if affected(entry) {
|
|
filtered = append(filtered, entry)
|
|
}
|
|
}
|
|
if len(filtered) == 0 {
|
|
return nil, nil
|
|
}
|
|
}
|
|
|
|
sort.SliceStable(entries, func(i, j int) bool {
|
|
return entries[i].ID < entries[j].ID
|
|
})
|
|
|
|
return entries, nil
|
|
}
|
|
|
|
func (c *Client) byIDs(ctx context.Context, ids []string) (_ []*osv.Entry, err error) {
|
|
entries := make([]*osv.Entry, len(ids))
|
|
g, gctx := errgroup.WithContext(ctx)
|
|
g.SetLimit(10)
|
|
for i, id := range ids {
|
|
i, id := i, id
|
|
g.Go(func() error {
|
|
e, err := c.byID(gctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
entries[i] = e
|
|
return nil
|
|
})
|
|
}
|
|
if err := g.Wait(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return entries, nil
|
|
}
|
|
|
|
// byID returns the OSV entry with the given ID,
|
|
// or an error if it does not exist / cannot be unmarshaled.
|
|
func (c *Client) byID(ctx context.Context, id string) (_ *osv.Entry, err error) {
|
|
derrors.Wrap(&err, "byID(%s)", id)
|
|
|
|
b, err := c.source.get(ctx, entryEndpoint(id))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var entry osv.Entry
|
|
if err := json.Unmarshal(b, &entry); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &entry, nil
|
|
}
|
|
|
|
// newStreamDecoder returns a decoder that can be used
|
|
// to read an array of JSON objects.
|
|
func newStreamDecoder(b []byte) (*json.Decoder, error) {
|
|
dec := json.NewDecoder(bytes.NewBuffer(b))
|
|
|
|
// skip open bracket
|
|
_, err := dec.Token()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return dec, nil
|
|
}
|