Skip to content

Commit

Permalink
cmd: enable realtime queries w headers (api keys) (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
matslina authored Jan 2, 2024
1 parent 04e9e8c commit 2034d97
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 27 deletions.
12 changes: 9 additions & 3 deletions cmd/departures.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,17 @@ func departures(cmd *cobra.Command, args []string) error {
}

for _, departure := range departures {
line := fmt.Sprintf("%s %s %s", departure.RouteID, departure.Time, departure.Headsign)
delay := ""
if departure.Delay != 0 {
line += fmt.Sprintf(" (%s)", departure.Delay)
delay = fmt.Sprintf("(%s)", departure.Delay)
}
fmt.Println(line)
fmt.Printf(
"%s%s - %s - %s\n",
departure.Time.Format("15:04:05"),
delay,
departure.RouteID,
departure.Headsign,
)
}

return nil
Expand Down
91 changes: 85 additions & 6 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"context"
"fmt"
"os"
"strings"
"time"

"github.com/spf13/cobra"

"tidbyt.dev/gtfs"
"tidbyt.dev/gtfs/downloader"
"tidbyt.dev/gtfs/storage"
)

Expand All @@ -20,13 +22,37 @@ var rootCmd = &cobra.Command{
}

var (
staticURL string
realtimeURL string
staticURL string
realtimeURL string
staticHeaders []string
realtimeHeaders []string
sharedHeaders []string
)

func init() {
rootCmd.PersistentFlags().StringVarP(&staticURL, "static", "", "", "GTFS Static URL")
rootCmd.PersistentFlags().StringVarP(&realtimeURL, "realtime", "", "", "GTFS Realtime URL")
rootCmd.PersistentFlags().StringVarP(&staticURL, "static-url", "", "", "GTFS Static URL")
rootCmd.PersistentFlags().StringVarP(&realtimeURL, "realtime-url", "", "", "GTFS Realtime URL")
rootCmd.PersistentFlags().StringSliceVarP(
&staticHeaders,
"static-header",
"",
[]string{},
"GTFS Static HTTP header",
)
rootCmd.PersistentFlags().StringSliceVarP(
&realtimeHeaders,
"realtime-header",
"",
[]string{},
"GTFS Realtime HTTP header",
)
rootCmd.PersistentFlags().StringSliceVarP(
&sharedHeaders,
"header",
"",
[]string{},
"GTFS HTTP header (shared between static and realtime)",
)
rootCmd.AddCommand(departuresCmd)
}

Expand All @@ -37,6 +63,18 @@ func main() {
}
}

