Skip to content

Commit

Permalink
Merge pull request #41 from Dewberry/feature-read-restriction
Browse files Browse the repository at this point in the history
Feature read restriction
  • Loading branch information
ShaneMPutnam authored Jun 17, 2024
2 parents 47ab4ad + dc2feb9 commit 2726603
Show file tree
Hide file tree
Showing 16 changed files with 652 additions and 486 deletions.
1 change: 1 addition & 0 deletions .example.env
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ S3API_SERVICE_PORT='5005'
KEYCLOAK_PUBLIC_KEYS_URL='public-keys-url-string'
AUTH_LEVEL=1 # Options: [0, 1] corresponds to [no FGAC, FGAC]. This integer value configures the initialization mode in docker-compose.
AUTH_LIMITED_WRITER_ROLE='s3_limited_writer'
AUTH_LIMITED_READER_ROLE='s3_limited_reader'

## DB for Auth:
POSTGRES_CONN_STRING='postgres://user:password@postgres:5432/db?sslmode=disable'
Expand Down
50 changes: 44 additions & 6 deletions auth/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@ import (
"os"

"github.com/labstack/gommon/log"
"github.com/lib/pq"
_ "github.com/lib/pq"
)

// Database interface abstracts database operations
type Database interface {
CheckUserPermission(userEmail, operation, s3_prefix string) bool
CheckUserPermission(userEmail, bucket, prefix string, operations []string) bool
Close() error
GetUserAccessiblePrefixes(userEmail, bucket string, operations []string) ([]string, error)
}

type PostgresDB struct {
Handle *sql.DB
}

