Skip to content

Commit

Permalink
chore: move auth logic into middleware && naming the files better (#115)
Browse files Browse the repository at this point in the history
* chore: move auth logic into middleware

* refactor: move the project ownership into central place

* refactor: give files better names
  • Loading branch information
iandyh authored Sep 25, 2024
1 parent 7d27593 commit c0b7c40
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 90 deletions.
2 changes: 1 addition & 1 deletion shibuya/api/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func getCollection(collectionID string) (*model.Collection, error) {
}

func (s *ShibuyaAPI) collectionConfigGetHandler(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(req, params)
collection, err := hasCollectionOwnership(req, params)
if err != nil {
s.handleErrors(w, err)
return
Expand Down
8 changes: 8 additions & 0 deletions shibuya/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,11 @@ func makeInternalServerError(message string) error {
func makeInvalidResourceError(resource string) error {
return fmt.Errorf("%winvalid %s", invalidRequestErr, resource)
}

func makeProjectOwnershipError() error {
return fmt.Errorf("%w%s", noPermissionErr, "You don't own the project")
}

func makeCollectionOwnershipError() error {
return fmt.Errorf("%w%s", noPermissionErr, "You don't own the collection")
}
87 changes: 36 additions & 51 deletions shibuya/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"net/http"
"strconv"
"strings"
"time"

"github.com/julienschmidt/httprouter"
Expand Down Expand Up @@ -89,11 +90,7 @@ func (s *ShibuyaAPI) handleErrors(w http.ResponseWriter, err error) {
}

func (s *ShibuyaAPI) projectsGetHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
account := model.GetAccountBySession(r)
if account == nil {
s.makeFailMessage(w, "Need to login", http.StatusForbidden)
return
}
account := r.Context().Value(accountKey).(*model.Account)
qs := r.URL.Query()
var includeCollections, includePlans bool
var err error
Expand Down Expand Up @@ -145,11 +142,7 @@ func (s *ShibuyaAPI) projectUpdateHandler(w http.ResponseWriter, _ *http.Request
}

func (s *ShibuyaAPI) projectCreateHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
account := model.GetAccountBySession(r)
if account == nil {
s.handleErrors(w, makeLoginError())
return
}
account := r.Context().Value(accountKey).(*model.Account)
r.ParseForm()
name := r.Form.Get("name")
if name == "" {
Expand Down Expand Up @@ -191,18 +184,14 @@ func (s *ShibuyaAPI) projectCreateHandler(w http.ResponseWriter, r *http.Request
}

func (s *ShibuyaAPI) projectDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
account := model.GetAccountBySession(r)
if account == nil {
s.handleErrors(w, makeLoginError())
return
}
account := r.Context().Value(accountKey).(*model.Account)
project, err := getProject(params.ByName("project_id"))
if err != nil {
s.handleErrors(w, err)
return
}
if _, ok := account.MLMap[project.Owner]; !ok {
s.handleErrors(w, noPermissionErr)
if r := hasProjectOwnership(project, account); !r {
s.handleErrors(w, makeProjectOwnershipError())
return
}
collectionIDs, err := project.GetCollections()
Expand Down Expand Up @@ -260,20 +249,16 @@ func (s *ShibuyaAPI) collectionAdminGetHandler(w http.ResponseWriter, r *http.Re
}

func (s *ShibuyaAPI) planCreateHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
account := model.GetAccountBySession(r)
if account == nil {
s.handleErrors(w, makeLoginError())
return
}
account := r.Context().Value(accountKey).(*model.Account)
r.ParseForm()
projectID := r.Form.Get("project_id")
project, err := getProject(projectID)
if err != nil {
s.handleErrors(w, err)
return
}
if _, ok := account.MLMap[project.Owner]; !ok {
s.handleErrors(w, makeNoPermissionErr("You don't own the project"))
if r := hasProjectOwnership(project, account); !r {
s.handleErrors(w, makeProjectOwnershipError())
return
}
name := r.Form.Get("name")
Expand All @@ -294,11 +279,7 @@ func (s *ShibuyaAPI) planCreateHandler(w http.ResponseWriter, r *http.Request, _
}

func (s *ShibuyaAPI) planDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
account := model.GetAccountBySession(r)
if account == nil {
s.handleErrors(w, makeLoginError())
return
}
account := r.Context().Value(accountKey).(*model.Account)
plan, err := getPlan(params.ByName("plan_id"))
if err != nil {
s.handleErrors(w, err)
Expand All @@ -309,8 +290,8 @@ func (s *ShibuyaAPI) planDeleteHandler(w http.ResponseWriter, r *http.Request, p
s.handleErrors(w, err)
return
}
if _, ok := account.MLMap[project.Owner]; !ok {
s.handleErrors(w, makeLoginError())
if r := hasProjectOwnership(project, account); !r {
s.handleErrors(w, makeProjectOwnershipError())
return
}
using, err := plan.IsBeingUsed()
Expand Down Expand Up @@ -355,7 +336,7 @@ func (s *ShibuyaAPI) collectionFilesGetHandler(w http.ResponseWriter, _ *http.Re
}

func (s *ShibuyaAPI) collectionFilesUploadHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand All @@ -375,7 +356,7 @@ func (s *ShibuyaAPI) collectionFilesUploadHandler(w http.ResponseWriter, r *http
}

func (s *ShibuyaAPI) collectionFilesDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand Down Expand Up @@ -415,11 +396,7 @@ func (s *ShibuyaAPI) planFilesDeleteHandler(w http.ResponseWriter, r *http.Reque
}

func (s *ShibuyaAPI) collectionCreateHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
account := model.GetAccountBySession(r)
if account == nil {
s.handleErrors(w, makeLoginError())
return
}
account := r.Context().Value(accountKey).(*model.Account)
r.ParseForm()
collectionName := r.Form.Get("name")
if collectionName == "" {
Expand All @@ -432,8 +409,8 @@ func (s *ShibuyaAPI) collectionCreateHandler(w http.ResponseWriter, r *http.Requ
s.handleErrors(w, err)
return
}
if _, ok := account.MLMap[project.Owner]; !ok {
s.handleErrors(w, makeNoPermissionErr("You don't have the permission"))
if r := hasProjectOwnership(project, account); !r {
s.handleErrors(w, makeProjectOwnershipError())
return
}
collectionID, err := model.CreateCollection(collectionName, project.ID)
Expand All @@ -450,7 +427,7 @@ func (s *ShibuyaAPI) collectionCreateHandler(w http.ResponseWriter, r *http.Requ
}

func (s *ShibuyaAPI) collectionDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand Down Expand Up @@ -480,7 +457,7 @@ func (s *ShibuyaAPI) collectionDeleteHandler(w http.ResponseWriter, r *http.Requ
}

func (s *ShibuyaAPI) collectionGetHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand Down Expand Up @@ -519,7 +496,7 @@ func (s *ShibuyaAPI) collectionUpdateHandler(w http.ResponseWriter, _ *http.Requ
}

func (s *ShibuyaAPI) collectionUploadHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand Down Expand Up @@ -613,7 +590,7 @@ func (s *ShibuyaAPI) collectionUploadHandler(w http.ResponseWriter, r *http.Requ
}

func (s *ShibuyaAPI) collectionEnginesDetailHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand All @@ -627,7 +604,7 @@ func (s *ShibuyaAPI) collectionEnginesDetailHandler(w http.ResponseWriter, r *ht
}

func (s *ShibuyaAPI) collectionDeploymentHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand All @@ -644,7 +621,7 @@ func (s *ShibuyaAPI) collectionDeploymentHandler(w http.ResponseWriter, r *http.
}

func (s *ShibuyaAPI) collectionTriggerHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand All @@ -656,7 +633,7 @@ func (s *ShibuyaAPI) collectionTriggerHandler(w http.ResponseWriter, r *http.Req
}

func (s *ShibuyaAPI) collectionTermHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand All @@ -668,7 +645,7 @@ func (s *ShibuyaAPI) collectionTermHandler(w http.ResponseWriter, r *http.Reques
}

func (s *ShibuyaAPI) collectionStatusHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand All @@ -681,7 +658,7 @@ func (s *ShibuyaAPI) collectionStatusHandler(w http.ResponseWriter, r *http.Requ
}

func (s *ShibuyaAPI) collectionPurgeHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand Down Expand Up @@ -714,7 +691,7 @@ func (s *ShibuyaAPI) planLogHandler(w http.ResponseWriter, r *http.Request, para
}

func (s *ShibuyaAPI) streamCollectionMetrics(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
collection, err := checkCollectionOwnership(r, params)
collection, err := hasCollectionOwnership(r, params)
if err != nil {
s.handleErrors(w, err)
return
Expand Down Expand Up @@ -789,7 +766,7 @@ type Route struct {
type Routes []*Route

func (s *ShibuyaAPI) InitRoutes() Routes {
return Routes{
routes := Routes{
&Route{"get_projects", "GET", "/api/projects", s.projectsGetHandler},
&Route{"create_project", "POST", "/api/projects", s.projectCreateHandler},
&Route{"delete_project", "DELETE", "/api/projects/:project_id", s.projectDeleteHandler},
Expand Down Expand Up @@ -833,4 +810,12 @@ func (s *ShibuyaAPI) InitRoutes() Routes {

&Route{"admin_collections", "GET", "/api/admin/collections", s.collectionAdminGetHandler},
}
for _, r := range routes {
// TODO! We don't require auth for usage endpoint for now.
if strings.Contains(r.Path, "usage") {
continue
}
r.HandlerFunc = s.authRequired(r.HandlerFunc)
}
return routes
}
40 changes: 40 additions & 0 deletions shibuya/api/middlewares.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package api

import (
"context"
"errors"
"net/http"

"github.com/julienschmidt/httprouter"
"github.com/rakutentech/shibuya/shibuya/model"
)

const (
accountKey = "account"
)

func authWithSession(r *http.Request) (*model.Account, error) {
account := model.GetAccountBySession(r)
if account == nil {
return nil, makeLoginError()
}
return account, nil
}

// TODO add JWT token auth in the future
func authWithToken(_ *http.Request) (*model.Account, error) {
return nil, errors.New("No token presented")
}

func (s *ShibuyaAPI) authRequired(next httprouter.Handle) httprouter.Handle {
return httprouter.Handle(func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
var account *model.Account
var err error
account, err = authWithSession(r)
if err != nil {
s.handleErrors(w, err)
return
}
next(w, r.WithContext(context.WithValue(r.Context(), accountKey, account)), params)
})
}
14 changes: 14 additions & 0 deletions shibuya/api/networkutils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package api

import (
"net/http"
"strings"
)

func retrieveClientIP(r *http.Request) string {
t := r.Header.Get("x-forwarded-for")
if t == "" {
return r.RemoteAddr
}
return strings.Split(t, ",")[0]
}
33 changes: 33 additions & 0 deletions shibuya/api/ownership.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package api

import (
"net/http"

"github.com/julienschmidt/httprouter"
"github.com/rakutentech/shibuya/shibuya/model"
)

func hasProjectOwnership(project *model.Project, account *model.Account) bool {
if _, ok := account.MLMap[project.Owner]; !ok {
if !account.IsAdmin() {
return false
}
}
return true
}

func hasCollectionOwnership(r *http.Request, params httprouter.Params) (*model.Collection, error) {
collection, err := getCollection(params.ByName("collection_id"))
if err != nil {
return nil, err
}
account := r.Context().Value(accountKey).(*model.Account)
project, err := model.GetProject(collection.ProjectID)
if err != nil {
return nil, err
}
if r := hasProjectOwnership(project, account); !r {
return nil, makeCollectionOwnershipError()
}
return collection, nil
}
Loading

0 comments on commit c0b7c40

Please sign in to comment.