Skip to content

Commit

Permalink
Merge pull request #5 from LyricTian/dev
Browse files Browse the repository at this point in the history
优化相关功能
  • Loading branch information
LyricTian committed Jun 4, 2016
2 parents 4877b08 + 2806dae commit 8f69df4
Show file tree
Hide file tree
Showing 19 changed files with 329 additions and 119 deletions.
29 changes: 21 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ Golang OAuth 2.0协议实现
[![GoDoc](https://godoc.org/gopkg.in/oauth2.v1?status.svg)](https://godoc.org/gopkg.in/oauth2.v1)
[![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v1)](https://goreportcard.com/report/gopkg.in/oauth2.v1)

> 基于Golang实现的OAuth 2.0协议相关操作,包括:令牌(或授权码)的生成、存储、验证操作以及更新令牌、废除令牌; 具有简单、灵活的特点; 其中所涉及的相关http请求操作在这里不做处理; 支持授权码模式、简化模式、密码模式、客户端模式; 默认使用MongoDB存储相关信息
获取
----

Expand All @@ -16,7 +14,7 @@ $ go get -v gopkg.in/oauth2.v1
范例
----

> 数据初始化:初始化相关的客户端信息
> 使用之前,初始化客户端信息
```go
package main
Expand All @@ -28,49 +26,64 @@ import (
)

func main() {
mongoConfig := oauth2.NewMongoConfig("mongodb://127.0.0.1:27017", "test")
// 初始化配置参数
ocfg := &oauth2.OAuthConfig{
ACConfig: &oauth2.ACConfig{
ATExpiresIn: 60 * 60 * 24,
},
}
mcfg := oauth2.NewMongoConfig("mongodb://127.0.0.1:27017", "test")

// 创建默认的OAuth2管理实例(基于MongoDB)
manager, err := oauth2.CreateDefaultOAuthManager(mongoConfig, "", "", nil)
manager, err := oauth2.NewDefaultOAuthManager(ocfg, mcfg, "xxx", "xxx")
if err != nil {
panic(err)
}
manager.SetACGenerate(oauth2.NewDefaultACGenerate())
manager.SetACStore(oauth2.NewACMemoryStore(0))

// 模拟授权码模式
// 使用默认参数,生成授权码
code, err := manager.GetACManager().
GenerateCode("clientID_x", "userID_x", "http://www.example.com/cb", "scopes")
if err != nil {
panic(err)
}

// 生成访问令牌及更新令牌
genToken, err := manager.GetACManager().
GenerateToken(code, "http://www.example.com/cb", "clientID_x", "clientSecret_x", true)
if err != nil {
panic(err)
}

// 检查访问令牌
checkToken, err := manager.CheckAccessToken(genToken.AccessToken)
if err != nil {
panic(err)
}

// TODO: 使用用户标识、申请的授权范围响应数据
fmt.Println(checkToken.UserID, checkToken.Scope)
// 申请一个新的访问令牌

// 更新令牌
newToken, err := manager.RefreshAccessToken(checkToken.RefreshToken, "scopes")
if err != nil {
panic(err)
}
fmt.Println(newToken.AccessToken, newToken.ATExpiresIn)
// TODO: 将新的访问令牌响应给客户端

}
```

执行测试
----
-------

```bash
$ go test -v
#
$ goconvey --port=9090
$ goconvey -port=9090
```

License
Expand Down
7 changes: 2 additions & 5 deletions authorizationCode.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package oauth2
import (
"time"

"gopkg.in/LyricTian/lib.v2"
"github.com/LyricTian/go.uuid"
)

// NewACManager 创建授权码模式管理实例
Expand All @@ -13,9 +13,6 @@ func NewACManager(oaManager *OAuthManager, config *ACConfig) *ACManager {
if config == nil {
config = new(ACConfig)
}
if config.RandomCodeLen == 0 {
config.RandomCodeLen = DefaultRandomCodeLen
}
if config.ACExpiresIn == 0 {
config.ACExpiresIn = DefaultACExpiresIn
}
Expand Down Expand Up @@ -53,7 +50,7 @@ func (am *ACManager) GenerateCode(clientID, userID, redirectURI, scopes string)
UserID: userID,
RedirectURI: redirectURI,
Scope: scopes,
Code: lib.NewRandom(am.config.RandomCodeLen).NumberAndLetter(),
Code: uuid.NewV4().String(),
CreateAt: time.Now().Unix(),
ExpiresIn: time.Duration(am.config.ACExpiresIn) * time.Second,
}
Expand Down
33 changes: 16 additions & 17 deletions authorizationCodeGenerate.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"strconv"
"strings"

"github.com/LyricTian/go.uuid"

"gopkg.in/LyricTian/lib.v2"
)

Expand All @@ -30,31 +32,28 @@ func NewDefaultACGenerate() ACGenerate {
// ACGenerateDefault 默认的授权码生成方式
type ACGenerateDefault struct{}

func (ag *ACGenerateDefault) genToken(info *ACInfo) (string, error) {
var buf bytes.Buffer
_, _ = buf.WriteString(info.ClientID)
_ = buf.WriteByte('_')
func (ag *ACGenerateDefault) genCode(info *ACInfo) (string, error) {
ns, _ := uuid.FromString(info.Code)
buf := bytes.NewBuffer(uuid.NewV3(ns, info.ClientID).Bytes())
_, _ = buf.WriteString(info.UserID)
_ = buf.WriteByte('\n')
_, _ = buf.WriteString(strconv.FormatInt(info.CreateAt, 10))
_ = buf.WriteByte('\n')
_, _ = buf.WriteString(info.Code)

md5Val, err := lib.NewEncryption(buf.Bytes()).MD5()
if err != nil {
return "", err
}
buf.Reset()
md5Val = md5Val[:15]

return md5Val, nil
}

// Code Authorization code
func (ag *ACGenerateDefault) Code(info *ACInfo) (string, error) {
tokenVal, err := ag.genToken(info)
codeVal, err := ag.genCode(info)
if err != nil {
return "", err
}
val := base64.URLEncoding.EncodeToString([]byte(tokenVal + "." + strconv.FormatInt(info.ID, 10)))
val := base64.URLEncoding.EncodeToString([]byte(codeVal + "." + strconv.FormatInt(info.ID, 10)))
return strings.TrimRight(val, "="), nil
}

Expand All @@ -64,20 +63,20 @@ func (ag *ACGenerateDefault) parse(code string) (id int64, token string, err err
codeLen = 4 - codeLen
}
code = code + strings.Repeat("=", codeLen)
codeVal, err := base64.URLEncoding.DecodeString(code)
codeBV, err := base64.URLEncoding.DecodeString(code)
if err != nil {
return
}
tokenVal := strings.SplitN(string(codeVal), ".", 2)
if len(tokenVal) != 2 {
codeVal := strings.SplitN(string(codeBV), ".", 2)
if len(codeVal) != 2 {
err = errors.New("Token is invalid")
return
}
id, err = strconv.ParseInt(tokenVal[1], 10, 64)
id, err = strconv.ParseInt(codeVal[1], 10, 64)
if err != nil {
return
}
token = tokenVal[0]
token = codeVal[0]
return
}

Expand All @@ -93,9 +92,9 @@ func (ag *ACGenerateDefault) Verify(code string, info *ACInfo) (valid bool, err
if err != nil {
return
}
tokenVal, err := ag.genToken(info)
codeVal, err := ag.genCode(info)
if err != nil {
return
}
return token == tokenVal, nil
return token == codeVal, nil
}
7 changes: 3 additions & 4 deletions authorizationCodeGenerate_test.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
package oauth2_test
package oauth2

import (
"testing"
"time"

"gopkg.in/LyricTian/lib.v2"
"gopkg.in/oauth2.v1"

. "github.com/smartystreets/goconvey/convey"
)

func TestACGenerate(t *testing.T) {
Convey("Authorization code generate test", t, func() {
acGenerate := oauth2.NewDefaultACGenerate()
info := &oauth2.ACInfo{
acGenerate := NewDefaultACGenerate()
info := &ACInfo{
ID: 1,
ClientID: "123456",
UserID: "999999",
Expand Down
20 changes: 12 additions & 8 deletions authorizationCodeMemoryStore_test.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,40 @@
package oauth2_test
package oauth2

import (
"testing"
"time"

"gopkg.in/oauth2.v1"

. "github.com/smartystreets/goconvey/convey"
)

func TestACMemoryStore(t *testing.T) {
Convey("AC memory store test", t, func() {
store := oauth2.NewACMemoryStore(1)
item := oauth2.ACInfo{
store := NewACMemoryStore(1)
item := ACInfo{
ClientID: "123456",
UserID: "999999",
CreateAt: time.Now().Unix(),
ExpiresIn: time.Millisecond * 500,
}

Convey("Put Test", func() {
id, err := store.Put(item)
So(err, ShouldBeNil)
So(id, ShouldEqual, 1)
item.ID = id
So(id, ShouldBeGreaterThan, 0)
Convey("Take Test", func() {
info, err := store.TakeByID(id)
So(err, ShouldBeNil)
So(info.ClientID, ShouldEqual, item.ClientID)
So(info.UserID, ShouldEqual, item.UserID)
})
})

Convey("GC Test", func() {
id, err := store.Put(item)
So(err, ShouldBeNil)
So(id, ShouldBeGreaterThan, 0)
Convey("Take GC Test", func() {
time.Sleep(time.Second * 2)
time.Sleep(time.Millisecond * 1500)
info, err := store.TakeByID(id)
So(err, ShouldNotBeNil)
So(info, ShouldBeNil)
Expand Down
89 changes: 89 additions & 0 deletions authorizationCodeRedisStore.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package oauth2

import (
"encoding/json"
"fmt"

"gopkg.in/redis.v3"
)

const (
// DefaultACRedisIDKey Redis存储授权码唯一标识的键
DefaultACRedisIDKey = "ACID"
)

// NewACRedisStore 创建Redis存储的实例
// config Redis配置参数
// key Redis存储授权码唯一标识的键(默认为ACID)
func NewACRedisStore(cfg *RedisConfig, key string) (*ACRedisStore, error) {
opt := &redis.Options{
Network: cfg.Network,
Addr: cfg.Addr,
Password: cfg.Password,
DB: cfg.DB,
MaxRetries: cfg.MaxRetries,
DialTimeout: cfg.DialTimeout,
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
PoolSize: cfg.PoolSize,
PoolTimeout: cfg.PoolTimeout,
}
cli := redis.NewClient(opt)
err := cli.Ping().Err()
if err != nil {
return nil, err
}
if key == "" {
key = DefaultACRedisIDKey
}
return &ACRedisStore{
cli: cli,
key: key,
}, nil
}

// ACRedisStore 提供授权码的redis存储
type ACRedisStore struct {
cli *redis.Client
key string
}

// Put 存储授权码
func (ar *ACRedisStore) Put(item ACInfo) (id int64, err error) {
n, err := ar.cli.Incr(ar.key).Result()
if err != nil {
return
}
item.ID = n
jv, err := json.Marshal(item)
if err != nil {
return
}
key := fmt.Sprintf("%s_%d", ar.key, n)
err = ar.cli.Set(key, string(jv), item.ExpiresIn).Err()
if err != nil {
return
}
id = item.ID
return
}

// TakeByID 取出授权码
func (ar *ACRedisStore) TakeByID(id int64) (info *ACInfo, err error) {
key := fmt.Sprintf("%s_%d", ar.key, id)
data, err := ar.cli.Get(key).Result()
if err != nil {
return
}
var v ACInfo
err = json.Unmarshal([]byte(data), &v)
if err != nil {
return
}
err = ar.cli.Del(key).Err()
if err != nil {
return
}
info = &v
return
}
Loading

0 comments on commit 8f69df4

Please sign in to comment.