From 2034d9783eab44960f3f90bb77543137eaa599a7 Mon Sep 17 00:00:00 2001 From: Mats Linander Date: Tue, 2 Jan 2024 18:28:38 -0500 Subject: [PATCH] cmd: enable realtime queries w headers (api keys) (#17) --- cmd/departures.go | 12 +++- cmd/main.go | 91 +++++++++++++++++++++++++++-- downloader/filesystem.go | 121 +++++++++++++++++++++++++++++++++++++++ downloader/memory.go | 32 +++++------ manager.go | 2 +- manager_test.go | 2 +- 6 files changed, 233 insertions(+), 27 deletions(-) create mode 100644 downloader/filesystem.go diff --git a/cmd/departures.go b/cmd/departures.go index 4afd7a5..c2ca7f1 100644 --- a/cmd/departures.go +++ b/cmd/departures.go @@ -56,11 +56,17 @@ func departures(cmd *cobra.Command, args []string) error { } for _, departure := range departures { - line := fmt.Sprintf("%s %s %s", departure.RouteID, departure.Time, departure.Headsign) + delay := "" if departure.Delay != 0 { - line += fmt.Sprintf(" (%s)", departure.Delay) + delay = fmt.Sprintf("(%s)", departure.Delay) } - fmt.Println(line) + fmt.Printf( + "%s%s - %s - %s\n", + departure.Time.Format("15:04:05"), + delay, + departure.RouteID, + departure.Headsign, + ) } return nil diff --git a/cmd/main.go b/cmd/main.go index 058b366..a08cae3 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -4,11 +4,13 @@ import ( "context" "fmt" "os" + "strings" "time" "github.com/spf13/cobra" "tidbyt.dev/gtfs" + "tidbyt.dev/gtfs/downloader" "tidbyt.dev/gtfs/storage" ) @@ -20,13 +22,37 @@ var rootCmd = &cobra.Command{ } var ( - staticURL string - realtimeURL string + staticURL string + realtimeURL string + staticHeaders []string + realtimeHeaders []string + sharedHeaders []string ) func init() { - rootCmd.PersistentFlags().StringVarP(&staticURL, "static", "", "", "GTFS Static URL") - rootCmd.PersistentFlags().StringVarP(&realtimeURL, "realtime", "", "", "GTFS Realtime URL") + rootCmd.PersistentFlags().StringVarP(&staticURL, "static-url", "", "", "GTFS Static URL") + rootCmd.PersistentFlags().StringVarP(&realtimeURL, "realtime-url", "", "", "GTFS Realtime URL") + rootCmd.PersistentFlags().StringSliceVarP( + &staticHeaders, + "static-header", + "", + []string{}, + "GTFS Static HTTP header", + ) + rootCmd.PersistentFlags().StringSliceVarP( + &realtimeHeaders, + "realtime-header", + "", + []string{}, + "GTFS Realtime HTTP header", + ) + rootCmd.PersistentFlags().StringSliceVarP( + &sharedHeaders, + "header", + "", + []string{}, + "GTFS HTTP header (shared between static and realtime)", + ) rootCmd.AddCommand(departuresCmd) } @@ -37,6 +63,18 @@ func main() { } } +func parseHeaders(headers []string) (map[string]string, error) { + parsed := map[string]string{} + for _, header := range headers { + parts := strings.SplitN(header, ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("'%s' is not on form :", header) + } + parsed[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) + } + return parsed, nil +} + func LoadStaticFeed() (*gtfs.Static, error) { if staticURL == "" { return nil, fmt.Errorf("static URL is required") @@ -46,9 +84,24 @@ func LoadStaticFeed() (*gtfs.Static, error) { if err != nil { return nil, err } + manager := gtfs.NewManager(s) - static, err := manager.LoadStaticAsync("cli", staticURL, nil, time.Now()) + headers, err := parseHeaders(staticHeaders) + if err != nil { + return nil, fmt.Errorf("invalid static header: %w", err) + } + + shared, err := parseHeaders(sharedHeaders) + if err != nil { + return nil, fmt.Errorf("invalid header: %w", err) + } + + for k, v := range shared { + headers[k] = v + } + + static, err := manager.LoadStaticAsync("cli", staticURL, headers, time.Now()) if err != nil { err = manager.Refresh(context.Background()) if err != nil { @@ -71,13 +124,39 @@ func LoadRealtimeFeed() (*gtfs.Realtime, error) { return nil, fmt.Errorf("static URL is required") } + sh, err := parseHeaders(staticHeaders) + if err != nil { + return nil, fmt.Errorf("invalid static header: %w", err) + } + + rh, err := parseHeaders(realtimeHeaders) + if err != nil { + return nil, fmt.Errorf("invalid realtime header: %w", err) + } + + shared, err := parseHeaders(sharedHeaders) + if err != nil { + return nil, fmt.Errorf("invalid header: %w", err) + } + + for k, v := range shared { + sh[k] = v + rh[k] = v + } + + fs, err := downloader.NewFilesystem("./gtfs-rt-cache.json") + if err != nil { + return nil, fmt.Errorf("creating realtime cache: %w", err) + } + s, err := storage.NewSQLiteStorage(storage.SQLiteConfig{OnDisk: true, Directory: "."}) if err != nil { return nil, err } manager := gtfs.NewManager(s) + manager.Downloader = fs - realtime, err := manager.LoadRealtime("cli", staticURL, nil, realtimeURL, nil, time.Now()) + realtime, err := manager.LoadRealtime("cli", staticURL, sh, realtimeURL, rh, time.Now()) if err != nil { return nil, err } diff --git a/downloader/filesystem.go b/downloader/filesystem.go new file mode 100644 index 0000000..187b15a --- /dev/null +++ b/downloader/filesystem.go @@ -0,0 +1,121 @@ +package downloader + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "sync" + "time" +) + +type Filesystem struct { + Path string + Records map[string]fsRecord + + mutex sync.Mutex +} + +type fsRecord struct { + Body string `json:"body"` + RetrievedAt string `json:"retrieved_at"` +} + +func NewFilesystem(path string) (*Filesystem, error) { + fs := &Filesystem{ + Path: path, + Records: map[string]fsRecord{}, + } + + err := fs.load() + if err != nil { + return nil, err + } + + return fs, nil +} + +func (f *Filesystem) Get( + ctx context.Context, + url string, + headers map[string]string, + options GetOptions, +) ([]byte, error) { + + f.mutex.Lock() + defer f.mutex.Unlock() + + if options.Cache { + if record, found := f.Records[url]; found { + retrievedAt, err := time.Parse(time.RFC3339, record.RetrievedAt) + if err != nil { + return nil, err + } + if retrievedAt.Add(options.CacheTTL).After(time.Now()) { + body, err := base64.StdEncoding.DecodeString(record.Body) + if err != nil { + return nil, fmt.Errorf("decoding: %w", err) + } + fmt.Println("cache hit") + return body, nil + } + fmt.Println("cache expired") + } + } + + body, err := HTTPGet(ctx, url, headers, options) + if err != nil { + return nil, fmt.Errorf("http get: %w", err) + } + + if options.Cache { + bodyB64 := base64.StdEncoding.EncodeToString(body) + f.Records[url] = fsRecord{ + Body: bodyB64, + RetrievedAt: time.Now().UTC().Format(time.RFC3339), + } + err = f.save() + if err != nil { + return nil, fmt.Errorf("saving: %w", err) + } + } + + return body, nil +} + +func (f *Filesystem) load() error { + f.mutex.Lock() + defer f.mutex.Unlock() + + _, err := os.Stat(f.Path) + if os.IsNotExist(err) { + return nil + } + + buf, err := os.ReadFile(f.Path) + if err != nil { + return fmt.Errorf("reading: %w", err) + } + + err = json.Unmarshal(buf, &f.Records) + if err != nil { + return fmt.Errorf("unmarshalling: %w", err) + } + + return nil +} + +func (f *Filesystem) save() error { + buf, err := json.Marshal(f.Records) + if err != nil { + return fmt.Errorf("marshalling: %w", err) + } + + err = os.WriteFile(f.Path, buf, 0644) + if err != nil { + return fmt.Errorf("writing: %w", err) + } + + return nil +} diff --git a/downloader/memory.go b/downloader/memory.go index e83fe5e..0f40b09 100644 --- a/downloader/memory.go +++ b/downloader/memory.go @@ -7,26 +7,26 @@ import ( ) // Caches downloaded files in memory -type MemoryDownloader struct { - mutex sync.Mutex - cache map[string]downloaderCacheEntry +type Memory struct { + mutex sync.Mutex + records map[string]memoryRecord TimeNow func() time.Time } -func NewMemoryDownloader() *MemoryDownloader { - return &MemoryDownloader{ - cache: make(map[string]downloaderCacheEntry), - TimeNow: time.Now, - } -} - -type downloaderCacheEntry struct { +type memoryRecord struct { data []byte expiration time.Time } -func (d *MemoryDownloader) Get( +func NewMemory() *Memory { + return &Memory{ + records: map[string]memoryRecord{}, + TimeNow: time.Now, + } +} + +func (d *Memory) Get( ctx context.Context, url string, headers map[string]string, @@ -36,9 +36,9 @@ func (d *MemoryDownloader) Get( d.mutex.Lock() defer d.mutex.Unlock() - if entry, ok := d.cache[url]; ok { - if entry.expiration.After(d.TimeNow()) { - return entry.data, nil + if record, ok := d.records[url]; ok { + if record.expiration.After(d.TimeNow()) { + return record.data, nil } } } @@ -49,7 +49,7 @@ func (d *MemoryDownloader) Get( } if options.Cache { - d.cache[url] = downloaderCacheEntry{ + d.records[url] = memoryRecord{ data: body, expiration: d.TimeNow().Add(options.CacheTTL), } diff --git a/manager.go b/manager.go index 1cc21a4..80b2960 100644 --- a/manager.go +++ b/manager.go @@ -54,7 +54,7 @@ func NewManager(s storage.Storage) *Manager { StaticMaxSize: DefaultStaticMaxSize, StaticRefreshInterval: DefaultStaticRefreshInterval, - Downloader: downloader.NewMemoryDownloader(), + Downloader: downloader.NewMemory(), storage: s, } diff --git a/manager_test.go b/manager_test.go index e562286..75f908d 100644 --- a/manager_test.go +++ b/manager_test.go @@ -740,7 +740,7 @@ func testManagerLoadRealtime(t *testing.T, strg storage.Storage) { // Mock clock on the downloader to control caching now := time.Now() - dl := downloader.NewMemoryDownloader() + dl := downloader.NewMemory() dl.TimeNow = func() time.Time { return now }