Skip to content

Commit

Permalink
Support freezing media on the unauthenticated endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
turt2live committed Jul 29, 2024
1 parent 3272722 commit f72092a
Show file tree
Hide file tree
Showing 19 changed files with 244 additions and 41 deletions.
9 changes: 9 additions & 0 deletions api/_apimeta/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ type ServerInfo struct {
ServerName string
}

type AuthContext struct {
User UserInfo
Server ServerInfo
}

func (a AuthContext) IsAuthenticated() bool {
return a.User.UserId != "" || a.Server.ServerName != ""
}

func GetRequestUserAdminStatus(r *http.Request, rctx rcontext.RequestContext, user UserInfo) (bool, bool) {
isGlobalAdmin := util.IsGlobalAdmin(user.UserId) || user.IsShared
isLocalAdmin, err := matrix.IsUserAdmin(rctx, r.Host, user.AccessToken, r.RemoteAddr)
Expand Down
33 changes: 24 additions & 9 deletions api/r0/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ import (
"github.com/t2bot/matrix-media-repo/common/rcontext"
)

func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
func DownloadMediaUser(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
return DownloadMedia(r, rctx, _apimeta.AuthContext{User: user})
}

func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, auth _apimeta.AuthContext) interface{} {
server := _routers.GetParam("server", r)
mediaId := _routers.GetParam("mediaId", r)
filename := _routers.GetParam("filename", r)
Expand Down Expand Up @@ -61,28 +65,39 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
}

rctx = rctx.LogWithFields(logrus.Fields{
"mediaId": mediaId,
"server": server,
"filename": filename,
"allowRemote": downloadRemote,
"allowRedirect": canRedirect,
"mediaId": mediaId,
"server": server,
"filename": filename,
"allowRemote": downloadRemote,
"allowRedirect": canRedirect,
"authUserId": auth.User.UserId,
"authServerName": auth.Server.ServerName,
})