// Initialize the database and create tables if they do not exist.
// NewPostgresDB initializes the database and creates tables if they do not exist.
func NewPostgresDB() (*PostgresDB, error) {
connString, exist := os.LookupEnv("POSTGRES_CONN_STRING")
if !exist {
Expand All @@ -41,7 +43,7 @@ func NewPostgresDB() (*PostgresDB, error) {
return pgDB, nil
}

// Creates the necessary tables in the database.
// createTables creates the necessary tables in the database.
func (db *PostgresDB) createTables() error {
createPermissionsTable := `
CREATE TABLE IF NOT EXISTS permissions (
Expand All @@ -63,21 +65,57 @@ func (db *PostgresDB) createTables() error {
return nil
}

// GetUserAccessiblePrefixes retrieves the accessible prefixes for a user.
func (db *PostgresDB) GetUserAccessiblePrefixes(userEmail, bucket string, operations []string) ([]string, error) {
query := `
WITH unnested_permissions AS (
SELECT DISTINCT unnest(allowed_s3_prefixes) AS allowed_prefix
FROM permissions
WHERE user_email = $1 AND operation = ANY($3)
)
SELECT allowed_prefix
FROM unnested_permissions
WHERE allowed_prefix LIKE $2 || '/%'
ORDER BY allowed_prefix;
`

rows, err := db.Handle.Query(query, userEmail, "/"+bucket, pq.Array(operations))
if err != nil {
return nil, fmt.Errorf("database error: %s", err)
}
defer rows.Close()

var prefixes []string
var prefix string
for rows.Next() {
if err := rows.Scan(&prefix); err != nil {
return nil, fmt.Errorf("scan error: %s", err)
}
prefixes = append(prefixes, prefix)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("row error: %s", err)
}

return prefixes, nil
}

// CheckUserPermission checks if a user has permission for a specific request.
func (db *PostgresDB) CheckUserPermission(userEmail, operation, s3_prefix string) bool {
func (db *PostgresDB) CheckUserPermission(userEmail, bucket, prefix string, operations []string) bool {
s3Prefix := fmt.Sprintf("/%s/%s", bucket, prefix)
query := `
SELECT EXISTS (
SELECT 1
FROM permissions,
UNNEST(allowed_s3_prefixes) AS allowed_prefix
WHERE user_email = $1
AND operation = $2
AND operation = ANY($2)
AND $3 LIKE allowed_prefix || '%'
);
`

var hasPermission bool
if err := db.Handle.QueryRow(query, userEmail, operation, s3_prefix).Scan(&hasPermission); err != nil {
if err := db.Handle.QueryRow(query, userEmail, pq.Array(operations), s3Prefix).Scan(&hasPermission); err != nil {
log.Errorf("error querying user permissions: %v", err)
return false
}
Expand Down
49 changes: 49 additions & 0 deletions blobstore/blobhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"os"
"strconv"
"strings"
"sync"

"github.com/Dewberry/s3api/auth"
Expand All @@ -30,6 +31,7 @@ type Config struct {
// external sources like configuration files, environment variables should go here.
AuthLevel int
LimitedWriterRoleName string
LimitedReaderRoleName string
DefaultTempPrefix string
DefaultDownloadPresignedUrlExpiration int
DefaultUploadPresignedUrlExpiration int
Expand Down Expand Up @@ -305,3 +307,50 @@ func (bh *BlobHandler) PingWithAuth(c echo.Context) error {

return c.JSON(http.StatusOK, bucketHealth)
}

func (bh *BlobHandler) GetS3ReadPermissions(c echo.Context, bucket string) ([]string, bool, int, error) {
permissions, fullAccess, err := bh.GetUserS3ReadListPermission(c, bucket)
if err != nil {
//TEMP solution before error library is implimented and string check ups become redundant
httpStatus := http.StatusInternalServerError
if strings.Contains(err.Error(), "this endpoint requires authentication information that is unavailable when authorization is disabled.") {
httpStatus = http.StatusForbidden
}
return nil, false, httpStatus, fmt.Errorf("error fetching user permissions: %s", err.Error())
}
if !fullAccess && len(permissions) == 0 {
return nil, false, http.StatusForbidden, fmt.Errorf("user does not have permission to read the %s bucket", bucket)
}
return permissions, fullAccess, http.StatusOK, nil
}

func (bh *BlobHandler) HandleCheckS3UserPermission(c echo.Context) error {
if bh.Config.AuthLevel == 0 {
log.Info("Checked user permissions successfully")
return c.JSON(http.StatusOK, true)
}
initAuth := os.Getenv("INIT_AUTH")
if initAuth == "0" {
errMsg := fmt.Errorf("this endpoint requires authentication information that is unavailable when authorization is disabled. Please enable authorization to use this functionality")
log.Error(errMsg.Error())
return c.JSON(http.StatusForbidden, errMsg.Error())
}
prefix := c.QueryParam("prefix")
bucket := c.QueryParam("bucket")
operation := c.QueryParam("operation")
claims, ok := c.Get("claims").(*auth.Claims)
if !ok {
errMsg := fmt.Errorf("could not get claims from request context")
log.Error(errMsg.Error())
return c.JSON(http.StatusInternalServerError, errMsg.Error())
}
userEmail := claims.Email
if operation == "" || prefix == "" || bucket == "" {
errMsg := fmt.Errorf("`prefix`, `operation` and 'bucket are required params")
log.Error(errMsg.Error())
return c.JSON(http.StatusUnprocessableEntity, errMsg.Error())
}
isAllowed := bh.DB.CheckUserPermission(userEmail, bucket, prefix, []string{operation})
log.Info("Checked user permissions successfully")
return c.JSON(http.StatusOK, isAllowed)
}
86 changes: 67 additions & 19 deletions blobstore/blobstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package blobstore
import (
"fmt"
"net/http"
"time"
"os"
"strings"

"github.com/Dewberry/s3api/auth"
"github.com/Dewberry/s3api/utils"
Expand All @@ -14,6 +15,7 @@ import (
)

func (s3Ctrl *S3Controller) KeyExists(bucket string, key string) (bool, error) {

_, err := s3Ctrl.S3Svc.HeadObject(&s3.HeadObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Expand All @@ -33,23 +35,23 @@ func (s3Ctrl *S3Controller) KeyExists(bucket string, key string) (bool, error) {
}

// function that will get the most recently uploaded file in a prefix
func (s3Ctrl *S3Controller) getMostRecentModTime(bucket, prefix string) (time.Time, error) {
// Initialize a time variable to store the most recent modification time
var mostRecent time.Time
// func (s3Ctrl *S3Controller) getMostRecentModTime(bucket, prefix string, permissions []string, fullAccess bool) (time.Time, error) {
// // Initialize a time variable to store the most recent modification time
// var mostRecent time.Time

// Call GetList to retrieve the list of objects with the specified prefix
response, err := s3Ctrl.GetList(bucket, prefix, false)
if err != nil {
return time.Time{}, err
}
// Iterate over the returned objects to find the most recent modification time
for _, item := range response.Contents {
if item.LastModified != nil && item.LastModified.After(mostRecent) {
mostRecent = *item.LastModified
}
}
return mostRecent, nil
}
// // Call GetList to retrieve the list of objects with the specified prefix
// response, err := s3Ctrl.GetList(bucket, prefix, false)
// if err != nil {
// return time.Time{}, err
// }
// // Iterate over the returned objects to find the most recent modification time
// for _, item := range response.Contents {
// if item.LastModified != nil && item.LastModified.After(mostRecent) {
// mostRecent = *item.LastModified
// }
// }
// return mostRecent, nil
// }

func arrayContains(a string, arr []string) bool {
for _, b := range arr {
Expand Down Expand Up @@ -80,8 +82,13 @@ func isIdenticalArray(array1, array2 []string) bool {
return true
}

func (bh *BlobHandler) CheckUserS3WritePermission(c echo.Context, bucket, key string) (int, error) {
func (bh *BlobHandler) CheckUserS3Permission(c echo.Context, bucket, prefix string, permissions []string) (int, error) {
if bh.Config.AuthLevel > 0 {
initAuth := os.Getenv("INIT_AUTH")
if initAuth == "0" {
errMsg := fmt.Errorf("this requires authentication information that is unavailable when authorization is disabled. Please enable authorization to use this functionality")
return http.StatusForbidden, errMsg
}
claims, ok := c.Get("claims").(*auth.Claims)
if !ok {
return http.StatusInternalServerError, fmt.Errorf("could not get claims from request context")
Expand All @@ -91,13 +98,54 @@ func (bh *BlobHandler) CheckUserS3WritePermission(c echo.Context, bucket, key st

// Check for required roles
isLimitedWriter := utils.StringInSlice(bh.Config.LimitedWriterRoleName, roles)
// Ensure the prefix ends with a slash
if !strings.HasSuffix(prefix, "/") {
prefix += "/"
}

// We assume if someone is limited_writer, they should never be admin or super_writer
if isLimitedWriter {
if !bh.DB.CheckUserPermission(ue, "write", fmt.Sprintf("/%s/%s", bucket, key)) {
if !bh.DB.CheckUserPermission(ue, bucket, prefix, permissions) {
return http.StatusForbidden, fmt.Errorf("forbidden")
}
}
}
return 0, nil
}

func (bh *BlobHandler) GetUserS3ReadListPermission(c echo.Context, bucket string) ([]string, bool, error) {
permissions := make([]string, 0)

if bh.Config.AuthLevel > 0 {
initAuth := os.Getenv("INIT_AUTH")
if initAuth == "0" {
errMsg := fmt.Errorf("this endpoint requires authentication information that is unavailable when authorization is disabled. Please enable authorization to use this functionality")
return permissions, false, errMsg
}
fullAccess := false
claims, ok := c.Get("claims").(*auth.Claims)
if !ok {
return permissions, fullAccess, fmt.Errorf("could not get claims from request context")
}
roles := claims.RealmAccess["roles"]

// Check if user has the limited reader role
isLimitedReader := utils.StringInSlice(bh.Config.LimitedReaderRoleName, roles)

// If user is not a limited reader, assume they have full read access
if !isLimitedReader {
fullAccess = true // Indicating full access
return permissions, fullAccess, nil
}

// If user is a limited reader, fetch specific permissions
ue := claims.Email
permissions, err := bh.DB.GetUserAccessiblePrefixes(ue, bucket, []string{"read", "write"})
if err != nil {
return permissions, fullAccess, err
}
return permissions, fullAccess, nil
}

return permissions, true, nil
}
55 changes: 40 additions & 15 deletions blobstore/buckets.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package blobstore
import (
"fmt"
"net/http"
"sort"

"github.com/aws/aws-sdk-go/service/s3"
"github.com/labstack/echo/v4"
Expand Down Expand Up @@ -75,17 +76,26 @@ func (s3Ctrl *S3Controller) ListBuckets() (*s3.ListBucketsOutput, error) {
// }

type BucketInfo struct {
ID int `json:"id"`
Name string `json:"name"`
ID int `json:"id"`
Name string `json:"name"`
CanRead bool `json:"can_read"`
}

func (bh *BlobHandler) HandleListBuckets(c echo.Context) error {
var allBuckets []BucketInfo
currentID := 1 // Initialize ID counter

bh.Mu.Lock()
for i := 0; i < len(bh.S3Controllers); i++ {
defer bh.Mu.Unlock()

// Check user's overall read access level
_, fullAccess, err := bh.GetUserS3ReadListPermission(c, "")
if err != nil {
return c.JSON(http.StatusInternalServerError, fmt.Errorf("error fetching user permissions: %s", err.Error()))
}

for _, controller := range bh.S3Controllers {
if bh.AllowAllBuckets {
result, err := bh.S3Controllers[i].ListBuckets()
result, err := controller.ListBuckets()
if err != nil {
errMsg := fmt.Errorf("error returning list of buckets, error: %s", err)
log.Error(errMsg)
Expand All @@ -95,24 +105,39 @@ func (bh *BlobHandler) HandleListBuckets(c echo.Context) error {
for _, b := range result.Buckets {
mostRecentBucketList = append(mostRecentBucketList, *b.Name)
}
if !isIdenticalArray(bh.S3Controllers[i].Buckets, mostRecentBucketList) {

bh.S3Controllers[i].Buckets = mostRecentBucketList

if !isIdenticalArray(controller.Buckets, mostRecentBucketList) {
controller.Buckets = mostRecentBucketList
}
}

// Extract the bucket names from the response and append to allBuckets
for _, bucket := range bh.S3Controllers[i].Buckets {
for i, bucket := range controller.Buckets {
canRead := fullAccess
if !fullAccess {
permissions, _, err := bh.GetUserS3ReadListPermission(c, bucket)
if err != nil {
return c.JSON(http.StatusInternalServerError, fmt.Errorf("error fetching user permissions: %s", err.Error()))
}
canRead = len(permissions) > 0
}
allBuckets = append(allBuckets, BucketInfo{
ID: currentID,
Name: bucket,
ID: i,
Name: bucket,
CanRead: canRead,
})
currentID++ // Increment the ID for the next bucket

}
}
bh.Mu.Unlock()

// Sorting allBuckets slice by CanRead true first and then by Name field alphabetically
sort.Slice(allBuckets, func(i, j int) bool {
if allBuckets[i].CanRead == allBuckets[j].CanRead {
return allBuckets[i].Name < allBuckets[j].Name
}
return allBuckets[i].CanRead && !allBuckets[j].CanRead
})

log.Info("Successfully retrieved list of buckets")

return c.JSON(http.StatusOK, allBuckets)
}

Expand Down
1 change: 1 addition & 0 deletions blobstore/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func newConfig(authLvl int) *Config {
c := &Config{
AuthLevel: authLvl,
LimitedWriterRoleName: os.Getenv("AUTH_LIMITED_WRITER_ROLE"),
LimitedReaderRoleName: os.Getenv("AUTH_LIMITED_READER_ROLE"),
DefaultTempPrefix: getEnvOrDefault("TEMP_PREFIX", defaultTempPrefix),
DefaultDownloadPresignedUrlExpiration: getIntEnvOrDefault("DOWNLOAD_URL_EXP_DAYS", defaultDownloadPresignedUrlExpiration),
DefaultUploadPresignedUrlExpiration: getIntEnvOrDefault("UPLOAD_URL_EXP_MIN", defaultUploadPresignedUrlExpiration),
Expand Down
Loading

0 comments on commit 2726603

Please sign in to comment.