Skip to content

Commit

Permalink
Update app to use ctx for endpoint handlers
Browse files Browse the repository at this point in the history
- fix tests
- de-pointify a lot of stuff because reasons
- add OPTIONS method for license endpoint
- remove unneeded env flag from Make compose commands
  • Loading branch information
GradeyCullins committed Nov 13, 2023
1 parent a204660 commit 0ef45e5
Show file tree
Hide file tree
Showing 16 changed files with 553 additions and 620 deletions.
3 changes: 2 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pg-data/
scripts/
scripts/
.env*
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ TAG = latest
.PHONY: docker-run run test docker-stop clean down

docker-run: $(TARGET)
docker compose --env-file ./.env up --detach
docker compose up --detach

run: $(TARGET)
./scripts/start-db.sh
Expand All @@ -27,7 +27,7 @@ test:

down: stop
stop:
docker-compose --env-file ./.env down
docker-compose ./.env down

clean:
rm ${NAME}
Expand Down
36 changes: 2 additions & 34 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,44 +1,12 @@
package main

import (
"flag"
"os"
"purity-vision-filter/src"
"strconv"

"github.com/joho/godotenv"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)

var portFlag int

func main() {
if err := godotenv.Load(); err != nil {
log.Fatal().Err(err)
}

if err := src.InitConfig(); err != nil {
log.Fatal().Msg(err.Error())
}

flag.IntVar(&portFlag, "port", src.DefaultPort, "port to run the service on")
flag.Parse()

logLevel, err := strconv.Atoi(src.LogLevel)
if err != nil {
panic(err)
}
zerolog.SetGlobalLevel(zerolog.Level(logLevel))

log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, NoColor: true}).With().Caller().Logger()

conn, err := src.InitDB(src.DBName)
if err != nil {
log.Fatal().Msg(err.Error())
}
defer conn.Close()

s := src.NewServe()
s.InitServer(portFlag, conn)
godotenv.Load()
src.InitServer()
}
126 changes: 51 additions & 75 deletions src/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,97 +8,73 @@ import (
"github.com/rs/zerolog"
)

var (
// DefaultDBName is the default name of the database.
DefaultDBName = "purity"

// DefaultDBTestName is the default name of the test database.
DefaultDBTestName = "purity_test"

// DefaultPort is the default port to expose the API server.
DefaultPort int = 8080

// DBHost is the host machine running the postgres instance.
DBHost string

// DBPort is the port that exposes the db server.
DBPort string

// DBName is the postgres database name.
DBName string

// DBUser is the postgres user account.
DBUser string

// DBPassword is the password for the DBUser postgres account.
DBPassword string

// DBSSLMode sets the SSL mode of the postgres client.
DBSSLMode string

// LogLevel is the level of logging for the application.
LogLevel string

// StripeKey is for making Stripe API requests.
StripeKey string

// Name on email license delivery.
EmailName string

// SendgridAPIKey is for sending emails.
SendgridAPIKey string

// Stripe webhook secret.
StripeWebhookSecret string

// From address for email license delivery.
EmailFrom string

// TrialLicenseMaxUsage is the maximum image filters for a trial license.
TrialLicenseMaxUsage int = 1000
)