if !util.IsGlobalAdmin(user.UserId) && util.IsHostIgnored(server) {
rctx.Log.Warn("Request blocked due to domain being ignored.")
return _responses.MediaBlocked()
if auth.User.UserId != "" {
if !util.IsGlobalAdmin(auth.User.UserId) && util.IsHostIgnored(server) {
rctx.Log.Warn("Request blocked due to domain being ignored.")
return _responses.MediaBlocked()
}
}

media, stream, err := pipeline_download.Execute(rctx, server, mediaId, pipeline_download.DownloadOpts{
FetchRemoteIfNeeded: downloadRemote,
BlockForReadUntil: blockFor,
CanRedirect: canRedirect,
RecordOnly: recordOnly,
AuthProvided: auth.IsAuthenticated(),
})
if err != nil {
var redirect datastores.RedirectError
if errors.Is(err, common.ErrMediaNotFound) {
return _responses.NotFoundError()
} else if errors.Is(err, common.ErrRestrictedAuth) {
return _responses.ErrorResponse{
Code: common.ErrCodeNotFound,
Message: "authentication is required to download this media",
InternalCode: common.ErrCodeUnauthorized,
}
} else if errors.Is(err, common.ErrMediaTooLarge) {
return _responses.RequestTooLarge()
} else if errors.Is(err, common.ErrRateLimitExceeded) {
Expand Down
31 changes: 23 additions & 8 deletions api/r0/thumbnail.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ import (
"github.com/t2bot/matrix-media-repo/common/rcontext"
)

func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
func ThumbnailMediaUser(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
return ThumbnailMedia(r, rctx, _apimeta.AuthContext{User: user})
}

func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, auth _apimeta.AuthContext) interface{} {
server := _routers.GetParam("server", r)
mediaId := _routers.GetParam("mediaId", r)
allowRemote := r.URL.Query().Get("allow_remote")
Expand Down Expand Up @@ -55,15 +59,19 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
}

rctx = rctx.LogWithFields(logrus.Fields{
"mediaId": mediaId,
"server": server,
"allowRemote": downloadRemote,
"allowRedirect": canRedirect,
"mediaId": mediaId,
"server": server,
"allowRemote": downloadRemote,
"allowRedirect": canRedirect,
"authUserId": auth.User.UserId,
"authServerName": auth.Server.ServerName,
})

if !util.IsGlobalAdmin(user.UserId) && util.IsHostIgnored(server) {
rctx.Log.Warn("Request blocked due to domain being ignored.")
return _responses.MediaBlocked()
if auth.User.UserId != "" {
if !util.IsGlobalAdmin(auth.User.UserId) && util.IsHostIgnored(server) {
rctx.Log.Warn("Request blocked due to domain being ignored.")
return _responses.MediaBlocked()
}
}

widthStr := r.URL.Query().Get("width")
Expand Down Expand Up @@ -124,6 +132,7 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
BlockForReadUntil: blockFor,
RecordOnly: false, // overridden
CanRedirect: canRedirect,
AuthProvided: auth.IsAuthenticated(),
},
Width: width,
Height: height,
Expand All @@ -134,6 +143,12 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
var redirect datastores.RedirectError
if errors.Is(err, common.ErrMediaNotFound) {
return _responses.NotFoundError()
} else if errors.Is(err, common.ErrRestrictedAuth) {
return _responses.ErrorResponse{
Code: common.ErrCodeNotFound,
Message: "authentication is required to download this media",
InternalCode: common.ErrCodeUnauthorized,
}
} else if errors.Is(err, common.ErrMediaTooLarge) {
return _responses.RequestTooLarge()
} else if errors.Is(err, common.ErrRateLimitExceeded) {
Expand Down
4 changes: 2 additions & 2 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ func buildRoutes() http.Handler {
// Standard (spec) features
register([]string{"PUT"}, PrefixMedia, "upload/:server/:mediaId", mxV3, router, makeRoute(_routers.RequireAccessToken(r0.UploadMediaAsync), "upload_async", counter))
register([]string{"POST"}, PrefixMedia, "upload", mxSpecV3Transition, router, makeRoute(_routers.RequireAccessToken(r0.UploadMediaSync), "upload", counter))
downloadRoute := makeRoute(_routers.OptionalAccessToken(r0.DownloadMedia), "download", counter)
downloadRoute := makeRoute(_routers.OptionalAccessToken(r0.DownloadMediaUser), "download", counter)
register([]string{"GET", "HEAD"}, PrefixMedia, "download/:server/:mediaId/:filename", mxSpecV3Transition, router, downloadRoute)
register([]string{"GET", "HEAD"}, PrefixMedia, "download/:server/:mediaId", mxSpecV3Transition, router, downloadRoute)
register([]string{"GET"}, PrefixMedia, "thumbnail/:server/:mediaId", mxSpecV3Transition, router, makeRoute(_routers.OptionalAccessToken(r0.ThumbnailMedia), "thumbnail", counter))
register([]string{"GET"}, PrefixMedia, "thumbnail/:server/:mediaId", mxSpecV3Transition, router, makeRoute(_routers.OptionalAccessToken(r0.ThumbnailMediaUser), "thumbnail", counter))
previewUrlRoute := makeRoute(_routers.RequireAccessToken(r0.PreviewUrl), "url_preview", counter)
register([]string{"GET"}, PrefixMedia, "preview_url", mxSpecV3TransitionCS, router, previewUrlRoute)
register([]string{"GET"}, PrefixMedia, "identicon/*seed", mxR0, router, makeRoute(_routers.OptionalAccessToken(r0.Identicon), "identicon", counter))
Expand Down
7 changes: 4 additions & 3 deletions api/v1/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package v1

import (
"bytes"
"github.com/t2bot/matrix-media-repo/util/ids"
"net/http"

"github.com/t2bot/matrix-media-repo/util/ids"

"github.com/t2bot/matrix-media-repo/api/_apimeta"
"github.com/t2bot/matrix-media-repo/api/_responses"
"github.com/t2bot/matrix-media-repo/api/_routers"
Expand All @@ -16,7 +17,7 @@ import (
func ClientDownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
r.URL.Query().Set("allow_remote", "true")
r.URL.Query().Set("allow_redirect", "true")
return r0.DownloadMedia(r, rctx, user)
return r0.DownloadMedia(r, rctx, _apimeta.AuthContext{User: user})
}

func FederationDownloadMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} {
Expand All @@ -26,7 +27,7 @@ func FederationDownloadMedia(r *http.Request, rctx rcontext.RequestContext, serv
r.URL.RawQuery = query.Encode()
r = _routers.ForceSetParam("server", r.Host, r)

res := r0.DownloadMedia(r, rctx, _apimeta.UserInfo{})
res := r0.DownloadMedia(r, rctx, _apimeta.AuthContext{Server: server})
boundary, err := ids.NewUniqueId()
if err != nil {
rctx.Log.Error("Error generating boundary on response: ", err)
Expand Down
7 changes: 4 additions & 3 deletions api/v1/thumbnail.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package v1

import (
"bytes"
"github.com/t2bot/matrix-media-repo/util/ids"
"net/http"

"github.com/t2bot/matrix-media-repo/util/ids"

"github.com/t2bot/matrix-media-repo/api/_apimeta"
"github.com/t2bot/matrix-media-repo/api/_responses"
"github.com/t2bot/matrix-media-repo/api/_routers"
Expand All @@ -16,7 +17,7 @@ import (
func ClientThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
r.URL.Query().Set("allow_remote", "true")
r.URL.Query().Set("allow_redirect", "true")
return r0.ThumbnailMedia(r, rctx, user)
return r0.ThumbnailMedia(r, rctx, _apimeta.AuthContext{User: user})
}

func FederationThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} {
Expand All @@ -26,7 +27,7 @@ func FederationThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, ser
r.URL.RawQuery = query.Encode()
r = _routers.ForceSetParam("server", r.Host, r)

res := r0.ThumbnailMedia(r, rctx, _apimeta.UserInfo{})
res := r0.ThumbnailMedia(r, rctx, _apimeta.AuthContext{Server: server})
boundary, err := ids.NewUniqueId()
if err != nil {
rctx.Log.Error("Error generating boundary on response: ", err)
Expand Down
1 change: 1 addition & 0 deletions archival/entity_export.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func ExportEntityData(ctx rcontext.RequestContext, exportId string, entityId str
FetchRemoteIfNeeded: false,
BlockForReadUntil: 10 * time.Minute,
RecordOnly: false,
AuthProvided: true, // it's for an export, so assume authentication
})
if errors.Is(err, common.ErrMediaQuarantined) {
ctx.Log.Warnf("%s is quarantined and will not be included in the export", mxc)
Expand Down
17 changes: 9 additions & 8 deletions common/config/conf_main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ func NewDefaultMainConfig() MainRepoConfig {
return MainRepoConfig{
MinimumRepoConfig: NewDefaultMinimumRepoConfig(),
General: GeneralConfig{
BindAddress: "127.0.0.1",
Port: 8000,
LogDirectory: "logs",
LogColors: false,
JsonLogs: false,
LogLevel: "info",
TrustAnyForward: false,
UseForwardedHost: true,
BindAddress: "127.0.0.1",
Port: 8000,
LogDirectory: "logs",
LogColors: false,
JsonLogs: false,
LogLevel: "info",
TrustAnyForward: false,
UseForwardedHost: true,
FreezeUnauthenticatedMedia: false,
},
Database: DatabaseConfig{
Postgres: "postgres://your_username:your_password@localhost/database_name?sslmode=disable",
Expand Down
17 changes: 9 additions & 8 deletions common/config/models_main.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package config

type GeneralConfig struct {
BindAddress string `yaml:"bindAddress"`
Port int `yaml:"port"`
LogDirectory string `yaml:"logDirectory"`
LogColors bool `yaml:"logColors"`
JsonLogs bool `yaml:"jsonLogs"`
LogLevel string `yaml:"logLevel"`
TrustAnyForward bool `yaml:"trustAnyForwardedAddress"`
UseForwardedHost bool `yaml:"useForwardedHost"`
BindAddress string `yaml:"bindAddress"`
Port int `yaml:"port"`
LogDirectory string `yaml:"logDirectory"`
LogColors bool `yaml:"logColors"`
JsonLogs bool `yaml:"jsonLogs"`
LogLevel string `yaml:"logLevel"`
TrustAnyForward bool `yaml:"trustAnyForwardedAddress"`
UseForwardedHost bool `yaml:"useForwardedHost"`
FreezeUnauthenticatedMedia bool `yaml:"freezeUnauthenticatedMedia"`
}

type HomeserverConfig struct {
Expand Down
1 change: 1 addition & 0 deletions common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ var ErrAlreadyUploaded = errors.New("already uploaded")
var ErrMediaNotYetUploaded = errors.New("media not yet uploaded")
var ErrMediaDimensionsTooSmall = errors.New("media is too small dimensionally")
var ErrRateLimitExceeded = errors.New("rate limit exceeded")
var ErrRestrictedAuth = errors.New("authentication is required to download this media")
11 changes: 11 additions & 0 deletions config.sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ repo:
# See https://github.com/t2bot/matrix-media-repo/issues/202 for more information.
useForwardedHost: true

# If true, media uploaded or cached from that point forwards will require authentication in order to
# be accessed. Media uploaded or cached prior will remain accessible on the unauthenticated endpoints.
# If set to false after being set to true, media uploaded or cached while the flag was true will still
# only be accessible over authenticated endpoints, though future media will be accessible on both
# authenticated and unauthenticated media.
#
# This flag currently defaults to false. A future release, likely in August 2024, will remove this flag
# and have the same effect as it being true (always on). This flag is primarily intended for servers to
# opt-in to the behaviour early.
freezeUnauthenticatedMedia: false

# Options for dealing with federation
federation:
# On a per-host basis, the number of consecutive failures in calling the host before the
Expand Down
4 changes: 4 additions & 0 deletions database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Database struct {
Tasks *tasksTableStatements
Exports *exportsTableStatements
ExportParts *exportPartsTableStatements
RestrictedMedia *restrictedMediaTableStatements
}

var instance *Database
Expand Down Expand Up @@ -126,6 +127,9 @@ func openDatabase(connectionString string, maxConns int, maxIdleConns int) error
if d.ExportParts, err = prepareExportPartsTables(d.conn); err != nil {
return errors.New("failed to create export parts table accessor: " + err.Error())
}
if d.RestrictedMedia, err = prepareRestrictedMediaTables(d.conn); err != nil {
return errors.New("failed to create restricted media table accessor: " + err.Error())
}

instance = d
return nil
Expand Down
87 changes: 87 additions & 0 deletions database/table_restricted_media.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package database

import (
"database/sql"
"errors"

"github.com/t2bot/matrix-media-repo/common/rcontext"
)

type RestrictedCondition string

const RestrictedRequiresAuth RestrictedCondition = "io.t2bot.requires_authentication" // Internal extension

type DbRestrictedMedia struct {
Origin string
MediaId string
Condition RestrictedCondition
ConditionValue string
}

const insertRestrictedMedia = "INSERT INTO restricted_media (origin, media_id, condition_type, condition_value) VALUES ($1, $2, $3, $4);"
const updateRestrictedMedia = "UPDATE restricted_media SET condition_type = $3, condition_value = $4 WHERE origin = $1 AND media_id = $2;"
const selectRestrictedMedia = "SELECT origin, media_id, condition_type, condition_value FROM restricted_media WHERE origin = $1 AND media_id = $2;"

type restrictedMediaTableStatements struct {
insertRestrictedMedia *sql.Stmt
updateRestrictedMedia *sql.Stmt
selectRestrictedMedia *sql.Stmt
}

type restrictedMediaTableWithContext struct {
statements *restrictedMediaTableStatements
ctx rcontext.RequestContext
}

func prepareRestrictedMediaTables(db *sql.DB) (*restrictedMediaTableStatements, error) {
var err error
var stmts = &restrictedMediaTableStatements{}

if stmts.insertRestrictedMedia, err = db.Prepare(insertRestrictedMedia); err != nil {
return nil, errors.New("error preparing insertRestrictedMedia: " + err.Error())
}
if stmts.updateRestrictedMedia, err = db.Prepare(updateRestrictedMedia); err != nil {
return nil, errors.New("error preparing updateRestrictedMedia: " + err.Error())
}
if stmts.selectRestrictedMedia, err = db.Prepare(selectRestrictedMedia); err != nil {
return nil, errors.New("error preparing selectRestrictedMedia: " + err.Error())
}

return stmts, nil
}

func (s *restrictedMediaTableStatements) Prepare(ctx rcontext.RequestContext) *restrictedMediaTableWithContext {
return &restrictedMediaTableWithContext{
statements: s,
ctx: ctx,
}
}

func (s *restrictedMediaTableWithContext) Insert(origin string, mediaId string, condition RestrictedCondition, conditionValue string) error {
_, err := s.statements.insertRestrictedMedia.ExecContext(s.ctx, origin, mediaId, condition, conditionValue)
return err
}

func (s *restrictedMediaTableWithContext) Update(origin string, mediaId string, condition RestrictedCondition, conditionValue string) error {
_, err := s.statements.updateRestrictedMedia.ExecContext(s.ctx, origin, mediaId, condition, conditionValue)
return err
}

func (s *restrictedMediaTableWithContext) GetAllForId(origin string, mediaId string) ([]*DbRestrictedMedia, error) {
results := make([]*DbRestrictedMedia, 0)
rows, err := s.statements.selectRestrictedMedia.QueryContext(s.ctx, origin, mediaId)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return results, nil
}
return nil, err
}
for rows.Next() {
val := &DbRestrictedMedia{}
if err = rows.Scan(&val.Origin, &val.MediaId, &val.Condition, &val.ConditionValue); err != nil {
return nil, err
}
results = append(results, val)
}
return results, nil
}
2 changes: 2 additions & 0 deletions migrations/29_create_media_restrictions_down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DROP INDEX IF EXISTS idx_restricted_media;
DROP TABLE IF EXISTS restricted_media;
Loading

0 comments on commit f72092a

Please sign in to comment.