From 57123b439d1478e3fd987f17522725096e57b662 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Thu, 12 Sep 2024 13:09:19 -0400 Subject: [PATCH] add FleetDM plugin (#363) --- .golangci.yml | 19 +-- cmd/pomerium-datasource/fleetdm.go | 70 +++++++++ cmd/pomerium-datasource/main.go | 1 + go.mod | 1 + go.sum | 2 + internal/fleetdm/client/client.go | 233 ++++++++++++++++++++++++++++ internal/fleetdm/client/config.go | 63 ++++++++ internal/fleetdm/client/model.go | 159 +++++++++++++++++++ internal/fleetdm/config.go | 43 +++++ internal/fleetdm/handlers.go | 15 ++ internal/fleetdm/report.go | 67 ++++++++ internal/fleetdm/server.go | 38 +++++ internal/jsonutil/jsonutil.go | 17 ++ internal/jsonutil/reader.go | 7 + pkg/directory/google/google_test.go | 1 + 15 files changed, 718 insertions(+), 18 deletions(-) create mode 100644 cmd/pomerium-datasource/fleetdm.go create mode 100644 internal/fleetdm/client/client.go create mode 100644 internal/fleetdm/client/config.go create mode 100644 internal/fleetdm/client/model.go create mode 100644 internal/fleetdm/config.go create mode 100644 internal/fleetdm/handlers.go create mode 100644 internal/fleetdm/report.go create mode 100644 internal/fleetdm/server.go diff --git a/.golangci.yml b/.golangci.yml index 5b1926f..c9625a7 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -7,19 +7,15 @@ linters: - bodyclose - decorder - dogsled - - dupl - durationcheck - errcheck - errname - errorlint - exhaustive - - gci - - goconst - - godox + # - gci # https://github.com/daixiang0/gci/issues/209 - gofumpt - goheader - goimports - - gomoddirectives - gomodguard - goprintffuncname - gosec @@ -28,42 +24,29 @@ linters: - grouper - importas - ineffassign - - interfacebloat - makezero - misspell - nakedret - nestif - nilerr - - nilnil - noctx - nolintlint - nosprintfhostport - - paralleltest - predeclared - promlinter - reassign - revive - staticcheck - - stylecheck - tenv - thelper - tparallel - - typecheck - unconvert - unused - whitespace - linters-settings: gci: custom-order: true sections: - standard # Standard section: captures all standard packages. - default # Default section: contains all imports that could not be matched to another section type. - - prefix(github.com/pomerium/datasource) -issues: - exclude-rules: - # Exclude some linters from running on test files. - - path: _test\.go$ - linters: - - gosec diff --git a/cmd/pomerium-datasource/fleetdm.go b/cmd/pomerium-datasource/fleetdm.go new file mode 100644 index 0000000..3cc7979 --- /dev/null +++ b/cmd/pomerium-datasource/fleetdm.go @@ -0,0 +1,70 @@ +package main + +import ( + "net/http" + + "github.com/go-playground/validator/v10" + "github.com/rs/zerolog" + "github.com/spf13/cobra" + + "github.com/pomerium/datasource/internal/fleetdm" + "github.com/pomerium/datasource/internal/server" +) + +type fleetDMCmd struct { + APIToken string `validate:"required"` + APIURL string `validate:"required,url"` + Address string `validate:"required"` + CertQueryID uint `validate:"required"` + + cobra.Command `validate:"-"` + zerolog.Logger `validate:"-"` +} + +func fleetDMCommand(log zerolog.Logger) *cobra.Command { + cmd := &fleetDMCmd{ + Command: cobra.Command{ + Use: "fleetdm", + Short: "run FleetDM connector", + }, + Logger: log, + } + cmd.RunE = cmd.exec + + cmd.setupFlags() + return &cmd.Command +} + +func (cmd *fleetDMCmd) setupFlags() { + flags := cmd.Flags() + flags.StringVar(&cmd.APIToken, "api-token", "", "FleetDM API token") + flags.StringVar(&cmd.APIURL, "api-url", "", "FleetDM API URL") + flags.UintVar(&cmd.CertQueryID, "cert-query-id", 0, "FleetDM certificate query ID") + flags.StringVar(&cmd.Address, "address", ":8080", "tcp address to listen to") +} + +func (cmd *fleetDMCmd) exec(c *cobra.Command, _ []string) error { + if err := validator.New().Struct(cmd); err != nil { + return err + } + + srv, err := cmd.newServer() + if err != nil { + return err + } + + return server.RunHTTPServer(c.Context(), cmd.Address, srv) +} + +func (cmd *fleetDMCmd) newServer() (http.Handler, error) { + srv, err := fleetdm.NewServer( + fleetdm.WithAPIToken(cmd.APIToken), + fleetdm.WithAPIURL(cmd.APIURL), + fleetdm.WithCertificateQueryID(cmd.CertQueryID), + ) + if err != nil { + return nil, err + } + + return srv, nil +} diff --git a/cmd/pomerium-datasource/main.go b/cmd/pomerium-datasource/main.go index 6a099c9..b37ec20 100644 --- a/cmd/pomerium-datasource/main.go +++ b/cmd/pomerium-datasource/main.go @@ -27,6 +27,7 @@ func main() { zenefitsCommand(logger), ip2LocationCmd, wellKnownIPsCmd, + fleetDMCommand(logger), ) if err := rootCmd.ExecuteContext(signalContext(logger)); err != nil { logger.Fatal().Err(err).Msg("exit") diff --git a/go.mod b/go.mod index c06a513..641501f 100644 --- a/go.mod +++ b/go.mod @@ -130,6 +130,7 @@ require ( github.com/gostaticanalysis/forcetypeassert v0.1.0 // indirect github.com/gostaticanalysis/nilerr v0.1.1 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-set/v3 v3.0.0-alpha.1 // indirect github.com/hashicorp/go-version v1.7.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hexops/gotextdiff v1.0.3 // indirect diff --git a/go.sum b/go.sum index e5dc0d9..709f350 100644 --- a/go.sum +++ b/go.sum @@ -378,6 +378,8 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-set/v3 v3.0.0-alpha.1 h1:dPUtuqKJGgxtF7YO42oE+NdUONXi5nfLMKH2NpBffIM= +github.com/hashicorp/go-set/v3 v3.0.0-alpha.1/go.mod h1:7bJRgsF3EL3AtRTzcKXdjAFbYGSef+1gHXhglGGO52k= github.com/hashicorp/go-version v1.2.1/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= diff --git a/internal/fleetdm/client/client.go b/internal/fleetdm/client/client.go new file mode 100644 index 0000000..b943c67 --- /dev/null +++ b/internal/fleetdm/client/client.go @@ -0,0 +1,233 @@ +package client + +import ( + "context" + "fmt" + "iter" + "net/http" + "net/url" + + "github.com/pomerium/datasource/internal/jsonutil" + + "github.com/hashicorp/go-set/v3" +) + +const ( + maxHostPerPage = 500 +) + +type Client struct { + cfg *config +} + +// New creates a new FleetDM API client +// see https://fleetdm.com/docs/rest-api/rest-api +func New(opts ...Option) (*Client, error) { + cfg := newConfig(opts...) + return &Client{ + cfg: cfg, + }, nil +} + +func (c *Client) ListHosts( + ctx context.Context, +) iter.Seq2[Host, error] { + var args []string + if c.cfg.withPolicies { + args = append(args, "populate_policies", "true") + } + if c.cfg.withVulnerabilities { + args = append(args, "populate_software", "true") + } + return fetchItemsPaged(ctx, c, convertHostRecord, "hosts", "/api/v1/fleet/hosts", maxHostPerPage, args...) +} + +func (c *Client) listTeams(ctx context.Context) ([]uint, error) { + iter, err := fetchItems(ctx, c, + func(tm struct { + ID uint `json:"id"` + }, + ) (uint, error) { + return tm.ID, nil + }, + "teams", "/api/v1/fleet/teams") + if err != nil { + return nil, err + } + + var ids []uint + for id, err := range iter { + if err != nil { + return nil, err + } + ids = append(ids, id) + } + + return ids, nil +} + +func (c *Client) ListPolicies(ctx context.Context) (iter.Seq2[Policy, error], error) { + teams, err := c.listTeams(ctx) + if err != nil { + return nil, fmt.Errorf("list teams: %w", err) + } + + global, err := fetchItems(ctx, c, convertPolicy, "policies", "/api/latest/fleet/policies") + if err != nil { + return nil, fmt.Errorf("list global policies: %w", err) + } + + policies := []iter.Seq2[Policy, error]{global} + for _, teamID := range teams { + p, err := fetchItems(ctx, c, convertPolicy, "policies", fmt.Sprintf("/api/latest/fleet/teams/%d/policies", teamID)) + if err != nil { + return nil, fmt.Errorf("list team policies: %w", err) + } + policies = append(policies, p) + } + + return dedup(policies...), nil +} + +func (c *Client) QueryCertificates( + ctx context.Context, + queryID uint, +) (iter.Seq2[CertificateSHA1QueryItem, error], error) { + return fetchItems(ctx, c, convertCertificateQuery, "results", fmt.Sprintf("/api/v1/fleet/queries/%d/report", queryID)) +} + +func fetchItemsPaged[InternalRecord, ExternalRecord any]( + ctx context.Context, + c *Client, + convert func(InternalRecord) (ExternalRecord, error), + key string, + path string, + itemsPerPage int, + args ...string, +) iter.Seq2[ExternalRecord, error] { + return func(yield func(ExternalRecord, error) bool) { + page := 0 + for { + iter, err := fetchItems(ctx, c, convert, key, path, append(args, "page", fmt.Sprint(page), "per_page", fmt.Sprint(itemsPerPage))...) + if err != nil { + var v ExternalRecord + if !yield(v, fmt.Errorf("fetch page %d: %w", page, err)) { + return + } + return + } + + itemCount := 0 + for v, err := range iter { + if err != nil { + err = fmt.Errorf("page %d: %w", page, err) + } + if !yield(v, err) { + return + } + + if err != nil { + return + } + + itemCount++ + } + + if itemCount < itemsPerPage { + return + } + + page++ + } + } +} + +func fetchItems[InternalRecord, ExternalRecord any]( + ctx context.Context, + c *Client, + convert func(InternalRecord) (ExternalRecord, error), + key string, + path string, + args ...string, +) (iter.Seq2[ExternalRecord, error], error) { + req, err := c.newRequest(ctx, "GET", path, args...) + if err != nil { + return nil, err + } + + resp, err := c.cfg.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to perform request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return convertIter2( + jsonutil.StreamArrayReadAndClose[InternalRecord](resp.Body, []string{key}), + convert, + ), nil +} + +func (c *Client) newRequest( + ctx context.Context, + method string, + path string, + kv ...string, +) (*http.Request, error) { + u, err := url.Parse(c.cfg.url) + if err != nil { + return nil, fmt.Errorf("failed to parse api endpoint URL: %w", err) + } + if u.Scheme != "https" && u.Scheme != "http" { + return nil, fmt.Errorf("api endpoint URL scheme must be http or https") + } + u.Path = path + + if len(kv)%2 != 0 { + return nil, fmt.Errorf("key-value pairs must be even") + } + + query := make(url.Values) + for i := 0; i < len(kv); i += 2 { + query.Add(kv[i], kv[i+1]) + } + u.RawQuery = query.Encode() + + req, err := http.NewRequest(method, u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.cfg.token)) + return req.WithContext(ctx), nil +} + +func dedup[ID comparable, T interface{ GetID() ID }]( + iters ...iter.Seq2[T, error], +) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + seen := set.New[ID](0) + for _, iter := range iters { + for v, err := range iter { + if err != nil { + if !yield(v, err) { + return + } + continue + } + id := v.GetID() + if seen.Contains(id) { + continue + } + seen.Insert(id) + if !yield(v, nil) { + return + } + } + } + } +} diff --git a/internal/fleetdm/client/config.go b/internal/fleetdm/client/config.go new file mode 100644 index 0000000..b20f87c --- /dev/null +++ b/internal/fleetdm/client/config.go @@ -0,0 +1,63 @@ +package client + +import "net/http" + +type config struct { + token string + url string + httpClient *http.Client + withPolicies bool + withVulnerabilities bool +} + +type Option func(*config) + +var defaults = []Option{ + WithHTTPClient(http.DefaultClient), +} + +func newConfig(opts ...Option) *config { + cfg := new(config) + for _, opt := range defaults { + opt(cfg) + } + for _, opt := range opts { + opt(cfg) + } + return cfg +} + +// WithToken sets the token on the config. +func WithToken(token string) Option { + return func(cfg *config) { + cfg.token = token + } +} + +// WithURL sets the URL on the config. +func WithURL(url string) Option { + return func(cfg *config) { + cfg.url = url + } +} + +// WithHTTPClient sets the HTTP client on the config. +func WithHTTPClient(httpClient *http.Client) Option { + return func(cfg *config) { + cfg.httpClient = httpClient + } +} + +// WithPolicies will fetch policy data and populate policy passing fields for hosts. +func WithPolicies() Option { + return func(cfg *config) { + cfg.withPolicies = true + } +} + +// WithVulnerabilities will fetch vulnerability data and populate CVE fields for hosts. +func WithVulnerabilities() Option { + return func(cfg *config) { + cfg.withVulnerabilities = true + } +} diff --git a/internal/fleetdm/client/model.go b/internal/fleetdm/client/model.go new file mode 100644 index 0000000..444c573 --- /dev/null +++ b/internal/fleetdm/client/model.go @@ -0,0 +1,159 @@ +package client + +import ( + "iter" + "maps" + "slices" + "strconv" + "time" +) + +type Host struct { + ID string `json:"id"` + Seen time.Time `json:"seen_time"` + FailingPoliciesCount uint64 `json:"failing_policies_count"` + // FailingCriticalPoliciesCount is the number of critical policies that the host is failing. This is a calculated value. + FailingCriticalPoliciesCount uint64 `json:"failing_critical_policies_count"` + CriticalVulnerabilitiesCount *uint64 `json:"critical_vulnerabilities_count,omitempty"` + PoliciesPassing []uint `json:"policies_passing,omitempty"` + PoliciesFailing []uint `json:"policies_failing,omitempty"` + // CVEs is a map of CVE to whether the host is vulnerable to the CVE + CVEs []string `json:"cves,omitempty"` +} + +type HostPolicyStatus struct { + ID uint `json:"id"` + Response string `json:"response"` +} + +type hostRecord struct { + ID uint `json:"id"` + Seen time.Time `json:"seen_time"` + HostIssues struct { + FailingPoliciesCount uint64 `json:"failing_policies_count"` + CriticalVulnerabilitiesCount *uint64 `json:"critical_vulnerabilities_count,omitempty"` + TotalIssuesCount uint64 `json:"total_issues_count"` + } `json:"issues,omitempty"` + Policies []struct { + ID uint `json:"id"` + Response string `json:"response"` + Critical bool `json:"critical"` + } `json:"policies"` + Software []struct { + Vulnerabilities []struct { + CVE string `json:"cve"` + } `json:"vulnerabilities"` + } `json:"software"` +} + +func convertHostRecord(r hostRecord) (Host, error) { + var policiesPassing, policiesFailing []uint + failingCriticalPoliciesCount := uint64(0) + + for _, p := range r.Policies { + passing := p.Response == "pass" + failing := p.Response == "fail" + + if p.Critical && failing { + failingCriticalPoliciesCount++ + } + + if passing { + policiesPassing = append(policiesPassing, p.ID) + } + if failing { + policiesFailing = append(policiesFailing, p.ID) + } + } + + cves := make(map[string]bool) + for _, s := range r.Software { + for _, v := range s.Vulnerabilities { + cves[v.CVE] = true + } + } + + return Host{ + ID: strconv.FormatUint(uint64(r.ID), 10), + FailingPoliciesCount: r.HostIssues.FailingPoliciesCount, + FailingCriticalPoliciesCount: failingCriticalPoliciesCount, + CriticalVulnerabilitiesCount: r.HostIssues.CriticalVulnerabilitiesCount, + PoliciesPassing: policiesPassing, + PoliciesFailing: policiesFailing, + CVEs: slices.Collect(maps.Keys(cves)), + Seen: r.Seen, + }, nil +} + +type certificateQueryRecord struct { + HostID uint64 `json:"host_id"` + Columns struct { + SHA1 string `json:"sha1"` + } `json:"columns"` +} + +type CertificateSHA1QueryItem struct { + HostID string `json:"host_id"` + SHA1 string `json:"id"` +} + +func convertCertificateQuery(c certificateQueryRecord) (CertificateSHA1QueryItem, error) { + return CertificateSHA1QueryItem{ + HostID: strconv.FormatUint(c.HostID, 10), + SHA1: c.Columns.SHA1, + }, nil +} + +type Policy struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Resolution string `json:"resolution"` +} + +func (p Policy) GetID() string { + return p.ID +} + +type policyRecord struct { + ID uint `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Resolution string `json:"resolution"` +} + +func convertPolicy(r policyRecord) (Policy, error) { + return Policy{ + ID: strconv.FormatUint(uint64(r.ID), 10), + Name: r.Name, + Description: r.Description, + Resolution: r.Resolution, + }, nil +} + +func convertIter2[T1, T2 any]( + iter1 iter.Seq2[T1, error], + fn func(T1) (T2, error), +) iter.Seq2[T2, error] { + return func(yield func(T2, error) bool) { + for v1, err := range iter1 { + var v2 T2 + if err != nil { + if !yield(v2, err) { + return + } + continue + } + v2, err := fn(v1) + if err != nil { + if !yield(v2, err) { + return + } + continue + } + if !yield(v2, nil) { + return + } + } + } +} diff --git a/internal/fleetdm/config.go b/internal/fleetdm/config.go new file mode 100644 index 0000000..faffd77 --- /dev/null +++ b/internal/fleetdm/config.go @@ -0,0 +1,43 @@ +package fleetdm + +type config struct { + apiToken string + apiURL string + certificateQueryID uint +} + +type Option func(*config) + +var defaults = []Option{} + +func newConfig(opts ...Option) *config { + cfg := new(config) + for _, opt := range defaults { + opt(cfg) + } + for _, opt := range opts { + opt(cfg) + } + return cfg +} + +// WithAPIToken sets the API token on the config. +func WithAPIToken(token string) Option { + return func(cfg *config) { + cfg.apiToken = token + } +} + +// WithAPIURL sets the API URL on the config. +func WithAPIURL(url string) Option { + return func(cfg *config) { + cfg.apiURL = url + } +} + +// WithCertificateQueryID sets the certificate query ID on the config. +func WithCertificateQueryID(id uint) Option { + return func(cfg *config) { + cfg.certificateQueryID = id + } +} diff --git a/internal/fleetdm/handlers.go b/internal/fleetdm/handlers.go new file mode 100644 index 0000000..abec32c --- /dev/null +++ b/internal/fleetdm/handlers.go @@ -0,0 +1,15 @@ +package fleetdm + +import ( + "net/http" +) + +func (srv *server) getIndexHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/zip") + w.Header().Set("Content-Disposition", "attachment; filename=fleetdm.zip") + + err := srv.writeRecords(r.Context(), w) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} diff --git a/internal/fleetdm/report.go b/internal/fleetdm/report.go new file mode 100644 index 0000000..49fd4f2 --- /dev/null +++ b/internal/fleetdm/report.go @@ -0,0 +1,67 @@ +package fleetdm + +import ( + "archive/zip" + "context" + "fmt" + "io" + + "github.com/pomerium/datasource/internal/jsonutil" +) + +const ( + typeCertificateSHA1Fingerprint = "fleetdm.com/CertificateSHA1Fingerprint" + typeHost = "fleetdm.com/Host" + typePolicy = "fleetdm.com/Policy" +) + +func (srv *server) writeRecords( + ctx context.Context, + dst io.Writer, +) error { + zw := zip.NewWriter(dst) + + fw, err := zw.Create(typeCertificateSHA1Fingerprint + ".json") + if err != nil { + return fmt.Errorf("write header: %w", err) + } + + certs, err := srv.client.QueryCertificates(ctx, srv.cfg.certificateQueryID) + if err != nil { + return fmt.Errorf("query certificates: %w", err) + } + + err = jsonutil.StreamWriteArray(fw, certs) + if err != nil { + return fmt.Errorf("write certificates: %w", err) + } + + fw, err = zw.Create(typeHost + ".json") + if err != nil { + return fmt.Errorf("write header: %w", err) + } + + hosts := srv.client.ListHosts(ctx) + + err = jsonutil.StreamWriteArray(fw, hosts) + if err != nil { + return fmt.Errorf("write hosts: %w", err) + } + + fw, err = zw.Create(typePolicy + ".json") + if err != nil { + return fmt.Errorf("write header: %w", err) + } + + policies, err := srv.client.ListPolicies(ctx) + if err != nil { + return fmt.Errorf("list policies: %w", err) + } + + err = jsonutil.StreamWriteArray(fw, policies) + if err != nil { + return fmt.Errorf("write policies: %w", err) + } + + return zw.Close() +} diff --git a/internal/fleetdm/server.go b/internal/fleetdm/server.go new file mode 100644 index 0000000..8a590fc --- /dev/null +++ b/internal/fleetdm/server.go @@ -0,0 +1,38 @@ +package fleetdm + +import ( + "net/http" + + "github.com/gorilla/mux" + + "github.com/pomerium/datasource/internal/fleetdm/client" +) + +func NewServer(opts ...Option) (*mux.Router, error) { + cfg := newConfig(opts...) + + client, err := client.New( + client.WithToken(cfg.apiToken), + client.WithURL(cfg.apiURL), + client.WithPolicies(), + client.WithVulnerabilities(), + ) + if err != nil { + return nil, err + } + + srv := server{ + cfg: cfg, + client: client, + } + + r := mux.NewRouter() + r.Path("/").Methods(http.MethodGet).HandlerFunc(srv.getIndexHandler) + + return r, nil +} + +type server struct { + cfg *config + client *client.Client +} diff --git a/internal/jsonutil/jsonutil.go b/internal/jsonutil/jsonutil.go index 7e2f9f5..2698da2 100644 --- a/internal/jsonutil/jsonutil.go +++ b/internal/jsonutil/jsonutil.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "io" + "iter" ) // A JSONArrayStream represents a streaming array of JSON objects. @@ -58,3 +59,19 @@ func (stream *JSONArrayStream) Encode(obj any) error { _, err = stream.buf.Write(bs) return err } + +func StreamWriteArray[T any](w io.Writer, src iter.Seq2[T, error]) error { + stream := NewJSONArrayStream(w) + defer stream.Close() + + for v, err := range src { + if err != nil { + return err + } + err = stream.Encode(v) + if err != nil { + return err + } + } + return nil +} diff --git a/internal/jsonutil/reader.go b/internal/jsonutil/reader.go index c8f4716..7f7d834 100644 --- a/internal/jsonutil/reader.go +++ b/internal/jsonutil/reader.go @@ -8,6 +8,13 @@ import ( "iter" ) +func StreamArrayReadAndClose[T any](r io.ReadCloser, keys []string) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + StreamArrayReader[T](r, keys)(yield) + _ = r.Close() + } +} + // StreamArrayReader reads a JSON array from r and yields each element. // keys is a list of keys hierarchy to traverse before reading the array. // the returned iterator is single-use. diff --git a/pkg/directory/google/google_test.go b/pkg/directory/google/google_test.go index 7f849f0..a3a46e8 100644 --- a/pkg/directory/google/google_test.go +++ b/pkg/directory/google/google_test.go @@ -15,6 +15,7 @@ import ( "github.com/pomerium/datasource/pkg/directory" ) +//nolint:gosec var privateKey = ` -----BEGIN RSA PRIVATE KEY----- MIIG4wIBAAKCAYEAnetGqPqS6dqYnV9S5S8gL34t7RRUMsf4prxIR+1PMv+bEqVH