func parseHeaders(headers []string) (map[string]string, error) {
parsed := map[string]string{}
for _, header := range headers {
parts := strings.SplitN(header, ":", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("'%s' is not on form <key>:<value>", header)
}
parsed[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
return parsed, nil
}

func LoadStaticFeed() (*gtfs.Static, error) {
if staticURL == "" {
return nil, fmt.Errorf("static URL is required")
Expand All @@ -46,9 +84,24 @@ func LoadStaticFeed() (*gtfs.Static, error) {
if err != nil {
return nil, err
}

manager := gtfs.NewManager(s)

static, err := manager.LoadStaticAsync("cli", staticURL, nil, time.Now())
headers, err := parseHeaders(staticHeaders)
if err != nil {
return nil, fmt.Errorf("invalid static header: %w", err)
}

shared, err := parseHeaders(sharedHeaders)
if err != nil {
return nil, fmt.Errorf("invalid header: %w", err)
}

for k, v := range shared {
headers[k] = v
}

static, err := manager.LoadStaticAsync("cli", staticURL, headers, time.Now())
if err != nil {
err = manager.Refresh(context.Background())
if err != nil {
Expand All @@ -71,13 +124,39 @@ func LoadRealtimeFeed() (*gtfs.Realtime, error) {
return nil, fmt.Errorf("static URL is required")
}

sh, err := parseHeaders(staticHeaders)
if err != nil {
return nil, fmt.Errorf("invalid static header: %w", err)
}

rh, err := parseHeaders(realtimeHeaders)
if err != nil {
return nil, fmt.Errorf("invalid realtime header: %w", err)
}

shared, err := parseHeaders(sharedHeaders)
if err != nil {
return nil, fmt.Errorf("invalid header: %w", err)
}

for k, v := range shared {
sh[k] = v
rh[k] = v
}

fs, err := downloader.NewFilesystem("./gtfs-rt-cache.json")
if err != nil {
return nil, fmt.Errorf("creating realtime cache: %w", err)
}

s, err := storage.NewSQLiteStorage(storage.SQLiteConfig{OnDisk: true, Directory: "."})
if err != nil {
return nil, err
}
manager := gtfs.NewManager(s)
manager.Downloader = fs

realtime, err := manager.LoadRealtime("cli", staticURL, nil, realtimeURL, nil, time.Now())
realtime, err := manager.LoadRealtime("cli", staticURL, sh, realtimeURL, rh, time.Now())
if err != nil {
return nil, err
}
Expand Down
121 changes: 121 additions & 0 deletions downloader/filesystem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package downloader

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"sync"
"time"
)

type Filesystem struct {
Path string
Records map[string]fsRecord

mutex sync.Mutex
}

type fsRecord struct {
Body string `json:"body"`
RetrievedAt string `json:"retrieved_at"`
}

func NewFilesystem(path string) (*Filesystem, error) {
fs := &Filesystem{
Path: path,
Records: map[string]fsRecord{},
}

err := fs.load()
if err != nil {
return nil, err
}

return fs, nil
}

func (f *Filesystem) Get(
ctx context.Context,
url string,
headers map[string]string,
options GetOptions,
) ([]byte, error) {

f.mutex.Lock()
defer f.mutex.Unlock()

if options.Cache {
if record, found := f.Records[url]; found {
retrievedAt, err := time.Parse(time.RFC3339, record.RetrievedAt)
if err != nil {
return nil, err
}
if retrievedAt.Add(options.CacheTTL).After(time.Now()) {
body, err := base64.StdEncoding.DecodeString(record.Body)
if err != nil {
return nil, fmt.Errorf("decoding: %w", err)
}
fmt.Println("cache hit")
return body, nil
}
fmt.Println("cache expired")
}
}

body, err := HTTPGet(ctx, url, headers, options)
if err != nil {
return nil, fmt.Errorf("http get: %w", err)
}

if options.Cache {
bodyB64 := base64.StdEncoding.EncodeToString(body)
f.Records[url] = fsRecord{
Body: bodyB64,
RetrievedAt: time.Now().UTC().Format(time.RFC3339),
}
err = f.save()
if err != nil {
return nil, fmt.Errorf("saving: %w", err)
}
}

return body, nil
}

func (f *Filesystem) load() error {
f.mutex.Lock()
defer f.mutex.Unlock()

_, err := os.Stat(f.Path)
if os.IsNotExist(err) {
return nil
}

buf, err := os.ReadFile(f.Path)
if err != nil {
return fmt.Errorf("reading: %w", err)
}

err = json.Unmarshal(buf, &f.Records)
if err != nil {
return fmt.Errorf("unmarshalling: %w", err)
}

return nil
}

func (f *Filesystem) save() error {
buf, err := json.Marshal(f.Records)
if err != nil {
return fmt.Errorf("marshalling: %w", err)
}

err = os.WriteFile(f.Path, buf, 0644)
if err != nil {
return fmt.Errorf("writing: %w", err)
}

return nil
}
32 changes: 16 additions & 16 deletions downloader/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@ import (
)

// Caches downloaded files in memory
type MemoryDownloader struct {
mutex sync.Mutex
cache map[string]downloaderCacheEntry
type Memory struct {
mutex sync.Mutex
records map[string]memoryRecord

TimeNow func() time.Time
}

func NewMemoryDownloader() *MemoryDownloader {
return &MemoryDownloader{
cache: make(map[string]downloaderCacheEntry),
TimeNow: time.Now,
}
}

type downloaderCacheEntry struct {
type memoryRecord struct {
data []byte
expiration time.Time
}

func (d *MemoryDownloader) Get(
func NewMemory() *Memory {
return &Memory{
records: map[string]memoryRecord{},
TimeNow: time.Now,
}
}

func (d *Memory) Get(
ctx context.Context,
url string,
headers map[string]string,
Expand All @@ -36,9 +36,9 @@ func (d *MemoryDownloader) Get(
d.mutex.Lock()
defer d.mutex.Unlock()

if entry, ok := d.cache[url]; ok {
if entry.expiration.After(d.TimeNow()) {
return entry.data, nil
if record, ok := d.records[url]; ok {
if record.expiration.After(d.TimeNow()) {
return record.data, nil
}
}
}
Expand All @@ -49,7 +49,7 @@ func (d *MemoryDownloader) Get(
}

if options.Cache {
d.cache[url] = downloaderCacheEntry{
d.records[url] = memoryRecord{
data: body,
expiration: d.TimeNow().Add(options.CacheTTL),
}
Expand Down
2 changes: 1 addition & 1 deletion manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func NewManager(s storage.Storage) *Manager {
StaticMaxSize: DefaultStaticMaxSize,
StaticRefreshInterval: DefaultStaticRefreshInterval,

Downloader: downloader.NewMemoryDownloader(),
Downloader: downloader.NewMemory(),

storage: s,
}
Expand Down
2 changes: 1 addition & 1 deletion manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ func testManagerLoadRealtime(t *testing.T, strg storage.Storage) {

// Mock clock on the downloader to control caching
now := time.Now()
dl := downloader.NewMemoryDownloader()
dl := downloader.NewMemory()
dl.TimeNow = func() time.Time {
return now
}
Expand Down

0 comments on commit 2034d97

Please sign in to comment.