From 0ef45e5fc13bd5ae3337005db007b2a7f0f5a3ef Mon Sep 17 00:00:00 2001 From: Gradey Cullins Date: Mon, 13 Nov 2023 16:20:56 -0700 Subject: [PATCH] Update app to use ctx for endpoint handlers - fix tests - de-pointify a lot of stuff because reasons - add OPTIONS method for license endpoint - remove unneeded env flag from Make compose commands --- .dockerignore | 3 +- Makefile | 4 +- main.go | 36 +----- src/config.go | 126 ++++++++---------- src/db.go | 26 ++-- src/filter.go | 46 +++---- src/handlers.go | 261 +++++++++++++++++-------------------- src/images.go | 21 ++- src/images_test.go | 132 ++++++++++--------- src/license.go | 19 +-- src/mail.go | 10 +- src/middleware.go | 7 +- src/routes.go | 37 ------ src/server.go | 116 ++++++++++++----- src/server_test.go | 315 +++++++++++++++++++++++---------------------- src/stripe.go | 14 +- 16 files changed, 553 insertions(+), 620 deletions(-) delete mode 100644 src/routes.go diff --git a/.dockerignore b/.dockerignore index 7ae676a..b2462a7 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,2 +1,3 @@ pg-data/ -scripts/ \ No newline at end of file +scripts/ +.env* \ No newline at end of file diff --git a/Makefile b/Makefile index 27e10c5..5d8463a 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ TAG = latest .PHONY: docker-run run test docker-stop clean down docker-run: $(TARGET) - docker compose --env-file ./.env up --detach + docker compose up --detach run: $(TARGET) ./scripts/start-db.sh @@ -27,7 +27,7 @@ test: down: stop stop: - docker-compose --env-file ./.env down + docker-compose ./.env down clean: rm ${NAME} diff --git a/main.go b/main.go index 8f33ea5..4d0d82b 100644 --- a/main.go +++ b/main.go @@ -1,44 +1,12 @@ package main import ( - "flag" - "os" "purity-vision-filter/src" - "strconv" "github.com/joho/godotenv" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" ) -var portFlag int - func main() { - if err := godotenv.Load(); err != nil { - log.Fatal().Err(err) - } - - if err := src.InitConfig(); err != nil { - log.Fatal().Msg(err.Error()) - } - - flag.IntVar(&portFlag, "port", src.DefaultPort, "port to run the service on") - flag.Parse() - - logLevel, err := strconv.Atoi(src.LogLevel) - if err != nil { - panic(err) - } - zerolog.SetGlobalLevel(zerolog.Level(logLevel)) - - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, NoColor: true}).With().Caller().Logger() - - conn, err := src.InitDB(src.DBName) - if err != nil { - log.Fatal().Msg(err.Error()) - } - defer conn.Close() - - s := src.NewServe() - s.InitServer(portFlag, conn) + godotenv.Load() + src.InitServer() } diff --git a/src/config.go b/src/config.go index 769fba6..fa49911 100644 --- a/src/config.go +++ b/src/config.go @@ -8,97 +8,73 @@ import ( "github.com/rs/zerolog" ) -var ( - // DefaultDBName is the default name of the database. - DefaultDBName = "purity" - - // DefaultDBTestName is the default name of the test database. - DefaultDBTestName = "purity_test" - - // DefaultPort is the default port to expose the API server. - DefaultPort int = 8080 - - // DBHost is the host machine running the postgres instance. - DBHost string - - // DBPort is the port that exposes the db server. - DBPort string - - // DBName is the postgres database name. - DBName string - - // DBUser is the postgres user account. - DBUser string - - // DBPassword is the password for the DBUser postgres account. - DBPassword string - - // DBSSLMode sets the SSL mode of the postgres client. - DBSSLMode string - - // LogLevel is the level of logging for the application. - LogLevel string - - // StripeKey is for making Stripe API requests. - StripeKey string - - // Name on email license delivery. - EmailName string - - // SendgridAPIKey is for sending emails. - SendgridAPIKey string - - // Stripe webhook secret. - StripeWebhookSecret string - - // From address for email license delivery. - EmailFrom string - - // TrialLicenseMaxUsage is the maximum image filters for a trial license. - TrialLicenseMaxUsage int = 1000 -) - -func InitConfig() error { - DefaultPort = 8080 - - DBHost = getEnvWithDefault("PURITY_DB_HOST", "localhost") - DBPort = getEnvWithDefault("PURITY_DB_PORT", "5432") - DBName = getEnvWithDefault("PURITY_DB_NAME", DefaultDBName) - DBUser = getEnvWithDefault("PURITY_DB_USER", "postgres") - DBPassword = getEnvWithDefault("PURITY_DB_PASS", "") - DBSSLMode = getEnvWithDefault("PURITY_DB_SSL_MODE", "disable") +type Config struct { + DBHost string // DBHost is the host machine running the postgres instance. + DBPort string // DBPort is the port that exposes the db server. + DBName string // DBName is the postgres database name. + DBUser string // DBUser is the postgres user account. + DBPassword string // DBPassword is the password for the DBUser postgres account. + DBSSLMode string // DBSSLMode sets the SSL mode of the postgres client. + LogLevel string // LogLevel is the level of logging for the application. + StripeKey string // StripeKey is for making Stripe API requests. + EmailName string // Name on email license delivery. + SendgridAPIKey string // SendgridAPIKey is for sending emails. + StripeWebhookSecret string // Stripe webhook secret. + EmailFrom string // From address for email license delivery. + TrialLicenseMaxUsage int // TrialLicenseMaxUsage is the maximum image filters for a trial license. +} - LogLevel = getEnvWithDefault("PURITY_LOG_LEVEL", strconv.Itoa(int(zerolog.InfoLevel))) +func missingEnvErr(envVar string) error { + return fmt.Errorf("%s not found in environment", envVar) +} - missingEnvErr := func(envVar string) error { - return fmt.Errorf("%s not found in environment", envVar) - } +func newConfig() (Config, error) { + var ( + StripeKey = os.Getenv("STRIPE_KEY") + StripeWebhookSecret = os.Getenv("STRIPE_WEBHOOK_SECRET") + EmailName = getEnvWithDefault("EMAIL_NAME", "John Doe") + EmailFrom = getEnvWithDefault("EMAIL_FROM", "test@example.com") + SendgridAPIKey = os.Getenv("SENDGRID_API_KEY") + ) if os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") == "" { - return missingEnvErr("GOOGLE_APPLICATION_CREDENTIALS") + return Config{}, missingEnvErr("GOOGLE_APPLICATION_CREDENTIALS") } - if StripeKey = os.Getenv("STRIPE_KEY"); StripeKey == "" { - return missingEnvErr("STRIPE_KEY") + if StripeKey == "" { + return Config{}, missingEnvErr("STRIPE_KEY") } - if StripeWebhookSecret = os.Getenv("STRIPE_WEBHOOK_SECRET"); StripeWebhookSecret == "" { - return missingEnvErr("STRIPE_WEBHOOK_SECRET") + if StripeWebhookSecret == "" { + return Config{}, missingEnvErr("STRIPE_WEBHOOK_SECRET") } - if EmailName = getEnvWithDefault("EMAIL_NAME", "John Doe"); EmailName == "" { - return missingEnvErr("EMAIL_NAME") + if EmailName == "" { + return Config{}, missingEnvErr("EMAIL_NAME") } - if EmailFrom = getEnvWithDefault("EMAIL_FROM", "test@example.com"); EmailFrom == "" { - return missingEnvErr("EMAIL_FROM") + if EmailFrom == "" { + return Config{}, missingEnvErr("EMAIL_FROM") } - if SendgridAPIKey = os.Getenv("SENDGRID_API_KEY"); SendgridAPIKey == "" { - return missingEnvErr("SENDGRID_API_KEY") + if SendgridAPIKey == "" { + return Config{}, missingEnvErr("SENDGRID_API_KEY") } - return nil + return Config{ + DBHost: getEnvWithDefault("PURITY_DB_HOST", "localhost"), + DBPort: getEnvWithDefault("PURITY_DB_PORT", "5432"), + DBName: getEnvWithDefault("PURITY_DB_NAME", "purity"), + DBUser: getEnvWithDefault("PURITY_DB_USER", "postgres"), + DBPassword: getEnvWithDefault("PURITY_DB_PASS", ""), + DBSSLMode: getEnvWithDefault("PURITY_DB_SSL_MODE", "disable"), + LogLevel: getEnvWithDefault("PURITY_LOG_LEVEL", strconv.Itoa(int(zerolog.InfoLevel))), + StripeKey: StripeKey, + StripeWebhookSecret: StripeWebhookSecret, + EmailName: EmailName, + EmailFrom: EmailFrom, + SendgridAPIKey: SendgridAPIKey, + }, nil } func getEnvWithDefault(name string, def string) string { diff --git a/src/db.go b/src/db.go index 9adb986..3115847 100644 --- a/src/db.go +++ b/src/db.go @@ -3,7 +3,9 @@ package src import ( "context" "fmt" + "os" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/go-pg/pg/v10" @@ -17,28 +19,18 @@ type User struct { } // InitDB intializes and returns a postgres database connection object. -func InitDB(dbName string) (*pg.DB, error) { - dbHost := DBHost - dbPort := DBPort - dbAddr := fmt.Sprintf("%s:%s", dbHost, dbPort) - if dbName == "" { - dbName = DBName - } - dbUser := DBUser - dbPassword := DBPassword +func InitDB(config Config) (*pg.DB, error) { + dbAddr := fmt.Sprintf("%s:%s", config.DBHost, config.DBPort) - if dbPassword == "" { + if config.DBPassword == "" { return nil, fmt.Errorf("missing postgres password. Export \"PURITY_DB_PASS=\"") } - // TODO: use - // tlsConfig := &tls.Config{} - conn := pg.Connect(&pg.Options{ Addr: dbAddr, - User: dbUser, - Password: dbPassword, - Database: dbName, + User: config.DBUser, + Password: config.DBPassword, + Database: config.DBName, }) // Print SQL queries to logger if loglevel is set to debug. @@ -55,6 +47,8 @@ func InitDB(dbName string) (*pg.DB, error) { type loggerHook struct{} func (h loggerHook) BeforeQuery(ctx context.Context, evt *pg.QueryEvent) (context.Context, error) { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, NoColor: true}).With().Caller().Logger() + q, err := evt.FormattedQuery() if err != nil { return nil, err diff --git a/src/filter.go b/src/filter.go index 38459ca..a0596a9 100644 --- a/src/filter.go +++ b/src/filter.go @@ -10,9 +10,9 @@ import ( ) // return URIs that are not cached in annotations -func getCachedSSAs(uris []string) ([]*ImageAnnotation, []string, error) { +func getCachedSSAs(ctx appContext, uris []string) ([]*ImageAnnotation, []string, error) { var res []*ImageAnnotation - cachedSSAs, err := FindAnnotationsByURI(conn, uris) + cachedSSAs, err := FindAnnotationsByURI(ctx.db, uris) if err != nil { return nil, nil, err } @@ -25,7 +25,7 @@ func getCachedSSAs(uris []string) ([]*ImageAnnotation, []string, error) { if cachedSSA.URI == uri { res = append(res, &cachedSSA) found = true - logger.Info().Msgf("found cached image: %s", uri) + ctx.logger.Info().Msgf("found cached image: %s", uri) break } } @@ -37,8 +37,8 @@ func getCachedSSAs(uris []string) ([]*ImageAnnotation, []string, error) { return res, uncachedURIs, nil } -func filterImages(uris []string, licenseID string) ([]*ImageAnnotation, error) { - res, uris, err := getCachedSSAs(uris) +func filterImages(ctx appContext, uris []string, licenseID string) ([]*ImageAnnotation, error) { + res, uris, err := getCachedSSAs(ctx, uris) if err != nil { return nil, err } @@ -46,7 +46,7 @@ func filterImages(uris []string, licenseID string) ([]*ImageAnnotation, error) { return res, nil } - license, err := licenseStore.GetLicenseByID(licenseID) + license, err := ctx.licenseStore.GetLicenseByID(licenseID) if err != nil { return nil, fmt.Errorf("failed to fetch license: %s", err.Error()) } @@ -54,12 +54,12 @@ func filterImages(uris []string, licenseID string) ([]*ImageAnnotation, error) { return nil, errors.New("license not found") } if license.IsTrial { - remainingUsage := TrialLicenseMaxUsage - license.RequestCount + remainingUsage := ctx.config.TrialLicenseMaxUsage - license.RequestCount if remainingUsage < len(uris) { uris = uris[:remainingUsage] } if remainingUsage <= 0 { // return early if trial license is expired - license, err := licenseStore.ExpireTrial(license) + license, err := ctx.licenseStore.ExpireTrial(license) if err != nil { return res, fmt.Errorf("failed to mark trial license as expired: %s", err.Error()) } else { @@ -75,11 +75,11 @@ func filterImages(uris []string, licenseID string) ([]*ImageAnnotation, error) { if len(annotateImageResponses) > 0 { license.RequestCount += len(annotateImageResponses) - if err = licenseStore.UpdateLicense(license); err != nil { - logger.Error().Msgf("failed to update license request count: %s", err) + if err = ctx.licenseStore.UpdateLicense(license); err != nil { + ctx.logger.Error().Msgf("failed to update license request count: %s", err) } - if err := IncrementSubscriptionMeter(license, int64(len(annotateImageResponses))); err != nil { - logger.Error().Msgf("failed to update stripe subscription usage: %s", err.Error()) + if err := IncrementSubscriptionMeter(ctx.config.StripeKey, license, int64(len(annotateImageResponses))); err != nil { + ctx.logger.Error().Msgf("failed to update stripe subscription usage: %s", err.Error()) } } @@ -98,22 +98,12 @@ func filterImages(uris []string, licenseID string) ([]*ImageAnnotation, error) { safeSearchAnnotationsRes := buildSSARes(annotateImageResponses) res = append(res, safeSearchAnnotationsRes...) - // safeSearchAnnotationsRes := make([]*ImageAnnotation, 0) - // for i, annotation := range annotateImageResponses { - // if annotation == nil { - // continue - // } - // uri := uris[i] - // safeSearchAnnotationsRes = append(safeSearchAnnotationsRes, annotationToSafeSearchResponseRes(uri, annotation)) - // } - // res = append(res, safeSearchAnnotationsRes...) - - err = cacheAnnotations(safeSearchAnnotationsRes) + err = cacheAnnotations(ctx, safeSearchAnnotationsRes) if err != nil { - logger.Error().Msgf("failed to cache with uris: %v", uris) + ctx.logger.Error().Msgf("failed to cache with uris: %v", uris) } - logger.Info().Msgf("license: %s added %d to request count", licenseID, len(annotateImageResponses)) + ctx.logger.Info().Msgf("license: %s added %d to request count", licenseID, len(annotateImageResponses)) return res, nil } @@ -153,13 +143,13 @@ func annotationToSafeSearchResponseRes(uri string, annotation *pb.AnnotateImageR } } -func cacheAnnotations(annos []*ImageAnnotation) error { - if err := InsertAll(conn, annos); err != nil { +func cacheAnnotations(ctx appContext, annos []*ImageAnnotation) error { + if err := InsertAll(ctx.db, annos); err != nil { return err } for _, anno := range annos { - logger.Info().Msgf("adding %s to DB cache", anno.URI) + ctx.logger.Info().Msgf("adding %s to DB cache", anno.URI) } return nil diff --git a/src/handlers.go b/src/handlers.go index a6b1b4c..482a1ec 100644 --- a/src/handlers.go +++ b/src/handlers.go @@ -2,6 +2,7 @@ package src import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -22,7 +23,7 @@ func health(w http.ResponseWriter, req *http.Request) { const MAX_IMAGES_PER_REQUEST = 16 -func removeDuplicates(vals []string) []string { +func removeDuplicates(logger zerolog.Logger, vals []string) []string { res := make([]string, 0) strMap := make(map[string]bool, 0) @@ -38,77 +39,65 @@ func removeDuplicates(vals []string) []string { return res } +func handleBatchFilter(ctx appContext, w http.ResponseWriter, req *http.Request) (int, error) { + var filterReqPayload AnnotateReq -func handleBatchFilter(logger zerolog.Logger) func(w http.ResponseWriter, req *http.Request) { - return func(w http.ResponseWriter, req *http.Request) { - var filterReqPayload AnnotateReq - - decoder := json.NewDecoder(req.Body) - if err := decoder.Decode(&filterReqPayload); err != nil { - writeError(400, "JSON body missing or malformed", w) - return - } + decoder := json.NewDecoder(req.Body) + if err := decoder.Decode(&filterReqPayload); err != nil { + return http.StatusBadRequest, errors.New("JSON body missing or malformed") + } - if len(filterReqPayload.ImgURIList) == 0 { - writeError(400, "ImgUriList cannot be empty", w) - return - } + if len(filterReqPayload.ImgURIList) == 0 { + return http.StatusBadRequest, errors.New("ImgUriList cannot be empty") + } - var res []*ImageAnnotation + var res []*ImageAnnotation - uris := removeDuplicates(filterReqPayload.ImgURIList) + uris := removeDuplicates(ctx.logger, filterReqPayload.ImgURIList) - // Validate the request payload URIs - for _, uri := range uris { - if _, err := url.ParseRequestURI(uri); err != nil { - writeError(400, fmt.Sprintf("%s is not a valid URI", uri), w) - return - } + for _, uri := range uris { + if _, err := url.ParseRequestURI(uri); err != nil { + return http.StatusBadRequest, fmt.Errorf("%s is not a valid URI", uri) } + } - // Filter images in pages of size MAX_IMAGES_PER_REQUEST. - for i := 0; i < len(uris); { - var endIdx int - if i+MAX_IMAGES_PER_REQUEST > len(uris)-1 { - endIdx = len(uris) - } else { - endIdx = i + MAX_IMAGES_PER_REQUEST - } - - temp, err := filterImages(uris[i:endIdx], req.Header.Get("LicenseID")) - if err != nil { - logger.Error().Msgf("error while filtering: %s", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - res = append(res, temp...) + // Filter images in pages of size MAX_IMAGES_PER_REQUEST. + for i := 0; i < len(uris); { + var endIdx int + if i+MAX_IMAGES_PER_REQUEST > len(uris)-1 { + endIdx = len(uris) + } else { + endIdx = i + MAX_IMAGES_PER_REQUEST + } - i += MAX_IMAGES_PER_REQUEST + temp, err := filterImages(ctx, uris[i:endIdx], req.Header.Get("LicenseID")) + if err != nil { + return http.StatusInternalServerError, fmt.Errorf("error while filtering: %s", err) } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(res) + res = append(res, temp...) + + i += MAX_IMAGES_PER_REQUEST } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) + return http.StatusOK, nil } -func handleWebhook(w http.ResponseWriter, req *http.Request) { +func handleWebhook(ctx appContext, w http.ResponseWriter, req *http.Request) (int, error) { const MaxBodyBytes = int64(65536) req.Body = http.MaxBytesReader(w, req.Body, MaxBodyBytes) payload, err := io.ReadAll(req.Body) if err != nil { - logger.Error().Msgf("error reading request body: %v", err) - w.WriteHeader(http.StatusServiceUnavailable) - return + return http.StatusServiceUnavailable, fmt.Errorf("error reading request body: %v", err) } - endpointSecret := StripeWebhookSecret + endpointSecret := ctx.config.StripeWebhookSecret event, err := webhook.ConstructEvent(payload, req.Header.Get("Stripe-Signature"), endpointSecret) if err != nil { - logger.Error().Msgf("error verifying webhook signature: %v", err) - w.WriteHeader(http.StatusBadRequest) // Return a 400 error on a bad signature - return + return http.StatusBadRequest, fmt.Errorf("error verifying webhook signature: %v", err) } // Unmarshal the event data into an appropriate struct depending on its Type @@ -118,39 +107,32 @@ func handleWebhook(w http.ResponseWriter, req *http.Request) { var session stripe.CheckoutSession err := json.Unmarshal(event.Data.Raw, &session) if err != nil { - logger.Error().Msgf("error parsing webhook JSON: %v", err.Error()) - fmt.Fprintf(os.Stderr, "error parsing webhook JSON: %v\n", err) - w.WriteHeader(http.StatusBadRequest) - return + return http.StatusBadRequest, fmt.Errorf("error parsing webhook JSON: %v", err.Error()) } subscriptionID := session.Subscription.ID stripeID := session.Customer.ID email := session.CustomerDetails.Email - license, err := licenseStore.GetLicenseByStripeID(stripeID) + license, err := ctx.licenseStore.GetLicenseByStripeID(stripeID) if err != nil { - logger.Error().Msgf("error fetching license: %v", err) - PrintSomethingWrong(w) - w.WriteHeader(http.StatusInternalServerError) - return + return http.StatusInternalServerError, fmt.Errorf("error fetching license: %v", err) } // if license exists ensure IsValid is true and return if license != nil { - logger.Debug().Msg("Existing license found, ensuring IsValid is true") + ctx.logger.Debug().Msg("existing license found, ensuring IsValid is true") license.IsValid = true - if err := licenseStore.UpdateLicense(license); err != nil { - PrintSomethingWrong(w) - w.WriteHeader(http.StatusInternalServerError) + if err := ctx.licenseStore.UpdateLicense(license); err != nil { + return http.StatusInternalServerError, errors.New("") } // TODO: email person to remind them their subscription is renewed. - return + return http.StatusOK, nil } // else create new license and store in db licenseID := GenerateLicenseKey() - logger.Info().Msgf("generating new license: %s", licenseID) + ctx.logger.Info().Msgf("generating new license: %s", licenseID) license = &License{ ID: licenseID, @@ -162,85 +144,73 @@ func handleWebhook(w http.ResponseWriter, req *http.Request) { RequestCount: 0, } - if _, err = conn.Model(license).Insert(); err != nil { - logger.Error().Msgf("error creating license: %v", err) - PrintSomethingWrong(w) - w.WriteHeader(http.StatusInternalServerError) + if _, err = ctx.db.Model(license).Insert(); err != nil { + return http.StatusInternalServerError, fmt.Errorf("error creating license: %v", err) } - stripe.Key = StripeKey + stripe.Key = ctx.config.StripeKey metadata := map[string]string{ "license": licenseID, } if _, err := customer.Update(session.Customer.ID, &stripe.CustomerParams{ - Params: stripe.Params{ - Metadata: metadata, - }, + Params: stripe.Params{Metadata: metadata}, }); err != nil { - fmt.Fprintf(w, "error adding license to customer metadata: %v", err) - PrintSomethingWrong(w) - w.WriteHeader(http.StatusInternalServerError) + return http.StatusInternalServerError, fmt.Errorf("error adding license to customer metadata: %v", err) } - if err = SendLicenseMail(license.Email, license.ID); err != nil { + if err = SendLicenseMail(ctx.config, license.Email, license.ID); err != nil { // TODO: retry sending email so user can get their license. - logger.Error().Msgf("error sending license email: %v", err) - PrintSomethingWrong(w) - w.WriteHeader(http.StatusInternalServerError) + return http.StatusInternalServerError, fmt.Errorf("error sending license email: %v", err) } case "customer.subscription.updated": sub := stripe.Subscription{} err := json.Unmarshal(event.Data.Raw, &sub) if err != nil { - fmt.Fprintf(w, "error parsing webhook JSON: %v", err) - w.WriteHeader(http.StatusBadRequest) - return + return http.StatusBadRequest, fmt.Errorf("error parsing webhook JSON: %v", err) } - license, err := licenseStore.GetLicenseByStripeID(sub.Customer.ID) + license, err := ctx.licenseStore.GetLicenseByStripeID(sub.Customer.ID) if err != nil { - logger.Error().Msgf("error finding license for valid subscriber: %v", err) + return http.StatusInternalServerError, fmt.Errorf("error finding license for valid subscriber: %v", err) } if license == nil { - logger.Error().Msg("failed to find license for existing subscriber. Something is terribly wrong") - PrintSomethingWrong(w) - return + return http.StatusInternalServerError, errors.New("failed to find license") } if sub.CancellationDetails.Reason != "" { license.IsValid = false license.ValidityReason = fmt.Sprintf("subscription was cancelled: %s", sub.CancellationDetails.Reason) - logger.Error().Msgf("invalidated license: %s", license.ID) + ctx.logger.Info().Msgf("invalidated license: %s", license.ID) } else { license.IsValid = true license.ValidityReason = "" - logger.Error().Msgf("activated license: %s", license.ID) + ctx.logger.Info().Msgf("activated license: %s", license.ID) } - if err = licenseStore.UpdateLicense(license); err != nil { - logger.Error().Msgf("error updating license: %v", err) - PrintSomethingWrong(w) - return + if err = ctx.licenseStore.UpdateLicense(license); err != nil { + return http.StatusInternalServerError, fmt.Errorf("error updating license: %v", err) } default: fmt.Fprintf(os.Stderr, "Unhandled event type: %s", event.Type) } w.WriteHeader(http.StatusOK) + return http.StatusOK, nil } -func handleGetLicense(w http.ResponseWriter, req *http.Request) { +func handleGetLicense(ctx appContext, w http.ResponseWriter, req *http.Request) (int, error) { vars := mux.Vars(req) licenseID := vars["id"] - logger.Info().Msgf("verifying license: %s", licenseID) + if licenseID == "" { + return http.StatusBadRequest, errors.New("licenseID path parameter was empty") + } - license, err := licenseStore.GetLicenseByID(licenseID) + ctx.logger.Info().Msgf("verifying license: %s", licenseID) + + license, err := ctx.licenseStore.GetLicenseByID(licenseID) if err != nil { - logger.Error().Msgf("failed to get license: %s", err.Error()) - w.WriteHeader(http.StatusInternalServerError) - PrintSomethingWrong(w) - return + return http.StatusInternalServerError, fmt.Errorf("failed to get license: %s", err.Error()) } // if license == nil { @@ -250,53 +220,54 @@ func handleGetLicense(w http.ResponseWriter, req *http.Request) { // } json.NewEncoder(w).Encode(license) + return http.StatusOK, nil } type TrialRegisterReq struct { Email string } -func handleTrialRegister(w http.ResponseWriter, req *http.Request) { - var trialReq TrialRegisterReq - - decoder := json.NewDecoder(req.Body) - if err := decoder.Decode(&trialReq); err != nil { - writeError(400, "JSON body missing or malformed", w) - return - } - - if trialReq.Email == "" { - fmt.Fprint(w, "Email cannot be empty") - w.WriteHeader(http.StatusBadRequest) - return - } - - license, err := licenseStore.GetLicenseByEmail(trialReq.Email) - if err != nil { - logger.Error().Msgf("failed to fetch license by email: %s", err.Error()) - w.WriteHeader(http.StatusInternalServerError) - return - } - - if license != nil { - logger.Error().Msg("email is already registered") - w.WriteHeader(http.StatusBadRequest) - fmt.Fprint(w, "email is already registered") - return - } - - if err = RegisterNewUser(trialReq.Email); err != nil { - logger.Error().Msgf("something went wrong registering a new user: %v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - w.WriteHeader(http.StatusOK) -} - -func RegisterNewUser(email string) error { +// func handleTrialRegister(ctx *appContext, w http.ResponseWriter, req *http.Request) { +// var trialReq TrialRegisterReq + +// decoder := json.NewDecoder(req.Body) +// if err := decoder.Decode(&trialReq); err != nil { +// writeError(ctx.logger, 400, "JSON body missing or malformed", w) +// return +// } + +// if trialReq.Email == "" { +// fmt.Fprint(w, "Email cannot be empty") +// w.WriteHeader(http.StatusBadRequest) +// return +// } + +// license, err := ctx.licenseStore.GetLicenseByEmail(trialReq.Email) +// if err != nil { +// ctx.logger.Error().Msgf("failed to fetch license by email: %s", err.Error()) +// w.WriteHeader(http.StatusInternalServerError) +// return +// } + +// if license != nil { +// ctx.logger.Error().Msg("email is already registered") +// w.WriteHeader(http.StatusBadRequest) +// fmt.Fprint(w, "email is already registered") +// return +// } + +// if err = RegisterNewUser(ctx.config, trialReq.Email); err != nil { +// ctx.logger.Error().Msgf("something went wrong registering a new user: %v", err) +// w.WriteHeader(http.StatusInternalServerError) +// return +// } + +// w.WriteHeader(http.StatusOK) +// } + +func RegisterNewUser(ctx appContext, email string) error { licenseID := GenerateLicenseKey() - logger.Info().Msgf("generated license: %s", licenseID) + ctx.logger.Info().Msgf("generated license: %s", licenseID) license := &License{ ID: licenseID, @@ -307,14 +278,14 @@ func RegisterNewUser(email string) error { RequestCount: 0, } - if _, err := conn.Model(license).Insert(); err != nil { - logger.Error().Msgf("error creating: %v", err) + if _, err := ctx.db.Model(license).Insert(); err != nil { + ctx.logger.Error().Msgf("error creating: %v", err) return err } - if err := SendLicenseMail(license.Email, license.ID); err != nil { + if err := SendLicenseMail(ctx.config, license.Email, license.ID); err != nil { // TODO: retry sending email so user can get their license. - logger.Error().Msgf("error sending license email: %v", err) + ctx.logger.Error().Msgf("error sending license email: %v", err) return err } diff --git a/src/images.go b/src/images.go index 1170e3b..d2354a6 100644 --- a/src/images.go +++ b/src/images.go @@ -6,7 +6,6 @@ import ( "time" "github.com/go-pg/pg/v10" - "github.com/rs/zerolog/log" ) type ImageAnnotation struct { @@ -24,20 +23,19 @@ type ImageAnnotation struct { } // FindByURI returns an image with the matching URI. -func FindByURI(conn *pg.DB, imgURI string) (*ImageAnnotation, error) { +func FindByURI(conn pg.DB, imgURI string) (ImageAnnotation, error) { var img ImageAnnotation err := conn.Model(&img).Where("uri = ?", imgURI).Select() if err != nil { - log.Error().Msgf("err: %v", err) - return nil, nil + return img, err } - return &img, nil + return img, nil } // FindAnnotationsByURI returns annotations that have matching URI's. -func FindAnnotationsByURI(conn *pg.DB, uris []string) ([]ImageAnnotation, error) { +func FindAnnotationsByURI(conn pg.DB, uris []string) ([]ImageAnnotation, error) { var annotations []ImageAnnotation if len(uris) == 0 { @@ -50,17 +48,17 @@ func FindAnnotationsByURI(conn *pg.DB, uris []string) ([]ImageAnnotation, error) } // Insert inserts the annotation into the DB. -func Insert(conn *pg.DB, image *ImageAnnotation) error { - _, err := conn.Model(image).Insert() +func Insert(conn pg.DB, image ImageAnnotation) error { + _, err := conn.Model(&image).Insert() if err != nil { return err } - logger.Debug().Msgf("inserted image: %s", image.URI) return nil } -func InsertAll(conn *pg.DB, images []*ImageAnnotation) error { +// InsertAll inserts all the image safe search annotations into the DB. +func InsertAll(conn pg.DB, images []*ImageAnnotation) error { if len(images) == 0 { return nil } @@ -69,13 +67,12 @@ func InsertAll(conn *pg.DB, images []*ImageAnnotation) error { if err != nil { return err } - logger.Debug().Msgf("inserted %d images", len(images)) return nil } // DeleteByURI deletes the images with matching URI. -func DeleteByURI(conn *pg.DB, uri string) error { +func DeleteByURI(conn pg.DB, uri string) error { img := ImageAnnotation{URI: uri} if _, err := conn.Model(&img).Where("uri = ?", uri).Delete(); err != nil { diff --git a/src/images_test.go b/src/images_test.go index 81bfc8e..8ec6487 100644 --- a/src/images_test.go +++ b/src/images_test.go @@ -2,86 +2,98 @@ package src import ( "database/sql" - "fmt" - "log" "os" "testing" "time" "github.com/joho/godotenv" + "github.com/rs/zerolog" ) // var conn *pg.DB -var testErr error -var imgURIList = []string{ - "https://hatrabbits.com/wp-content/uploads/2017/01/random.jpg", - "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcT1ZgCJADylizZLNnOnyuhtwR2qVk5yOi0UoQ&usqp=CAU", - "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRKsJoGKlOJnxl-GNgfUtluGobgx_M8JBdsng&usqp=CAU", +func getTestCtx() (appContext, error) { + var ctx appContext + config, err := newConfig() + if err != nil { + return ctx, err + } + config.DBName = "purity_test" + conn, err := InitDB(config) + if err != nil { + return ctx, err + } + ctx.db = *conn + ctx.logger = zerolog.New(os.Stderr).With().Timestamp().Logger() + ctx.licenseStore = NewLicenseStore(conn) + ctx.annotationStore = nil + ctx.config = config + return ctx, nil } func TestMain(m *testing.M) { - if err := godotenv.Load("../.env"); err != nil { - log.Fatal(err) - } - InitConfig() + godotenv.Load() +} - fmt.Println("Got value: ", os.Getenv("PURITY_DB_HOST")) - conn, testErr = InitDB(DefaultDBTestName) - if testErr != nil { - log.Fatal(testErr) +func TestImages(t *testing.T) { + ctx, err := getTestCtx() + if err != nil { + t.Fatal(err) } - exitCode := m.Run() - - os.Exit(exitCode) -} + var imgURIList = []string{ + "https://hatrabbits.com/wp-content/uploads/2017/01/random.jpg", + "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcT1ZgCJADylizZLNnOnyuhtwR2qVk5yOi0UoQ&usqp=CAU", + "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRKsJoGKlOJnxl-GNgfUtluGobgx_M8JBdsng&usqp=CAU", + } -func TestInsertImage(t *testing.T) { - for _, uri := range imgURIList { - fakeHash := Hash(uri) - anno := &ImageAnnotation{ - Hash: fakeHash, - URI: uri, - Error: sql.NullString{}, - DateAdded: time.Now(), - Adult: 0, - Spoof: 0, - Medical: 0, - Violence: 0, - Racy: 0, - } - testErr = Insert(conn, anno) - if testErr != nil { - t.Fatal(testErr.Error()) + t.Run("inserts images", func(t *testing.T) { + for _, uri := range imgURIList { + fakeHash := Hash(uri) + anno := ImageAnnotation{ + Hash: fakeHash, + URI: uri, + Error: sql.NullString{}, + DateAdded: time.Now(), + Adult: 0, + Spoof: 0, + Medical: 0, + Violence: 0, + Racy: 0, + } + err := Insert(ctx.db, anno) + if err != nil { + t.Fatal(err.Error()) + } } - } -} + }) -func TestFindImagesByURI(t *testing.T) { - smallURIList := imgURIList[:1] + t.Run("finds images by URI", func(t *testing.T) { + smallURIList := imgURIList[:1] - imgList, err := FindAnnotationsByURI(conn, smallURIList) - if err != nil { - t.Fatal(err.Error()) - } + imgList, err := FindAnnotationsByURI(ctx.db, smallURIList) + if err != nil { + t.Fatal(err.Error()) + } - if len(imgList) != 1 { - t.Fatalf("Expected 1 image in response but received %d", len(imgList)) - t.FailNow() - } + if len(imgList) != 1 { + t.Fatalf("Expected 1 image in response but received %d", len(imgList)) + t.FailNow() + } - smallURIList = []string{} - _, err = FindAnnotationsByURI(conn, smallURIList) - if err == nil { - t.Fatal("Expected FindImagesByURI to return an error because imgURIList cannot be empty") - } -} + smallURIList = []string{} + _, err = FindAnnotationsByURI(ctx.db, smallURIList) + if err == nil { + t.Fatal("Expected FindImagesByURI to return an error because imgURIList cannot be empty") + } + }) -func TestDeleteImagesByURI(t *testing.T) { - for _, uri := range imgURIList { - testErr = DeleteByURI(conn, uri) - if testErr != nil { - t.Fatal(testErr) + t.Run("deletes images by URI", func(t *testing.T) { + for _, uri := range imgURIList { + err := DeleteByURI(ctx.db, uri) + if err != nil { + t.Fatal(err) + } } - } + }) + } diff --git a/src/license.go b/src/license.go index 9d0e639..e92b826 100644 --- a/src/license.go +++ b/src/license.go @@ -15,23 +15,24 @@ type License struct { IsTrial bool `json:"isTrial"` } -type LicenseManager interface { +type LicenseStorer interface { GetLicenseByID(id string) (*License, error) GetLicenseByStripeID(id string) (*License, error) UpdateLicense(*License) error + GetLicenseByEmail(email string) (*License, error) ExpireTrial(*License) (*License, error) } -type LicenseStore struct { +type licenseStore struct { db *pg.DB } -func NewLicenseStore(db *pg.DB) *LicenseStore { - return &LicenseStore{db: db} +func NewLicenseStore(db *pg.DB) *licenseStore { + return &licenseStore{db: db} } // GetLicenseByID fetches a license from DB by license ID -func (store *LicenseStore) GetLicenseByID(id string) (*License, error) { +func (store *licenseStore) GetLicenseByID(id string) (*License, error) { license := new(License) err := store.db.Model(license).Where("id = ?", id).Select() if err != nil { @@ -43,7 +44,7 @@ func (store *LicenseStore) GetLicenseByID(id string) (*License, error) { return license, nil } -func (store *LicenseStore) GetLicenseByStripeID(stripeID string) (*License, error) { +func (store *licenseStore) GetLicenseByStripeID(stripeID string) (*License, error) { license := new(License) err := store.db.Model(license).Where("stripe_id = ?", stripeID).Select() if err != nil { @@ -55,12 +56,12 @@ func (store *LicenseStore) GetLicenseByStripeID(stripeID string) (*License, erro return license, nil } -func (store *LicenseStore) UpdateLicense(license *License) error { +func (store *licenseStore) UpdateLicense(license *License) error { _, err := store.db.Model(license).Where("id = ?", license.ID).Update(license) return err } -func (store *LicenseStore) GetLicenseByEmail(email string) (*License, error) { +func (store *licenseStore) GetLicenseByEmail(email string) (*License, error) { license := new(License) err := store.db.Model(license).Where("email = ?", email).Select() if err != nil { @@ -72,7 +73,7 @@ func (store *LicenseStore) GetLicenseByEmail(email string) (*License, error) { return license, nil } -func (store *LicenseStore) ExpireTrial(license *License) (*License, error) { +func (store *licenseStore) ExpireTrial(license *License) (*License, error) { license.IsValid = false license.ValidityReason = "trial license has expired" if err := store.UpdateLicense(license); err != nil { diff --git a/src/mail.go b/src/mail.go index 5a2d661..2b0aabe 100644 --- a/src/mail.go +++ b/src/mail.go @@ -15,11 +15,11 @@ type Email struct { Html string } -func SendMail(email Email) error { - from := mail.NewEmail(EmailName, EmailFrom) +func SendMail(config Config, email Email) error { + from := mail.NewEmail(config.EmailName, config.EmailFrom) to := mail.NewEmail(email.Name, email.To) message := mail.NewSingleEmail(from, email.Subject, to, email.Plain, email.Html) - client := sendgrid.NewSendClient(SendgridAPIKey) + client := sendgrid.NewSendClient(config.SendgridAPIKey) _, err := client.Send(message) if err != nil { @@ -29,7 +29,7 @@ func SendMail(email Email) error { return nil } -func SendLicenseMail(emailTo string, licenseID string) error { +func SendLicenseMail(config Config, emailTo string, licenseID string) error { email := Email{ Name: emailTo, To: emailTo, @@ -38,7 +38,7 @@ func SendLicenseMail(emailTo string, licenseID string) error { Html: fmt.Sprintf("

Your PurityVision License Key

%s

", licenseID), } - if err := SendMail(email); err != nil { + if err := SendMail(config, email); err != nil { return err } diff --git a/src/middleware.go b/src/middleware.go index f3a2096..646c898 100644 --- a/src/middleware.go +++ b/src/middleware.go @@ -7,7 +7,7 @@ import ( "github.com/google/uuid" ) -func getLicenseFromReq(ls LicenseManager, r *http.Request) (*License, error) { +func getLicenseFromReq(ls LicenseStorer, r *http.Request) (*License, error) { licenseID := r.Header.Get("LicenseID") _, err := uuid.Parse(licenseID) @@ -23,11 +23,12 @@ func getLicenseFromReq(ls LicenseManager, r *http.Request) (*License, error) { return license, nil } -func paywallMiddleware(ls LicenseManager) func(next http.Handler) http.Handler { +func paywallMiddleware(ctx appContext) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - license, err := getLicenseFromReq(ls, r) + license, err := getLicenseFromReq(ctx.licenseStore, r) if err != nil { + ctx.logger.Info().Msgf("failed to get license: %v", err) http.Error(w, err.Error(), http.StatusUnauthorized) return } diff --git a/src/routes.go b/src/routes.go deleted file mode 100644 index 74b2466..0000000 --- a/src/routes.go +++ /dev/null @@ -1,37 +0,0 @@ -package src - -import ( - "fmt" - "net/http" - - "github.com/go-pg/pg/v10" - "github.com/gorilla/mux" - "github.com/rs/zerolog/log" -) - -var PrintSomethingWrong = func(w http.ResponseWriter) { fmt.Fprint(w, "Something went wrong") } - -// Init intializes the Serve instance and exposes it based on the port parameter. -func (s *Serve) InitServer(port int, _conn *pg.DB) { - // Store the database connection in a global var. - conn = _conn - licenseStore = NewLicenseStore(conn) - - r := mux.NewRouter() - - r.Use(addCorsHeaders) - r.Handle("/", http.FileServer(http.Dir("./"))).Methods("GET") - r.HandleFunc("/health", health).Methods("GET", "OPTIONS") - r.HandleFunc("/license/{id}", handleGetLicense).Methods("GET") - r.HandleFunc("/webhook", handleWebhook).Methods("POST") - r.HandleFunc("/trial-register", handleTrialRegister).Methods("POST", "OPTIONS") - - // Paywalled filter routes. - filterR := r.PathPrefix("/filter").Subrouter() - filterR.Use(paywallMiddleware(licenseStore)) - filterR.HandleFunc("/batch", handleBatchFilter(logger)).Methods("POST", "OPTIONS") - - listenAddr = fmt.Sprintf("%s:%d", listenAddr, port) - log.Info().Msgf("Web server now listening on %s", listenAddr) - log.Fatal().Msg(http.ListenAndServe(listenAddr, r).Error()) -} diff --git a/src/server.go b/src/server.go index 8c5dd53..d157767 100644 --- a/src/server.go +++ b/src/server.go @@ -1,57 +1,113 @@ package src import ( - "database/sql" - "encoding/json" + "flag" + "fmt" "net/http" "os" + "strconv" "github.com/go-pg/pg/v10" + "github.com/gorilla/mux" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) -// Server listens on localhost:8080 by default. -var listenAddr string = "" - -// Store the db connection passed from main.go. -var conn *pg.DB - -var licenseStore *LicenseStore - -var logger zerolog.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, NoColor: true}).With().Caller().Logger() - // AnnotateReq is the form of an incoming JSON payload // for retrieving pass/fail status of each supplied image URI. type AnnotateReq struct { ImgURIList []string `json:"imgURIList"` } -// ErrorRes is a JSON response containing an error message from the API. -type ErrorRes struct { - Message string `json:"message"` +type AnnotationStore interface { + GetAnnotations([]string) ([]*ImageAnnotation, error) + PutAnnotations([]*ImageAnnotation) error } -// Server defines the actions of a Purity API Web Server. -type Server interface { - Init(int, *sql.DB) +// Serve is an instance of a Purity API Web Server. +type appContext struct { + db pg.DB + logger zerolog.Logger + licenseStore LicenseStorer + annotationStore AnnotationStore + config Config } -// Serve is an instance of a Purity API Web Server. -type Serve struct { +type appHandler struct { + appContext + H func(appContext, http.ResponseWriter, *http.Request) (int, error) } -// NewServe returns an uninitialized Serve instance. -func NewServe() *Serve { - return &Serve{} +// Our ServeHTTP method is mostly the same, and also has the ability to +// access our *appContext's fields (templates, loggers, etc.) as well. +func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Updated to pass ah.appContext as a parameter to our handler type. + status, err := ah.H(ah.appContext, w, r) + if err != nil { + log.Printf("HTTP %d: %q", status, err) + switch status { + case http.StatusNotFound: + http.NotFound(w, r) + // And if we wanted a friendlier error page, we can + // now leverage our context instance - e.g. + // err := ah.renderTemplate(w, "http_404.tmpl", nil) + case http.StatusInternalServerError: + http.Error(w, http.StatusText(status), status) + default: + http.Error(w, http.StatusText(status), status) + } + } } -func writeError(code int, message string, w http.ResponseWriter) { - logger.Info().Msg(message) - w.WriteHeader(code) - err := ErrorRes{ - Message: message, +// Init intializes the Serve instance and exposes it based on the port parameter. +func InitServer() { + var portFlag int + + config, err := newConfig() + if err != nil { + log.Fatal().Msg(err.Error()) + } + zerolog.SetGlobalLevel(zerolog.Level(zerolog.ErrorLevel)) + + conn, err := InitDB(config) + if err != nil { + log.Fatal().Msg(err.Error()) } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(err) + defer conn.Close() + + ctx := appContext{ + db: *conn, + logger: zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr, NoColor: true}).With().Timestamp().Logger(), + licenseStore: NewLicenseStore(conn), + annotationStore: nil, + config: config, + } + + flag.IntVar(&portFlag, "port", 8080, "port to run the service on") + flag.Parse() + + logLevel, err := strconv.Atoi(ctx.config.LogLevel) + if err != nil { + panic(err) + } + zerolog.SetGlobalLevel(zerolog.Level(logLevel)) + + r := mux.NewRouter() + + r.Use(addCorsHeaders) + r.Handle("/", http.FileServer(http.Dir("./"))).Methods("GET") + r.HandleFunc("/health", health).Methods("GET", "OPTIONS") + r.Handle("/license/{id}", &appHandler{ctx, handleGetLicense}).Methods("GET", "OPTIONS") + r.Handle("/webhook", &appHandler{ctx, handleWebhook}).Methods("POST") + // r.HandleFunc("/trial-register", handleTrialRegister).Methods("POST", "OPTIONS") + + // Paywalled filter routes. + filterR := r.PathPrefix("/filter").Subrouter() + filterR.Use(paywallMiddleware(ctx)) + filterR.Handle("/batch", &appHandler{ctx, handleBatchFilter}).Methods("POST", "OPTIONS") + + listenAddr := "" + listenAddr = fmt.Sprintf("%s:%d", listenAddr, portFlag) + ctx.logger.Info().Msgf("Web server now listening on %s", listenAddr) + ctx.logger.Fatal().Msg(http.ListenAndServe(listenAddr, r).Error()) } diff --git a/src/server_test.go b/src/server_test.go index 94f7377..37aff9c 100644 --- a/src/server_test.go +++ b/src/server_test.go @@ -9,8 +9,11 @@ import ( "log" "net/http" "net/http/httptest" + "os" "testing" "time" + + "github.com/rs/zerolog" ) const testLicenseID = "797e2754-7547-49c2-acfb-fa7b8357ab03" @@ -82,200 +85,200 @@ func TestHealthHandler(t *testing.T) { testHealthJunkBody(t) } -func testCleanup() { - conn.Model(&ImageAnnotation{}).Where("1=1").Delete() - _, err := conn.Model(&License{}).Where("1=1").Delete() +func TestServer(t *testing.T) { + ctx, err := getTestCtx() if err != nil { - fmt.Println("error: ", err) + t.Fatal(err) } - defer conn.Close() -} + ctx.logger = zerolog.Logger{} -func TestFilterHandlerTable(t *testing.T) { - t.Cleanup(testCleanup) + t.Cleanup(func() { + ctx.db.Model(&ImageAnnotation{}).Where("1=1").Delete() + _, err := ctx.db.Model(&License{}).Where("1=1").Delete() + if err != nil { + fmt.Println("error: ", err) + } + defer ctx.db.Close() + }) - licenseStore = NewLicenseStore(conn) - if serverTestErr != nil { - log.Fatal(serverTestErr) - } + t.Run("filters different URI lists", func(t *testing.T) { + if serverTestErr != nil { + log.Fatal(serverTestErr) + } - license := &License{ - ID: testLicenseID, - Email: "test@email.com", - StripeID: "stripe id", - IsValid: true, - ValidityReason: "", - } + license := &License{ + ID: testLicenseID, + Email: "test@email.com", + StripeID: "stripe id", + IsValid: true, + SubscriptionID: os.Getenv("STRIPE_TEST_SUB_ID"), + ValidityReason: "", + } - if _, serverTestErr = conn.Model(license).Insert(); serverTestErr != nil { - t.Error("failed to create test license") - } + if _, serverTestErr = ctx.db.Model(license).Insert(); serverTestErr != nil { + t.Error("failed to create test license") + } - tests := []FilterTest{ - { - Given: []string{}, - Expect: FilterTestExpect{ - Code: 400, - Error: errors.New("ImgUriList cannot be empty"), - Res: []*ImageAnnotation{}, - }, - }, - { - Given: []string{ - "https://i.imgur.com/FEpwOY8.jpg", - "https://i.imgur.com/FEpwOY8.jpg", - "https://i.imgur.com/FEpwOY8.jpg", - "https://i.imgur.com/FEpwOY8.jpg", - "https://i.imgur.com/FEpwOY8.jpg", + tests := []FilterTest{ + { + Given: []string{}, + Expect: FilterTestExpect{ + Code: 400, + Error: errors.New("ImgUriList cannot be empty"), + Res: []*ImageAnnotation{}, + }, }, - Expect: FilterTestExpect{ - Code: 200, - Error: nil, - Res: []*ImageAnnotation{ - { - Hash: "87408bebb6a1d42cd7cc1bbffb6d7dcc6aff14af4aea5c9af9fc5b624cf7c93a", - URI: "https://i.imgur.com/FEpwOY8.jpg", - Error: sql.NullString{}, - DateAdded: time.Now(), - Adult: 2, - Spoof: 1, - Medical: 2, - Violence: 3, - Racy: 5, + { + Given: []string{ + "https://i.imgur.com/FEpwOY8.jpg", + "https://i.imgur.com/FEpwOY8.jpg", + "https://i.imgur.com/FEpwOY8.jpg", + "https://i.imgur.com/FEpwOY8.jpg", + "https://i.imgur.com/FEpwOY8.jpg", + }, + Expect: FilterTestExpect{ + Code: 200, + Error: nil, + Res: []*ImageAnnotation{ + { + Hash: "87408bebb6a1d42cd7cc1bbffb6d7dcc6aff14af4aea5c9af9fc5b624cf7c93a", + URI: "https://i.imgur.com/FEpwOY8.jpg", + Error: sql.NullString{}, + DateAdded: time.Now(), + Adult: 2, + Spoof: 1, + Medical: 2, + Violence: 3, + Racy: 5, + }, }, }, }, - }, - { - Given: []string{ - "https://i.imgur.com/FEpwOY8.jpg", - "https://i.imgur.com/6ZOubbU.png", - "https://i.imgur.com/qtTfzH6.jpg", - "https://i.imgur.com/RwHI4jk.jpg", - }, - Expect: FilterTestExpect{ - Code: 200, - Error: nil, - Res: []*ImageAnnotation{ - { - Hash: "87408bebb6a1d42cd7cc1bbffb6d7dcc6aff14af4aea5c9af9fc5b624cf7c93a", - URI: "https://i.imgur.com/FEpwOY8.jpg", - Error: sql.NullString{}, - DateAdded: time.Now(), - Adult: 2, - Spoof: 1, - Medical: 2, - Violence: 3, - Racy: 5, - }, - { - Hash: "65d2ad788998a350e7476c4a110ece346d4d56ab76670d48ddd896444a0029b1", - URI: "https://i.imgur.com/6ZOubbU.png", - Error: sql.NullString{}, - DateAdded: time.Now(), - Adult: 1, - Spoof: 1, - Medical: 2, - Violence: 1, - Racy: 5, - }, - { - Hash: "2a5cdbc5148669ec4efc788d03f535cf99f13756ccd200ae48faf59fac30b811", - URI: "https://i.imgur.com/qtTfzH6.jpg", - Error: sql.NullString{}, - DateAdded: time.Now(), - Adult: 2, - Spoof: 3, - Medical: 2, - Violence: 4, - Racy: 2, - }, - { - Hash: "b2047dfb0412f815859b269288a948528587b77d9b3e0395cd57faf2ba4c37f5", - URI: "https://i.imgur.com/RwHI4jk.jpg", - Error: sql.NullString{}, - DateAdded: time.Now(), - Adult: 5, - Spoof: 1, - Medical: 3, - Violence: 3, - Racy: 5, + { + Given: []string{ + "https://i.imgur.com/FEpwOY8.jpg", + "https://i.imgur.com/6ZOubbU.png", + "https://i.imgur.com/qtTfzH6.jpg", + "https://i.imgur.com/RwHI4jk.jpg", + }, + Expect: FilterTestExpect{ + Code: 200, + Error: nil, + Res: []*ImageAnnotation{ + { + Hash: "87408bebb6a1d42cd7cc1bbffb6d7dcc6aff14af4aea5c9af9fc5b624cf7c93a", + URI: "https://i.imgur.com/FEpwOY8.jpg", + Error: sql.NullString{}, + DateAdded: time.Now(), + Adult: 2, + Spoof: 1, + Medical: 2, + Violence: 3, + Racy: 5, + }, + { + Hash: "65d2ad788998a350e7476c4a110ece346d4d56ab76670d48ddd896444a0029b1", + URI: "https://i.imgur.com/6ZOubbU.png", + Error: sql.NullString{}, + DateAdded: time.Now(), + Adult: 1, + Spoof: 1, + Medical: 2, + Violence: 1, + Racy: 5, + }, + { + Hash: "2a5cdbc5148669ec4efc788d03f535cf99f13756ccd200ae48faf59fac30b811", + URI: "https://i.imgur.com/qtTfzH6.jpg", + Error: sql.NullString{}, + DateAdded: time.Now(), + Adult: 2, + Spoof: 3, + Medical: 2, + Violence: 4, + Racy: 2, + }, + { + Hash: "b2047dfb0412f815859b269288a948528587b77d9b3e0395cd57faf2ba4c37f5", + URI: "https://i.imgur.com/RwHI4jk.jpg", + Error: sql.NullString{}, + DateAdded: time.Now(), + Adult: 5, + Spoof: 1, + Medical: 3, + Violence: 3, + Racy: 5, + }, }, }, }, - }, - } - - for _, test := range tests { - req := &AnnotateReq{ImgURIList: test.Given} - rec, err := testFilterHandler(req) - - if test.Expect.Error != nil { - decoder := json.NewDecoder(rec.Body) - var errRes ErrorRes - if err := decoder.Decode(&errRes); err != nil { - t.Error("JSON body missing or malformed") - } - if errRes.Message != test.Expect.Error.Error() { - t.Error("expected error but didn't get one") - } } - if test.Expect.Error == nil && err != nil { - t.Error("didn't expect error but got: ", err.Error()) - } - if rec.Code != test.Expect.Code { - t.Errorf("expected status %d but got %d", rec.Code, test.Expect.Code) - } - var annotations []*ImageAnnotation - json.Unmarshal(rec.Body.Bytes(), &annotations) - - if len(annotations) != len(test.Expect.Res) { - t.Errorf("expected %d annotation results but got %d", len(test.Expect.Res), len(annotations)) - } + for _, test := range tests { + req := &AnnotateReq{ImgURIList: test.Given} + rec, code, err := testFilterHandler(ctx, req) - for i, annotation := range annotations { - expected := test.Expect.Res[i] - if annotation.Adult != expected.Adult { - t.Errorf("expected adult to be %d but got %d", annotation.Adult, expected.Adult) + if test.Expect.Error != nil { + if err.Error() != test.Expect.Error.Error() { + t.Error("expected error but didn't get one") + } } - if annotation.Spoof != expected.Spoof { - t.Errorf("expected spoof to be %d but got %d", annotation.Spoof, expected.Spoof) + if test.Expect.Error == nil && err != nil { + t.Error("didn't expect error but got: ", err.Error()) } - - if annotation.Medical != expected.Medical { - t.Errorf("expected medical to be %d but got %d", annotation.Medical, expected.Medical) + if code != test.Expect.Code { + t.Errorf("expected status %d but got %d", test.Expect.Code, rec.Code) } + var annotations []*ImageAnnotation + json.Unmarshal(rec.Body.Bytes(), &annotations) - if annotation.Violence != expected.Violence { - t.Errorf("expected violence to be %d but got %d", annotation.Violence, expected.Violence) + if len(annotations) != len(test.Expect.Res) { + t.Errorf("expected %d annotation results but got %d", len(test.Expect.Res), len(annotations)) } - if annotation.Racy != expected.Racy { - t.Errorf("expected racy to be %d but got %d", annotation.Racy, expected.Racy) + for i, annotation := range annotations { + expected := test.Expect.Res[i] + if annotation.Adult != expected.Adult { + t.Errorf("expected adult to be %d but got %d", annotation.Adult, expected.Adult) + } + + if annotation.Spoof != expected.Spoof { + t.Errorf("expected spoof to be %d but got %d", annotation.Spoof, expected.Spoof) + } + + if annotation.Medical != expected.Medical { + t.Errorf("expected medical to be %d but got %d", annotation.Medical, expected.Medical) + } + + if annotation.Violence != expected.Violence { + t.Errorf("expected violence to be %d but got %d", annotation.Violence, expected.Violence) + } + + if annotation.Racy != expected.Racy { + t.Errorf("expected racy to be %d but got %d", annotation.Racy, expected.Racy) + } } } - } + }) + } -func testFilterHandler(fr *AnnotateReq) (*httptest.ResponseRecorder, error) { +func testFilterHandler(ctx appContext, fr *AnnotateReq) (*httptest.ResponseRecorder, int, error) { b, err := json.Marshal(fr) if err != nil { - return nil, fmt.Errorf("Failed to marshal request body struct") + return nil, -1, fmt.Errorf("Failed to marshal request body struct") } r := bytes.NewReader(b) req, err := http.NewRequest("POST", "/filter", r) req.Header.Add("LicenseID", testLicenseID) if err != nil { - return nil, errors.New("Failed to create test HTTP request") + return nil, -1, errors.New("Failed to create test HTTP request") } rr := httptest.NewRecorder() - handler := http.HandlerFunc(handleBatchFilter(logger)) - - handler.ServeHTTP(rr, req) - return rr, nil + code, err := handleBatchFilter(ctx, rr, req) + return rr, code, err } diff --git a/src/stripe.go b/src/stripe.go index 3b70233..caf8f91 100644 --- a/src/stripe.go +++ b/src/stripe.go @@ -8,14 +8,14 @@ import ( "github.com/stripe/stripe-go/v74/usagerecord" ) -func IncrementSubscriptionMeter(lic *License, quantity int64) error { - if StripeKey == "" { +func IncrementSubscriptionMeter(stripeKey string, lic *License, quantity int64) error { + if stripeKey == "" { return errors.New("STRIPE_KEY env var not found") } - stripe.Key = StripeKey + stripe.Key = stripeKey - s, err := fetchStripeSubscription(lic) + s, err := fetchStripeSubscription(stripeKey, lic) if err != nil { return err } @@ -31,12 +31,12 @@ func IncrementSubscriptionMeter(lic *License, quantity int64) error { return err } -func fetchStripeSubscription(lic *License) (*stripe.Subscription, error) { - if StripeKey == "" { +func fetchStripeSubscription(stripeKey string, lic *License) (*stripe.Subscription, error) { + if stripeKey == "" { return nil, errors.New("STRIPE_KEY env var not found") } - stripe.Key = StripeKey + stripe.Key = stripeKey sub, err := subscription.Get(lic.SubscriptionID, nil) if err != nil {