func InitConfig() error {
DefaultPort = 8080

DBHost = getEnvWithDefault("PURITY_DB_HOST", "localhost")
DBPort = getEnvWithDefault("PURITY_DB_PORT", "5432")
DBName = getEnvWithDefault("PURITY_DB_NAME", DefaultDBName)
DBUser = getEnvWithDefault("PURITY_DB_USER", "postgres")
DBPassword = getEnvWithDefault("PURITY_DB_PASS", "")
DBSSLMode = getEnvWithDefault("PURITY_DB_SSL_MODE", "disable")
type Config struct {
DBHost string // DBHost is the host machine running the postgres instance.
DBPort string // DBPort is the port that exposes the db server.
DBName string // DBName is the postgres database name.
DBUser string // DBUser is the postgres user account.
DBPassword string // DBPassword is the password for the DBUser postgres account.
DBSSLMode string // DBSSLMode sets the SSL mode of the postgres client.
LogLevel string // LogLevel is the level of logging for the application.
StripeKey string // StripeKey is for making Stripe API requests.
EmailName string // Name on email license delivery.
SendgridAPIKey string // SendgridAPIKey is for sending emails.
StripeWebhookSecret string // Stripe webhook secret.
EmailFrom string // From address for email license delivery.
TrialLicenseMaxUsage int // TrialLicenseMaxUsage is the maximum image filters for a trial license.
}

LogLevel = getEnvWithDefault("PURITY_LOG_LEVEL", strconv.Itoa(int(zerolog.InfoLevel)))
func missingEnvErr(envVar string) error {
return fmt.Errorf("%s not found in environment", envVar)
}

missingEnvErr := func(envVar string) error {
return fmt.Errorf("%s not found in environment", envVar)
}
func newConfig() (Config, error) {
var (
StripeKey = os.Getenv("STRIPE_KEY")
StripeWebhookSecret = os.Getenv("STRIPE_WEBHOOK_SECRET")
EmailName = getEnvWithDefault("EMAIL_NAME", "John Doe")
EmailFrom = getEnvWithDefault("EMAIL_FROM", "[email protected]")
SendgridAPIKey = os.Getenv("SENDGRID_API_KEY")
)

if os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") == "" {
return missingEnvErr("GOOGLE_APPLICATION_CREDENTIALS")
return Config{}, missingEnvErr("GOOGLE_APPLICATION_CREDENTIALS")
}

if StripeKey = os.Getenv("STRIPE_KEY"); StripeKey == "" {
return missingEnvErr("STRIPE_KEY")
if StripeKey == "" {
return Config{}, missingEnvErr("STRIPE_KEY")
}

if StripeWebhookSecret = os.Getenv("STRIPE_WEBHOOK_SECRET"); StripeWebhookSecret == "" {
return missingEnvErr("STRIPE_WEBHOOK_SECRET")
if StripeWebhookSecret == "" {
return Config{}, missingEnvErr("STRIPE_WEBHOOK_SECRET")
}

if EmailName = getEnvWithDefault("EMAIL_NAME", "John Doe"); EmailName == "" {
return missingEnvErr("EMAIL_NAME")
if EmailName == "" {
return Config{}, missingEnvErr("EMAIL_NAME")
}

if EmailFrom = getEnvWithDefault("EMAIL_FROM", "[email protected]"); EmailFrom == "" {
return missingEnvErr("EMAIL_FROM")
if EmailFrom == "" {
return Config{}, missingEnvErr("EMAIL_FROM")
}

if SendgridAPIKey = os.Getenv("SENDGRID_API_KEY"); SendgridAPIKey == "" {
return missingEnvErr("SENDGRID_API_KEY")
if SendgridAPIKey == "" {
return Config{}, missingEnvErr("SENDGRID_API_KEY")
}

return nil
return Config{
DBHost: getEnvWithDefault("PURITY_DB_HOST", "localhost"),
DBPort: getEnvWithDefault("PURITY_DB_PORT", "5432"),
DBName: getEnvWithDefault("PURITY_DB_NAME", "purity"),
DBUser: getEnvWithDefault("PURITY_DB_USER", "postgres"),
DBPassword: getEnvWithDefault("PURITY_DB_PASS", ""),
DBSSLMode: getEnvWithDefault("PURITY_DB_SSL_MODE", "disable"),
LogLevel: getEnvWithDefault("PURITY_LOG_LEVEL", strconv.Itoa(int(zerolog.InfoLevel))),
StripeKey: StripeKey,
StripeWebhookSecret: StripeWebhookSecret,
EmailName: EmailName,
EmailFrom: EmailFrom,
SendgridAPIKey: SendgridAPIKey,
}, nil
}

func getEnvWithDefault(name string, def string) string {
Expand Down
26 changes: 10 additions & 16 deletions src/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package src
import (
"context"
"fmt"
"os"

"github.com/rs/zerolog"
"github.com/rs/zerolog/log"

"github.com/go-pg/pg/v10"
Expand All @@ -17,28 +19,18 @@ type User struct {
}

// InitDB intializes and returns a postgres database connection object.
func InitDB(dbName string) (*pg.DB, error) {
dbHost := DBHost
dbPort := DBPort
dbAddr := fmt.Sprintf("%s:%s", dbHost, dbPort)
if dbName == "" {
dbName = DBName
}
dbUser := DBUser
dbPassword := DBPassword
func InitDB(config Config) (*pg.DB, error) {
dbAddr := fmt.Sprintf("%s:%s", config.DBHost, config.DBPort)

if dbPassword == "" {
if config.DBPassword == "" {
return nil, fmt.Errorf("missing postgres password. Export \"PURITY_DB_PASS=<your_password>\"")
}

// TODO: use
// tlsConfig := &tls.Config{}

conn := pg.Connect(&pg.Options{
Addr: dbAddr,
User: dbUser,
Password: dbPassword,
Database: dbName,
User: config.DBUser,
Password: config.DBPassword,
Database: config.DBName,
})

// Print SQL queries to logger if loglevel is set to debug.
Expand All @@ -55,6 +47,8 @@ func InitDB(dbName string) (*pg.DB, error) {
type loggerHook struct{}

func (h loggerHook) BeforeQuery(ctx context.Context, evt *pg.QueryEvent) (context.Context, error) {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, NoColor: true}).With().Caller().Logger()

q, err := evt.FormattedQuery()
if err != nil {
return nil, err
Expand Down
46 changes: 18 additions & 28 deletions src/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import (
)

// return URIs that are not cached in annotations
func getCachedSSAs(uris []string) ([]*ImageAnnotation, []string, error) {
func getCachedSSAs(ctx appContext, uris []string) ([]*ImageAnnotation, []string, error) {
var res []*ImageAnnotation
cachedSSAs, err := FindAnnotationsByURI(conn, uris)
cachedSSAs, err := FindAnnotationsByURI(ctx.db, uris)
if err != nil {
return nil, nil, err
}
Expand All @@ -25,7 +25,7 @@ func getCachedSSAs(uris []string) ([]*ImageAnnotation, []string, error) {
if cachedSSA.URI == uri {
res = append(res, &cachedSSA)
found = true
logger.Info().Msgf("found cached image: %s", uri)
ctx.logger.Info().Msgf("found cached image: %s", uri)
break
}
}
Expand All @@ -37,29 +37,29 @@ func getCachedSSAs(uris []string) ([]*ImageAnnotation, []string, error) {
return res, uncachedURIs, nil
}

func filterImages(uris []string, licenseID string) ([]*ImageAnnotation, error) {
res, uris, err := getCachedSSAs(uris)
func filterImages(ctx appContext, uris []string, licenseID string) ([]*ImageAnnotation, error) {
res, uris, err := getCachedSSAs(ctx, uris)
if err != nil {
return nil, err
}
if len(uris) == 0 {
return res, nil
}

license, err := licenseStore.GetLicenseByID(licenseID)
license, err := ctx.licenseStore.GetLicenseByID(licenseID)
if err != nil {
return nil, fmt.Errorf("failed to fetch license: %s", err.Error())
}
if license == nil {
return nil, errors.New("license not found")
}
if license.IsTrial {
remainingUsage := TrialLicenseMaxUsage - license.RequestCount
remainingUsage := ctx.config.TrialLicenseMaxUsage - license.RequestCount
if remainingUsage < len(uris) {
uris = uris[:remainingUsage]
}
if remainingUsage <= 0 { // return early if trial license is expired
license, err := licenseStore.ExpireTrial(license)
license, err := ctx.licenseStore.ExpireTrial(license)
if err != nil {
return res, fmt.Errorf("failed to mark trial license as expired: %s", err.Error())
} else {
Expand All @@ -75,11 +75,11 @@ func filterImages(uris []string, licenseID string) ([]*ImageAnnotation, error) {

if len(annotateImageResponses) > 0 {
license.RequestCount += len(annotateImageResponses)
if err = licenseStore.UpdateLicense(license); err != nil {
logger.Error().Msgf("failed to update license request count: %s", err)
if err = ctx.licenseStore.UpdateLicense(license); err != nil {
ctx.logger.Error().Msgf("failed to update license request count: %s", err)
}
if err := IncrementSubscriptionMeter(license, int64(len(annotateImageResponses))); err != nil {
logger.Error().Msgf("failed to update stripe subscription usage: %s", err.Error())
if err := IncrementSubscriptionMeter(ctx.config.StripeKey, license, int64(len(annotateImageResponses))); err != nil {
ctx.logger.Error().Msgf("failed to update stripe subscription usage: %s", err.Error())
}
}

Expand All @@ -98,22 +98,12 @@ func filterImages(uris []string, licenseID string) ([]*ImageAnnotation, error) {
safeSearchAnnotationsRes := buildSSARes(annotateImageResponses)
res = append(res, safeSearchAnnotationsRes...)

// safeSearchAnnotationsRes := make([]*ImageAnnotation, 0)
// for i, annotation := range annotateImageResponses {
// if annotation == nil {
// continue
// }
// uri := uris[i]
// safeSearchAnnotationsRes = append(safeSearchAnnotationsRes, annotationToSafeSearchResponseRes(uri, annotation))
// }
// res = append(res, safeSearchAnnotationsRes...)

err = cacheAnnotations(safeSearchAnnotationsRes)
err = cacheAnnotations(ctx, safeSearchAnnotationsRes)
if err != nil {
logger.Error().Msgf("failed to cache with uris: %v", uris)
ctx.logger.Error().Msgf("failed to cache with uris: %v", uris)
}

logger.Info().Msgf("license: %s added %d to request count", licenseID, len(annotateImageResponses))
ctx.logger.Info().Msgf("license: %s added %d to request count", licenseID, len(annotateImageResponses))

return res, nil
}
Expand Down Expand Up @@ -153,13 +143,13 @@ func annotationToSafeSearchResponseRes(uri string, annotation *pb.AnnotateImageR
}
}

func cacheAnnotations(annos []*ImageAnnotation) error {
if err := InsertAll(conn, annos); err != nil {
func cacheAnnotations(ctx appContext, annos []*ImageAnnotation) error {
if err := InsertAll(ctx.db, annos); err != nil {
return err
}

for _, anno := range annos {
logger.Info().Msgf("adding %s to DB cache", anno.URI)
ctx.logger.Info().Msgf("adding %s to DB cache", anno.URI)
}

return nil
Expand Down
Loading

0 comments on commit 0ef45e5

Please sign in to comment.