Skip to content

Commit

Permalink
Merge pull request #545 from traPtitech/breaking/use-common-traq-acce…
Browse files Browse the repository at this point in the history
…ss-token

traQのAPIアクセスにサーバー設定のAccessTokenを使う
  • Loading branch information
ras0q authored Jul 9, 2024
2 parents a9a665d + 3ed7903 commit 04e012d
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 52 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ TOKEN_KEY=
KNOQ_VERSION=
KNOQ_REVISION=
DEVELOPMENT=
TRAQ_ACCESS_TOKEN=
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ debug

# for development
_development/*
.env
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ knoQ の全ての機能を動作させるためには、追加の情報が必要
| KNOQ_VERSION | 環境変数 | UNKNOWN | knoQ のバージョン (github actions でイメージ作成時に指定) |
| KNOQ_REVISION | 環境変数 | UNKNOWN | git の sha1 (github actions でイメージ作成時に指定) |
| DEVELOPMENT | 環境変数 | | 開発時かどうか |
| TRAQ_ACCESS_TOKEN | 環境変数 | | traQ へのアクセストークン |
| service.json | ファイル | 空のファイル | google calendar api に必要(権限は必要なし) |

### テスト
Expand Down
1 change: 1 addition & 0 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ services:
KNOQ_VERSION: ${KNOQ_VERSION:-dev}
DEVELOPMENT: true
GORM_LOG_LEVEL: info
TRAQ_ACCESS_TOKEN:
ports:
- "${APP_PORT:-3000}:3000"
depends_on:
Expand Down
16 changes: 9 additions & 7 deletions infra/traq/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import (
"github.com/traPtitech/go-traq"
)

func (repo *TraQRepository) GetGroup(token *oauth2.Token, groupID uuid.UUID) (*traq.UserGroup, error) {
ctx := context.TODO()
apiClient := NewAPIClient(ctx, token)
func (repo *TraQRepository) GetGroup(groupID uuid.UUID) (*traq.UserGroup, error) {
ctx := context.WithValue(context.TODO(), traq.ContextAccessToken, repo.ServerAccessToken)
apiClient := traq.NewAPIClient(traqAPIConfig)
// TODO: 一定期間キャッシュする
group, resp, err := apiClient.GroupApi.GetUserGroup(ctx, groupID.String()).Execute()
if err != nil {
return nil, err
Expand All @@ -23,9 +24,10 @@ func (repo *TraQRepository) GetGroup(token *oauth2.Token, groupID uuid.UUID) (*t
return group, err
}

func (repo *TraQRepository) GetAllGroups(token *oauth2.Token) ([]traq.UserGroup, error) {
ctx := context.TODO()
apiClient := NewAPIClient(ctx, token)
func (repo *TraQRepository) GetAllGroups() ([]traq.UserGroup, error) {
ctx := context.WithValue(context.TODO(), traq.ContextAccessToken, repo.ServerAccessToken)
apiClient := traq.NewAPIClient(traqAPIConfig)
// TODO: 一定期間キャッシュする
groups, resp, err := apiClient.GroupApi.GetUserGroups(ctx).Execute()
if err != nil {
return nil, err
Expand All @@ -39,7 +41,7 @@ func (repo *TraQRepository) GetAllGroups(token *oauth2.Token) ([]traq.UserGroup,

func (repo *TraQRepository) GetUserBelongingGroupIDs(token *oauth2.Token, userID uuid.UUID) ([]uuid.UUID, error) {
ctx := context.TODO()
apiClient := NewAPIClient(ctx, token)
apiClient := NewOauth2APIClient(ctx, token)
user, resp, err := apiClient.UserApi.GetUser(ctx, userID.String()).Execute()
if err != nil {
return nil, err
Expand Down
11 changes: 7 additions & 4 deletions infra/traq/traq.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ import (

// TraQRepository is traq
type TraQRepository struct {
Config *oauth2.Config
URL string
Config *oauth2.Config
URL string
ServerAccessToken string
}

var TraQDefaultConfig = &oauth2.Config{
Expand All @@ -29,6 +30,8 @@ var TraQDefaultConfig = &oauth2.Config{
},
}

var traqAPIConfig = traq.NewConfiguration()

func newPKCE() (pkceOptions []oauth2.AuthCodeOption, codeVerifier string) {
codeVerifier = random.AlphaNumeric(43, true)
result := sha256.Sum256([]byte(codeVerifier))
Expand Down Expand Up @@ -64,8 +67,8 @@ func (repo *TraQRepository) GetOAuthToken(query, state, codeVerifier string) (*o
return repo.Config.Exchange(ctx, code, option)
}

func NewAPIClient(ctx context.Context, token *oauth2.Token) *traq.APIClient {
traqconf := traq.NewConfiguration()
func NewOauth2APIClient(ctx context.Context, token *oauth2.Token) *traq.APIClient {
traqconf := traqAPIConfig
conf := TraQDefaultConfig
traqconf.HTTPClient = conf.Client(ctx, token)
apiClient := traq.NewAPIClient(traqconf)
Expand Down
16 changes: 9 additions & 7 deletions infra/traq/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import (
"golang.org/x/oauth2"
)

func (repo *TraQRepository) GetUser(token *oauth2.Token, userID uuid.UUID) (*traq.User, error) {
ctx := context.TODO()
apiClient := NewAPIClient(ctx, token)
func (repo *TraQRepository) GetUser(userID uuid.UUID) (*traq.User, error) {
ctx := context.WithValue(context.TODO(), traq.ContextAccessToken, repo.ServerAccessToken)
apiClient := traq.NewAPIClient(traqAPIConfig)
// TODO: 一定期間キャッシュする
userDetail, resp, err := apiClient.UserApi.GetUser(ctx, userID.String()).Execute()
if err != nil {
return nil, err
Expand All @@ -31,9 +32,10 @@ func (repo *TraQRepository) GetUser(token *oauth2.Token, userID uuid.UUID) (*tra
return &user, err
}

func (repo *TraQRepository) GetUsers(token *oauth2.Token, includeSuspended bool) ([]traq.User, error) {
ctx := context.TODO()
apiClient := NewAPIClient(ctx, token)
func (repo *TraQRepository) GetUsers(includeSuspended bool) ([]traq.User, error) {
ctx := context.WithValue(context.TODO(), traq.ContextAccessToken, repo.ServerAccessToken)
apiClient := traq.NewAPIClient(traqAPIConfig)
// TODO: 一定期間キャッシュする
users, resp, err := apiClient.UserApi.GetUsers(ctx).IncludeSuspended(includeSuspended).Execute()
if err != nil {
return nil, err
Expand All @@ -47,7 +49,7 @@ func (repo *TraQRepository) GetUsers(token *oauth2.Token, includeSuspended bool)

func (repo *TraQRepository) GetUserMe(token *oauth2.Token) (*traq.User, error) {
ctx := context.TODO()
apiClient := NewAPIClient(ctx, token)
apiClient := NewOauth2APIClient(ctx, token)
userDetail, resp, err := apiClient.MeApi.GetMe(ctx).Execute()
if err != nil {
return nil, err
Expand Down
7 changes: 6 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ var (
webhookSecret = getenv("WEBHOOK_SECRET", "")
activityChannelID = getenv("ACTIVITY_CHANNEL_ID", "")
dailyChannelID = getenv("DAILY_CHANNEL_ID", "")

// TODO: traQにClient Credential Flowが実装されたら定期的に取得するように変更する
// Issue: https://github.com/traPtitech/traQ/issues/2403
traqAccessToken = getenv("TRAQ_ACCESS_TOKEN", "")
)

func main() {
Expand All @@ -66,7 +70,8 @@ func main() {
TokenURL: "https://q.trap.jp/api/v3/oauth2/token",
},
},
URL: "https://q.trap.jp/api/v3",
URL: "https://q.trap.jp/api/v3",
ServerAccessToken: traqAccessToken,
}
repo := &repository.Repository{
GormRepo: gormRepo,
Expand Down
19 changes: 3 additions & 16 deletions repository/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@ func (repo *Repository) GetGroup(groupID uuid.UUID, info *domain.ConInfo) (*doma
}

// traq group
t, err := repo.GormRepo.GetToken(info.ReqUserID)
if err != nil {
return nil, defaultErrorHandling(err)
}
g, err := repo.TraQRepo.GetGroup(t, groupID)
g, err := repo.TraQRepo.GetGroup(groupID)
if err != nil {
return nil, defaultErrorHandling(err)
}
Expand All @@ -95,16 +91,12 @@ func (repo *Repository) GetGroup(groupID uuid.UUID, info *domain.ConInfo) (*doma

func (repo *Repository) GetAllGroups(info *domain.ConInfo) ([]*domain.Group, error) {
groups := make([]*domain.Group, 0)
t, err := repo.GormRepo.GetToken(info.ReqUserID)
if err != nil {
return nil, defaultErrorHandling(err)
}
gg, err := repo.GormRepo.GetAllGroups()
if err != nil {
return nil, defaultErrorHandling(err)
}
groups = append(groups, db.ConvSPGroupToSPdomainGroup(gg)...)
tg, err := repo.TraQRepo.GetAllGroups(t)
tg, err := repo.TraQRepo.GetAllGroups()
if err != nil {
return nil, defaultErrorHandling(err)
}
Expand Down Expand Up @@ -195,12 +187,7 @@ func (repo *Repository) getTraPGroup(info *domain.ConInfo) *domain.Group {
}

func (repo *Repository) GetGradeGroupNames(info *domain.ConInfo) ([]string, error) {
t, err := repo.GormRepo.GetToken(info.ReqUserID)
if err != nil {
return nil, defaultErrorHandling(err)
}

groups, err := repo.TraQRepo.GetAllGroups(t)
groups, err := repo.TraQRepo.GetAllGroups()
if err != nil {
return nil, defaultErrorHandling(err)
}
Expand Down
20 changes: 3 additions & 17 deletions repository/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ func (repo *Repository) SyncUsers(info *domain.ConInfo) error {
if !repo.IsPrivilege(info) {
return domain.ErrForbidden
}
t, err := repo.GormRepo.GetToken(info.ReqUserID)
if err != nil {
return defaultErrorHandling(err)
}
traQUsers, err := repo.TraQRepo.GetUsers(t, true)
traQUsers, err := repo.TraQRepo.GetUsers(true)
if err != nil {
return defaultErrorHandling(err)
}
Expand Down Expand Up @@ -92,18 +88,13 @@ func (repo *Repository) LoginUser(query, state, codeVerifier string) (*domain.Us
}

func (repo *Repository) GetUser(userID uuid.UUID, info *domain.ConInfo) (*domain.User, error) {
t, err := repo.GormRepo.GetToken(info.ReqUserID)
if err != nil {
return nil, defaultErrorHandling(err)
}

userMeta, err := repo.GormRepo.GetUser(userID)
if err != nil {
return nil, defaultErrorHandling(err)
}

if userMeta.Provider.Issuer == traQIssuerName {
userBody, err := repo.TraQRepo.GetUser(t, userID)
userBody, err := repo.TraQRepo.GetUser(userID)
if err != nil {
return nil, defaultErrorHandling(err)
}
Expand All @@ -120,17 +111,12 @@ func (repo *Repository) GetUserMe(info *domain.ConInfo) (*domain.User, error) {
}

func (repo *Repository) GetAllUsers(includeSuspend, includeBot bool, info *domain.ConInfo) ([]*domain.User, error) {
t, err := repo.GormRepo.GetToken(info.ReqUserID)
if err != nil {
return nil, defaultErrorHandling(err)
}

userMetas, err := repo.GormRepo.GetAllUsers(!includeSuspend)
if err != nil {
return nil, defaultErrorHandling(err)
}
// TODO fix
traQUserBodys, err := repo.TraQRepo.GetUsers(t, includeSuspend)
traQUserBodys, err := repo.TraQRepo.GetUsers(includeSuspend)
if err != nil {
return nil, defaultErrorHandling(err)
}
Expand Down

0 comments on commit 04e012d

Please sign in to comment.