From 3590945fe1452602e34643d13fdcfcdd1d1df7b8 Mon Sep 17 00:00:00 2001 From: caixw Date: Tue, 14 May 2024 16:39:06 +0800 Subject: [PATCH] =?UTF-8?q?refactor(cmfx):=20=E9=87=8D=E6=9E=84=E9=A1=B9?= =?UTF-8?q?=E7=9B=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server/web.xml | 6 +- cmfx/cmfx.go | 4 +- cmfx/initial/cmd/cmd.go | 4 +- cmfx/initial/cmd/config.go | 146 ++++++++-------- cmfx/modules/admin/admintest/admintest.go | 41 +++-- cmfx/modules/admin/config.go | 32 ++++ cmfx/modules/admin/config_test.go | 22 +++ cmfx/modules/admin/install.go | 20 +-- cmfx/modules/admin/install_test.go | 13 +- cmfx/modules/admin/models.go | 2 +- cmfx/modules/admin/{loader.go => module.go} | 96 +++++------ cmfx/modules/admin/route_admins.go | 62 +++---- cmfx/modules/admin/route_current.go | 24 +-- cmfx/modules/admin/route_rbac.go | 28 +-- cmfx/modules/admin/route_token.go | 31 ++-- cmfx/modules/system/install.go | 11 +- cmfx/modules/system/install_test.go | 9 +- cmfx/modules/system/linkage.go | 20 ++- cmfx/modules/system/{loader.go => module.go} | 10 +- .../system/{loader_test.go => module_test.go} | 10 +- cmfx/modules/system/route.go | 38 +++-- cmfx/modules/system/route_test.go | 12 +- cmfx/modules/system/systemtest/systemtest.go | 5 +- cmfx/user/loader_test.go | 9 - cmfx/user/models.go | 76 ++++----- cmfx/user/{loader.go => module.go} | 25 +-- cmfx/user/module_test.go | 42 +++++ cmfx/user/passport/adapter.go | 14 +- cmfx/user/passport/adaptertest/adaptertest.go | 10 ++ cmfx/user/passport/code/code.go | 65 ++++--- cmfx/user/passport/code/models.go | 2 +- cmfx/user/passport/code/sender.go | 24 ++- cmfx/user/passport/code/sender_test.go | 2 +- cmfx/user/passport/errors.go | 2 + cmfx/user/passport/oauth/oauth.go | 13 +- cmfx/user/passport/passport.go | 31 ++-- cmfx/user/passport/passport_test.go | 34 ++-- cmfx/user/passport/password/models.go | 7 +- cmfx/user/passport/password/password.go | 92 ++++++---- cmfx/user/securitylog.go | 8 +- cmfx/user/token.go | 96 ++++++----- cmfx/user/token_test.go | 159 ++++++++++++++++++ locales/locales.go | 8 + 43 files changed, 866 insertions(+), 499 deletions(-) create mode 100644 cmfx/modules/admin/config.go create mode 100644 cmfx/modules/admin/config_test.go rename cmfx/modules/admin/{loader.go => module.go} (59%) rename cmfx/modules/system/{loader.go => module.go} (90%) rename cmfx/modules/system/{loader_test.go => module_test.go} (66%) delete mode 100644 cmfx/user/loader_test.go rename cmfx/user/{loader.go => module.go} (65%) create mode 100644 cmfx/user/module_test.go create mode 100644 cmfx/user/token_test.go diff --git a/cmd/server/web.xml b/cmd/server/web.xml index 31d3e10f..6fde0b99 100644 --- a/cmd/server/web.xml +++ b/cmd/server/web.xml @@ -49,8 +49,10 @@ http://localhost:8080 - - /admin + + + /admin + diff --git a/cmfx/cmfx.go b/cmfx/cmfx.go index 6c1a5abc..a4398343 100644 --- a/cmfx/cmfx.go +++ b/cmfx/cmfx.go @@ -42,9 +42,7 @@ const ( UnauthorizedSecurityToken = "40103" // 需要强验证 UnauthorizedInvalidAccount = "40104" // 无效的账号或密码 UnauthorizedNeedChangePassword = "40105" - - // 可注册的状态,比如 OAuth2 验证,如果未注册,返回一个 ID 可用以注册。 - UnauthorizedRegistrable = "40106" + UnauthorizedRegistrable = "40106" // 可注册的状态,比如 OAuth2 验证,如果未注册,返回一个 ID 可用以注册。 ) // 403 diff --git a/cmfx/initial/cmd/cmd.go b/cmfx/initial/cmd/cmd.go index 5b14afe8..7ab56c07 100644 --- a/cmfx/initial/cmd/cmd.go +++ b/cmfx/initial/cmd/cmd.go @@ -60,8 +60,8 @@ func initServer(name, ver string, o *server.Options, user *Config, action string adminL := admin.Load(adminMod, user.Admin) system.Load(systemMod, user.System, adminL) case "install": - admin.Install(adminMod) - system.Install(systemMod) + adminL := admin.Install(adminMod, user.Admin) + system.Install(systemMod, user.System, adminL) case "upgrade": panic("not implements") default: diff --git a/cmfx/initial/cmd/config.go b/cmfx/initial/cmd/config.go index 87e9faa0..de564544 100644 --- a/cmfx/initial/cmd/config.go +++ b/cmfx/initial/cmd/config.go @@ -5,97 +5,97 @@ package cmd import ( - "github.com/issue9/orm/v6" - "github.com/issue9/orm/v6/dialect" - "github.com/issue9/web" + "github.com/issue9/orm/v6" + "github.com/issue9/orm/v6/dialect" + "github.com/issue9/web" - "github.com/issue9/cmfx/cmfx/modules/system" - "github.com/issue9/cmfx/cmfx/user" - "github.com/issue9/cmfx/locales" + "github.com/issue9/cmfx/cmfx/modules/admin" + "github.com/issue9/cmfx/cmfx/modules/system" + "github.com/issue9/cmfx/locales" ) // Config 配置文件的自定义部分内容 type Config struct { - // DB 数据库配置 - DB *DB `yaml:"db" xml:"db" json:"db"` + // DB 数据库配置 + DB *DB `yaml:"db" xml:"db" json:"db"` - // URL 路由的基地址 - URL string `yaml:"url" xml:"url" json:"url"` + // URL 路由的基地址 + URL string `yaml:"url" xml:"url" json:"url"` - // Admin 后台管理员用户的相关配置 - Admin *user.Config `yaml:"admin" xml:"admin" json:"admin"` + // Admin 后台管理员用户的相关配置 + Admin *admin.Config `yaml:"admin" xml:"admin" json:"admin"` - // System 系统模块的相关配置 - System *system.Config `yaml:"system,omitempty" xml:"system,omitempty" json:"system,omitempty"` + // System 系统模块的相关配置 + System *system.Config `yaml:"system,omitempty" xml:"system,omitempty" json:"system,omitempty"` } // DB 数据库的配置项 type DB struct { - // Prefix 表名前缀 - // - // 可以为空。 - Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty" xml:"prefix,attr,omitempty"` - - // 表示数据库的类型 - // - // 目前支持以下几种类型: - // - sqlite3 - // - sqlite 纯 Go - // - mysql - // - mariadb - // - postgres - Type string `yaml:"type" json:"type" xml:"type,attr"` - - // 连接数据库的参数 - DSN string `yaml:"dsn" json:"dsn" xml:"dsn"` - - db *orm.DB + // Prefix 表名前缀 + // + // 可以为空。 + Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty" xml:"prefix,attr,omitempty"` + + // 表示数据库的类型 + // + // 目前支持以下几种类型: + // - sqlite3 + // - sqlite 纯 Go + // - mysql + // - mariadb + // - postgres + Type string `yaml:"type" json:"type" xml:"type,attr"` + + // 连接数据库的参数 + DSN string `yaml:"dsn" json:"dsn" xml:"dsn"` + + db *orm.DB } func (c *Config) SanitizeConfig() *web.FieldError { - if err := c.DB.SanitizeConfig(); err != nil { - return err.AddFieldParent("db") - } - - if err := c.Admin.SanitizeConfig(); err != nil { - return err.AddFieldParent("admin") - } - - if c.System != nil { - if err := c.System.SanitizeConfig(); err != nil { - return err.AddFieldParent("system") - } - } - - return nil + if err := c.DB.SanitizeConfig(); err != nil { + return err.AddFieldParent("db") + } + + if err := c.Admin.SanitizeConfig(); err != nil { + return err.AddFieldParent("admin") + } + + if c.System != nil { + if err := c.System.SanitizeConfig(); err != nil { + return err.AddFieldParent("system") + } + } + + return nil } func (conf *DB) SanitizeConfig() *web.FieldError { - var d orm.Dialect - switch conf.Type { - case "sqlite3": - d = dialect.Sqlite3("sqlite3") - case "sqlite": - d = dialect.Sqlite3("sqlite") - case "mysql": - d = dialect.Mysql("mysql") - case "mariadb": - d = dialect.Mariadb("mysql") - case "postgres": - d = dialect.Postgres("postgres") - default: - err := web.NewFieldError("type", locales.InvalidValue) - err.Value = conf.Type - return err - } - - db, err := orm.NewDB(conf.Prefix, conf.DSN, d) - if err != nil { - return web.NewFieldError("", err) - } - conf.db = db - - return nil + var d orm.Dialect + switch conf.Type { + case "sqlite3": + d = dialect.Sqlite3("sqlite3") + case "sqlite": + d = dialect.Sqlite3("sqlite") + case "mysql": + d = dialect.Mysql("mysql") + case "mariadb": + d = dialect.Mariadb("mysql") + case "postgres": + d = dialect.Postgres("postgres") + default: + err := web.NewFieldError("type", locales.InvalidValue) + err.Value = conf.Type + return err + } + + db, err := orm.NewDB(conf.Prefix, conf.DSN, d) + if err != nil { + return web.NewFieldError("", err) + } + conf.db = db + + return nil } // DB 返回根据配置项生成的 [orm.DB] 实例 diff --git a/cmfx/modules/admin/admintest/admintest.go b/cmfx/modules/admin/admintest/admintest.go index fde3a1f4..a7a2fc6c 100644 --- a/cmfx/modules/admin/admintest/admintest.go +++ b/cmfx/modules/admin/admintest/admintest.go @@ -5,26 +5,47 @@ package admintest import ( + "encoding/json" + "net/http" + + "github.com/issue9/assert/v4" + "github.com/issue9/mux/v8/header" + "github.com/issue9/webuse/v7/middlewares/auth/token" + "github.com/issue9/cmfx/cmfx/initial/test" "github.com/issue9/cmfx/cmfx/modules/admin" "github.com/issue9/cmfx/cmfx/user" ) -func NewAdmin(s *test.Suite) *admin.Loader { +// NewAdmin 声明一个用于测试的 [admin.Module] 实例 +func NewAdmin(s *test.Suite) *admin.Module { mod := s.NewModule("admin") - admin.Install(mod) - o := &user.Config{ - URLPrefix: "/admin", - AccessExpired: 60, - RefreshExpired: 120, + o := &admin.Config{ + SuperUser: 1, + User: &user.Config{ + URLPrefix: "/admin", + AccessExpired: 60, + RefreshExpired: 120, + }, } s.Assertion().NotError(o.SanitizeConfig()) - a := admin.Load(mod, o) - s.Assertion().NotNil(a) + loader := admin.Install(mod, o) + s.Assertion().NotNil(loader) + + return loader +} - // TODO 返回用于测试的令牌 +// GetToken 获得后台的访问令牌 +func GetToken(s *test.Suite, loader *admin.Module) string { + r := &token.Response{} + s.Post(loader.URLPrefix()+"/login?type=password", []byte(`{"username":"admin","password":"123"}`)). + Header(header.ContentType, header.JSON+";charset=utf-8"). + Header(header.Accept, header.JSON). + Do(nil). + Status(http.StatusCreated). + BodyFunc(func(a *assert.Assertion, body []byte) { a.NotError(json.Unmarshal(body, r)) }) - return a + return r.AccessToken } diff --git a/cmfx/modules/admin/config.go b/cmfx/modules/admin/config.go new file mode 100644 index 00000000..fa3a9047 --- /dev/null +++ b/cmfx/modules/admin/config.go @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: 2024 caixw +// +// SPDX-License-Identifier: MIT + +package admin + +import ( + "github.com/issue9/web" + + "github.com/issue9/cmfx/cmfx/user" + "github.com/issue9/cmfx/locales" +) + +type Config struct { + // SuperUser 超级用户的 ID + SuperUser int64 `json:"superUser" xml:"superUser,attr" yaml:"superUser"` + + // User 用户相关的配置 + User *user.Config `json:"user" xml:"user" yaml:"user"` +} + +func (c *Config) SanitizeConfig() *web.FieldError { + if c.SuperUser <= 0 { + return web.NewFieldError("superUser", locales.MustBeGreaterThan(0)) + } + + if err := c.User.SanitizeConfig(); err != nil { + return err.AddFieldParent("user") + } + + return nil +} diff --git a/cmfx/modules/admin/config_test.go b/cmfx/modules/admin/config_test.go new file mode 100644 index 00000000..0865e1df --- /dev/null +++ b/cmfx/modules/admin/config_test.go @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2024 caixw +// +// SPDX-License-Identifier: MIT + +package admin + +import ( + "testing" + + "github.com/issue9/assert/v4" + "github.com/issue9/config" +) + +var _ config.Sanitizer = &Config{} + +func TestConfig_SanitizeConfig(t *testing.T) { + a := assert.New(t, false) + + conf := &Config{} + err := conf.SanitizeConfig() + a.Equal(err.Field, "superUser") +} diff --git a/cmfx/modules/admin/install.go b/cmfx/modules/admin/install.go index 787580c4..18725675 100644 --- a/cmfx/modules/admin/install.go +++ b/cmfx/modules/admin/install.go @@ -16,7 +16,7 @@ import ( "github.com/issue9/cmfx/cmfx/user/rbac" ) -func Install(mod *cmfx.Module) { +func Install(mod *cmfx.Module, o *Config) *Module { user.Install(mod) password.Install(mod) rbac.Install(mod) @@ -25,19 +25,15 @@ func Install(mod *cmfx.Module) { panic(web.SprintError(mod.Server().Locale().Printer(), true, err)) } - r := rbac.New(mod, nil) - a, err := r.NewRoleGroup(mod.ID(), 0) - if err != nil { - panic(web.SprintError(mod.Server().Locale().Printer(), true, err)) - } + l := Load(mod, o) - if _, err = a.NewRole("管理员", "拥有超级权限", ""); err != nil { + if _, err := l.newRole("管理员", "拥有超级权限", ""); err != nil { panic(web.SprintError(mod.Server().Locale().Printer(), true, err)) } - if _, err = a.NewRole("财务", "财务", ""); err != nil { + if _, err := l.newRole("财务", "财务", ""); err != nil { panic(web.SprintError(mod.Server().Locale().Printer(), true, err)) } - if _, err = a.NewRole("编辑", "仅有编辑文章的相关权限", ""); err != nil { + if _, err := l.newRole("编辑", "仅有编辑文章的相关权限", ""); err != nil { panic(web.SprintError(mod.Server().Locale().Printer(), true, err)) } @@ -87,11 +83,11 @@ func Install(mod *cmfx.Module) { }, } - p := password.New(mod, 11) - for _, u := range us { - if err := newAdmin(mod, p, u, time.Now()); err != nil { + if err := l.newAdmin(l.password, u, time.Now()); err != nil { panic(web.SprintError(mod.Server().Locale().Printer(), true, err)) } } + + return l } diff --git a/cmfx/modules/admin/install_test.go b/cmfx/modules/admin/install_test.go index ab03f315..ed26a220 100644 --- a/cmfx/modules/admin/install_test.go +++ b/cmfx/modules/admin/install_test.go @@ -10,6 +10,7 @@ import ( "github.com/issue9/assert/v4" "github.com/issue9/cmfx/cmfx/initial/test" + "github.com/issue9/cmfx/cmfx/user" ) func TestInstall(t *testing.T) { @@ -18,7 +19,17 @@ func TestInstall(t *testing.T) { defer suite.Close() mod := suite.NewModule("test") - Install(mod) + o := &Config{ + SuperUser: 1, + User: &user.Config{ + URLPrefix: "/admin", + AccessExpired: 60, + RefreshExpired: 120, + }, + } + suite.Assertion().NotError(o.SanitizeConfig()) + l := Install(mod, o) + a.NotNil(l) suite.TableExists(mod.ID() + "_info") } diff --git a/cmfx/modules/admin/models.go b/cmfx/modules/admin/models.go index d1fd07ba..eebf8def 100644 --- a/cmfx/modules/admin/models.go +++ b/cmfx/modules/admin/models.go @@ -28,7 +28,7 @@ type modelInfo struct { } type respInfo struct { - m *Loader + m *Module XMLName struct{} `xml:"info" json:"-" cbor:"-"` diff --git a/cmfx/modules/admin/loader.go b/cmfx/modules/admin/module.go similarity index 59% rename from cmfx/modules/admin/loader.go rename to cmfx/modules/admin/module.go index 8989a0bc..aa3101a3 100644 --- a/cmfx/modules/admin/loader.go +++ b/cmfx/modules/admin/module.go @@ -7,7 +7,6 @@ package admin import ( "context" - "errors" "time" "github.com/issue9/events" @@ -22,12 +21,12 @@ import ( ) const ( + passwordID = "password" defaultPassword = "123" - SystemID = 0 // 表示系统的 ID ) -type Loader struct { - user *user.Loader +type Module struct { + user *user.Module password passport.Adapter roleGroup *rbac.RoleGroup @@ -40,15 +39,14 @@ type Loader struct { // Load 声明 Admin 对象 // // o 表示初始化的一些额外选项,这些值可以直接从配置文件中加载; -func Load(mod *cmfx.Module, o *user.Config) *Loader { +func Load(mod *cmfx.Module, o *Config) *Module { loadProblems(mod.Server()) - u := user.Load(mod, o) + u := user.Load(mod, o.User) - pp := passport.New(mod) pass := password.New(mod, 8) - pp.Register("password", pass, web.StringPhrase("password mode")) - m := &Loader{ + u.Passport().Register(passwordID, pass, web.StringPhrase("password mode")) + m := &Module{ user: u, password: pass, @@ -64,7 +62,7 @@ func Load(mod *cmfx.Module, o *user.Config) *Loader { } return u.ID, nil }) - rg, err := inst.NewRoleGroup("0", SystemID) + rg, err := inst.NewRoleGroup("0", o.SuperUser) if err != nil { panic(web.SprintError(mod.Server().Locale().Printer(), true, err)) } @@ -82,10 +80,10 @@ func Load(mod *cmfx.Module, o *user.Config) *Loader { mod.Router().Prefix(m.URLPrefix()). Post("/login", m.postLogin). - Delete("/login", m.AuthMiddleware(m.deleteLogin)). - Get("/token", m.AuthMiddleware(m.getToken)) + Delete("/login", m.Middleware(m.deleteLogin)). + Put("/login", m.Middleware(m.getToken)) - mod.Router().Prefix(m.URLPrefix(), web.MiddlewareFunc(m.AuthMiddleware)). + mod.Router().Prefix(m.URLPrefix(), web.MiddlewareFunc(m.Middleware)). Get("/resources", m.getResources). Get("/roles", m.getRoles). Post("/roles", postGroup(m.postRoles)). @@ -94,13 +92,13 @@ func Load(mod *cmfx.Module, o *user.Config) *Loader { Get("/roles/{id:digit}/resources", m.getRoleResources). Put("/roles/{id:digit}/resources", putGroupResources(m.putRoleResources)) - mod.Router().Prefix(m.URLPrefix(), web.MiddlewareFunc(m.AuthMiddleware)). + mod.Router().Prefix(m.URLPrefix(), web.MiddlewareFunc(m.Middleware)). Get("/info", m.getInfo). Patch("/info", m.patchInfo). Get("/securitylog", m.getSecurityLogs). Put("/password", m.putCurrentPassword) - mod.Router().Prefix(m.URLPrefix(), web.MiddlewareFunc(m.AuthMiddleware)). + mod.Router().Prefix(m.URLPrefix(), web.MiddlewareFunc(m.Middleware)). Get("/admins", getAdmin(m.getAdmins)). Post("/admins", postAdmin(m.postAdmins)). Get("/admins/{id:digit}", getAdmin(m.getAdmin)). @@ -113,81 +111,63 @@ func Load(mod *cmfx.Module, o *user.Config) *Loader { return m } -func (l *Loader) URLPrefix() string { return l.user.URLPrefix() } +func (m *Module) URLPrefix() string { return m.user.URLPrefix() } -// AuthMiddleware 验证是否登录 -func (l *Loader) AuthMiddleware(next web.HandlerFunc) web.HandlerFunc { - return l.user.Middleware(next) -} +// Middleware 验证是否登录 +func (m *Module) Middleware(next web.HandlerFunc) web.HandlerFunc { return m.user.Middleware(next) } // CurrentUser 获取当前登录的用户信息 -func (l *Loader) CurrentUser(ctx *web.Context) *user.User { return l.user.CurrentUser(ctx) } +func (m *Module) CurrentUser(ctx *web.Context) *user.User { return m.user.CurrentUser(ctx) } // NewResourceGroup 新建资源分组 -func (l *Loader) NewResourceGroup(mod *cmfx.Module) *rbac.ResourceGroup { - return l.roleGroup.RBAC().NewResourceGroup(mod.ID(), mod.Desc()) +func (m *Module) NewResourceGroup(mod *cmfx.Module) *rbac.ResourceGroup { + return m.roleGroup.RBAC().NewResourceGroup(mod.ID(), mod.Desc()) } // GetResourceGroup 获取指定 ID 的资源分组 -func (l *Loader) GetResourceGroup(id string) *rbac.ResourceGroup { - return l.roleGroup.RBAC().ResourceGroup(id) +func (m *Module) GetResourceGroup(id string) *rbac.ResourceGroup { + return m.roleGroup.RBAC().ResourceGroup(id) } // ResourceGroup 当前资源组 -func (l *Loader) ResourceGroup() *rbac.ResourceGroup { return l.GetResourceGroup(l.user.Module().ID()) } +func (m *Module) ResourceGroup() *rbac.ResourceGroup { return m.GetResourceGroup(m.user.Module().ID()) } -func (l *Loader) AddSecurityLog(tx *orm.Tx, uid int64, content, ip, ua string) error { - return l.user.AddSecurityLog(tx, uid, ip, ua, content) +func (m *Module) AddSecurityLog(tx *orm.Tx, uid int64, content, ip, ua string) error { + return m.user.AddSecurityLog(tx, uid, ip, ua, content) } -func (l *Loader) AddSecurityLogWithContext(tx *orm.Tx, uid int64, ctx *web.Context, content string) error { - return l.user.AddSecurityLogFromContext(tx, uid, ctx, content) +func (m *Module) AddSecurityLogWithContext(tx *orm.Tx, uid int64, ctx *web.Context, content string) error { + return m.user.AddSecurityLogFromContext(tx, uid, ctx, content) } // OnLogin 注册登录事件 -func (l *Loader) OnLogin(f func(*user.User)) context.CancelFunc { - return l.loginEvent.Subscribe(f) -} +func (m *Module) OnLogin(f func(*user.User)) context.CancelFunc { return m.loginEvent.Subscribe(f) } // OnLogout 注册用户主动退出时的事 -func (l *Loader) OnLogout(f func(*user.User)) context.CancelFunc { - return l.logoutEvent.Subscribe(f) -} - -func newAdmin(mod *cmfx.Module, pa passport.Adapter, data *respInfoWithAccount, now time.Time) error { - tx, err := mod.DB().Begin() - if err != nil { - return err - } +func (m *Module) OnLogout(f func(*user.User)) context.CancelFunc { return m.logoutEvent.Subscribe(f) } - e := tx.NewEngine(mod.DB().TablePrefix()) +func (m *Module) Module() *cmfx.Module { return m.user.Module() } - id, err := e.LastInsertID(&user.User{NO: mod.Server().UniqueID()}) +// 手动添加一个新的管理员 +func (m *Module) newAdmin(pa passport.Adapter, data *respInfoWithAccount, now time.Time) error { + uid, err := m.user.NewUser(pa, data.Username, data.Password, now) if err != nil { - return errors.Join(err, tx.Rollback()) + return err } a := &modelInfo{ - ID: id, + ID: uid, Nickname: data.Nickname, Name: data.Name, Avatar: data.Avatar, Sex: data.Sex, } - if _, err = e.Insert(a); err != nil { - return errors.Join(err, tx.Rollback()) - } - - if err := tx.Commit(); err != nil { - return err - } - - if err := pa.Add(id, data.Username, data.Password, now); err != nil { + if _, err = m.Module().DB().Insert(a); err != nil { return err } for _, role := range data.roles { - if err := role.Link(id); err != nil { + if err := role.Link(uid); err != nil { return err } } @@ -195,4 +175,6 @@ func newAdmin(mod *cmfx.Module, pa passport.Adapter, data *respInfoWithAccount, return nil } -func (l *Loader) Module() *cmfx.Module { return l.user.Module() } +func (m *Module) newRole(name, desc, parent string) (*rbac.Role, error) { + return m.roleGroup.NewRole(name, desc, parent) +} diff --git a/cmfx/modules/admin/route_admins.go b/cmfx/modules/admin/route_admins.go index 6dda2dbf..f22dc280 100644 --- a/cmfx/modules/admin/route_admins.go +++ b/cmfx/modules/admin/route_admins.go @@ -24,14 +24,14 @@ import ( // @tag admin // @path id int 管理的 ID // @resp 200 * respInfoWithRoleState -func (l *Loader) getAdmin(ctx *web.Context) web.Responser { +func (m *Module) getAdmin(ctx *web.Context) web.Responser { id, resp := ctx.PathID("id", cmfx.BadRequestInvalidPath) if resp != nil { return resp } a := &modelInfo{ID: id} - found, err := l.Module().DB().Select(a) + found, err := m.Module().DB().Select(a) if err != nil { return ctx.Error(err, "") } @@ -39,8 +39,8 @@ func (l *Loader) getAdmin(ctx *web.Context) web.Responser { return ctx.NotFound() } - roles := l.roleGroup.UserRoles(id) - u, err := l.user.GetUser(id) + roles := m.roleGroup.UserRoles(id) + u, err := m.user.GetUser(id) if err != nil { return ctx.Error(err, "") } @@ -68,7 +68,7 @@ type queryAdmins struct { Roles []string `query:"role"` States []user.State `query:"state,normal"` Sexes []types.Sex `query:"sex"` - m *Loader + m *Module } func (q *queryAdmins) Filter(v *web.FilterContext) { @@ -89,16 +89,16 @@ func (q *queryAdmins) Filter(v *web.FilterContext) { // @tag admin // @query queryAdmins // @resp 200 * github.com/issue9/cmfx/cmfx/query.Page[respInfoWithRoleState] -func (l *Loader) getAdmins(ctx *web.Context) web.Responser { - q := &queryAdmins{m: l} +func (m *Module) getAdmins(ctx *web.Context) web.Responser { + q := &queryAdmins{m: m} if resp := ctx.QueryObject(true, q, cmfx.BadRequestInvalidQuery); resp != nil { return resp } - sql := l.Module().DB().SQLBuilder().Select().Column("*").From(orm.TableName(&modelInfo{}), "info") + sql := m.Module().DB().SQLBuilder().Select().Column("*").From(orm.TableName(&modelInfo{}), "info") if len(q.States) > 0 { - l.user.LeftJoin(sql, "user", "user.id=info.id", q.States) + m.user.LeftJoin(sql, "user", "user.id=info.id", q.States) } if len(q.Sexes) > 0 { @@ -120,7 +120,7 @@ func (l *Loader) getAdmins(ctx *web.Context) web.Responser { } return query.PagingResponserWithConvert[info, respInfoWithRoleState](ctx, &q.Limit, sql, func(i *info) *respInfoWithRoleState { - roles := l.roleGroup.UserRoles(i.ID) + roles := m.roleGroup.UserRoles(i.ID) rs := make([]string, 0, len(roles)) for _, r := range roles { rs = append(rs, r.ID) @@ -145,13 +145,13 @@ func (l *Loader) getAdmins(ctx *web.Context) web.Responser { // @tag admin // @req * respInfoWithRoleState // @resp 204 * {} -func (l *Loader) patchAdmin(ctx *web.Context) web.Responser { +func (m *Module) patchAdmin(ctx *web.Context) web.Responser { id, resp := ctx.PathID("id", cmfx.BadRequestInvalidPath) if resp != nil { return resp } - u, err := l.user.GetUser(id) + u, err := m.user.GetUser(id) if err != nil { return ctx.Error(err, "") } @@ -171,18 +171,18 @@ func (l *Loader) patchAdmin(ctx *web.Context) web.Responser { Sex: data.Sex, } - tx, err := l.Module().DB().Begin() + tx, err := m.Module().DB().Begin() if err != nil { return ctx.Error(err, "") } - e := tx.NewEngine(l.Module().DB().TablePrefix()) + e := tx.NewEngine(m.Module().DB().TablePrefix()) if _, err := e.Update(aa, "sex"); err != nil { return ctx.Error(errors.Join(err, tx.Rollback()), "") } for _, rid := range data.Roles { - r := l.roleGroup.Role(rid) + r := m.roleGroup.Role(rid) if r == nil { continue } @@ -192,7 +192,7 @@ func (l *Loader) patchAdmin(ctx *web.Context) web.Responser { } } - if err := l.user.SetState(tx, u, data.State); err != nil { + if err := m.user.SetState(tx, u, data.State); err != nil { return ctx.Error(errors.Join(err, tx.Rollback()), "") } @@ -207,19 +207,19 @@ func (l *Loader) patchAdmin(ctx *web.Context) web.Responser { // @tag admin // @resp 204 * {} // @path id int 管理的 ID -func (l *Loader) deleteAdminPassword(ctx *web.Context) web.Responser { +func (m *Module) deleteAdminPassword(ctx *web.Context) web.Responser { id, resp := ctx.PathID("id", cmfx.BadRequestInvalidPath) if resp != nil { return resp } // 查看指定的用户是否真实存在,不判断状态,即使锁定,也能改其信息 - if _, err := l.user.GetUser(id); err != nil { + if _, err := m.user.GetUser(id); err != nil { return ctx.Error(err, "") } // 更新数据库 - if err := l.password.Set(id, defaultPassword); err != nil { + if err := m.password.Set(id, defaultPassword); err != nil { return ctx.Error(err, "") } @@ -230,13 +230,13 @@ func (l *Loader) deleteAdminPassword(ctx *web.Context) web.Responser { // @tag admin // @req * respInfoWithAccount // @resp 201 * {} -func (l *Loader) postAdmins(ctx *web.Context) web.Responser { - data := &respInfoWithAccount{respInfoWithRoleState: respInfoWithRoleState{respInfo: respInfo{m: l}}} +func (m *Module) postAdmins(ctx *web.Context) web.Responser { + data := &respInfoWithAccount{respInfoWithRoleState: respInfoWithRoleState{respInfo: respInfo{m: m}}} if resp := ctx.Read(true, data, cmfx.BadRequestInvalidBody); resp != nil { return resp } - if err := newAdmin(l.Module(), l.password, data, ctx.Begin()); err != nil { + if err := m.newAdmin(m.password, data, ctx.Begin()); err != nil { return ctx.Error(err, "") } return web.Created(nil, "") @@ -245,38 +245,38 @@ func (l *Loader) postAdmins(ctx *web.Context) web.Responser { // # api POST /admins/{id}/locked 锁定管理员 // @tag admin // @resp 201 * {} -func (l *Loader) postAdminLocked(ctx *web.Context) web.Responser { - return l.setAdminState(ctx, user.StateLocked, http.StatusCreated) +func (m *Module) postAdminLocked(ctx *web.Context) web.Responser { + return m.setAdminState(ctx, user.StateLocked, http.StatusCreated) } // # api delete /admins/{id} 删除管理员 // @tag admin // @path id id 管理员的 ID // @resp 201 * {} -func (l *Loader) deleteAdmin(ctx *web.Context) web.Responser { - return l.setAdminState(ctx, user.StateDeleted, http.StatusCreated) +func (m *Module) deleteAdmin(ctx *web.Context) web.Responser { + return m.setAdminState(ctx, user.StateDeleted, http.StatusCreated) } // # api delete /admins/{id}/locked 解除锁定 // @tag admin // @path id id 管理员的 ID // @resp 204 * {} -func (l *Loader) deleteAdminLocked(ctx *web.Context) web.Responser { - return l.setAdminState(ctx, user.StateNormal, http.StatusNoContent) +func (m *Module) deleteAdminLocked(ctx *web.Context) web.Responser { + return m.setAdminState(ctx, user.StateNormal, http.StatusNoContent) } -func (l *Loader) setAdminState(ctx *web.Context, state user.State, code int) web.Responser { +func (m *Module) setAdminState(ctx *web.Context, state user.State, code int) web.Responser { id, resp := ctx.PathID("id", cmfx.BadRequestInvalidPath) if resp != nil { return resp } - u, err := l.user.GetUser(id) + u, err := m.user.GetUser(id) if err != nil { return ctx.Error(err, "") } - if err := l.user.SetState(nil, u, state); err != nil { + if err := m.user.SetState(nil, u, state); err != nil { return ctx.Error(err, "") } diff --git a/cmfx/modules/admin/route_current.go b/cmfx/modules/admin/route_current.go index 17e3e279..e776c7bf 100644 --- a/cmfx/modules/admin/route_current.go +++ b/cmfx/modules/admin/route_current.go @@ -18,23 +18,23 @@ import ( // # api get /info 获取当前登用户的信息 // @tag admin // @resp 200 * respInfo -func (l *Loader) getInfo(ctx *web.Context) web.Responser { - return web.OK(l.CurrentUser(ctx)) +func (m *Module) getInfo(ctx *web.Context) web.Responser { + return web.OK(m.CurrentUser(ctx)) } // # api patch /info 更新当前登用户的信息 // @tag admin // @req * respInfo 更新的信息,将忽略 id // @resp 204 * {} -func (l *Loader) patchInfo(ctx *web.Context) web.Responser { +func (m *Module) patchInfo(ctx *web.Context) web.Responser { data := &respInfo{} if resp := ctx.Read(true, data, cmfx.BadRequestInvalidBody); resp != nil { return resp } - a := l.CurrentUser(ctx) + a := m.CurrentUser(ctx) - _, err := l.Module().DB().Update(&modelInfo{ + _, err := m.Module().DB().Update(&modelInfo{ ID: a.ID, Nickname: data.Nickname, Avatar: data.Avatar, @@ -45,7 +45,7 @@ func (l *Loader) patchInfo(ctx *web.Context) web.Responser { return ctx.Error(err, "") } - if err := l.user.AddSecurityLogFromContext(nil, a.ID, ctx, "更新个人信息"); err != nil { + if err := m.user.AddSecurityLogFromContext(nil, a.ID, ctx, "更新个人信息"); err != nil { return ctx.Error(err, "") } @@ -72,27 +72,27 @@ func (p *putPassword) Filter(v *web.FilterContext) { // @tag admin // @req * putPassword // @resp 204 * {} -func (l *Loader) putCurrentPassword(ctx *web.Context) web.Responser { +func (m *Module) putCurrentPassword(ctx *web.Context) web.Responser { data := &putPassword{} if resp := ctx.Read(true, data, cmfx.BadRequestInvalidBody); resp != nil { return resp } - a := l.CurrentUser(ctx) - err := l.password.Change(a.ID, data.Old, data.New) + a := m.CurrentUser(ctx) + err := m.password.Change(a.ID, data.Old, data.New) if errors.Is(err, passport.ErrUnauthorized()) { return ctx.Problem(cmfx.Unauthorized) } else if err != nil { return ctx.Error(err, "") } - return l.user.Logout(ctx, nil, web.Phrase("change password")) + return m.user.Logout(ctx, nil, web.Phrase("change password")) } // # api get /securitylog 当前用户的安全操作记录 // @tag admin // @query github.com/issue9/cmfx/cmfx/user.queryLog // @resp 200 * github.com/issue9/cmfx/cmfx/query.Page[github.com/issue9/cmfx/cmfx/user.respLog] -func (l *Loader) getSecurityLogs(ctx *web.Context) web.Responser { - return l.user.GetSecurityLogs(ctx) +func (m *Module) getSecurityLogs(ctx *web.Context) web.Responser { + return m.user.GetSecurityLogs(ctx) } diff --git a/cmfx/modules/admin/route_rbac.go b/cmfx/modules/admin/route_rbac.go index 2241c52d..9765da3b 100644 --- a/cmfx/modules/admin/route_rbac.go +++ b/cmfx/modules/admin/route_rbac.go @@ -13,16 +13,16 @@ import ( // # api get /roles 获取权限组列表 // @tag admin rbac // @resp 200 * []github.com/issue9/cmfx/cmfx/user/rbac.respRole -func (l *Loader) getRoles(ctx *web.Context) web.Responser { - return rbac.GetRolesHandle(l.roleGroup, ctx) +func (m *Module) getRoles(ctx *web.Context) web.Responser { + return rbac.GetRolesHandle(m.roleGroup, ctx) } // # api post /roles 添加一个权限组 // @tag admin rbac // @req * github.com/issue9/cmfx/cmfx/user/rbac.reqRole // @resp 201 * {} -func (l *Loader) postRoles(ctx *web.Context) web.Responser { - return rbac.PostRolesHandle(l.roleGroup, ctx) +func (m *Module) postRoles(ctx *web.Context) web.Responser { + return rbac.PostRolesHandle(m.roleGroup, ctx) } // # api put /roles/{id} 修改权限组 @@ -30,31 +30,31 @@ func (l *Loader) postRoles(ctx *web.Context) web.Responser { // @path id id 权限组 ID // @req * github.com/issue9/cmfx/cmfx/user/rbac.reqRole // @resp 204 * {} -func (l *Loader) putRole(ctx *web.Context) web.Responser { - return rbac.PutRoleHandle(l.roleGroup, "id", ctx) +func (m *Module) putRole(ctx *web.Context) web.Responser { + return rbac.PutRoleHandle(m.roleGroup, "id", ctx) } // # api delete /roles/{id} 删除权限组 // @tag admin rbac // @path id id 权限组 ID // @resp 204 * {} -func (l *Loader) deleteRole(ctx *web.Context) web.Responser { - return rbac.DeleteRoleHandle(l.roleGroup, "id", ctx) +func (m *Module) deleteRole(ctx *web.Context) web.Responser { + return rbac.DeleteRoleHandle(m.roleGroup, "id", ctx) } // # api get /resources 获取所有的资源 // @tag admin rbac // @resp 200 * map 键名为资源 ID,键值为资源描述 -func (l *Loader) getResources(ctx *web.Context) web.Responser { - return rbac.GetResourcesHandle(l.roleGroup, ctx) +func (m *Module) getResources(ctx *web.Context) web.Responser { + return rbac.GetResourcesHandle(m.roleGroup, ctx) } // # api get /roles/{id}/resources 获得角色已被允许访问的资源 // @tag admin rbac // @path id id 权限组 ID // @resp 200 application/json map 键名为资源 ID,键值为资源描述 -func (l *Loader) getRoleResources(ctx *web.Context) web.Responser { - return rbac.GetRoleResourcesHandle(l.roleGroup, "id", ctx) +func (m *Module) getRoleResources(ctx *web.Context) web.Responser { + return rbac.GetRoleResourcesHandle(m.roleGroup, "id", ctx) } // # api patch /roles/{id}/resources 设置权限组的可访问的资源 @@ -62,6 +62,6 @@ func (l *Loader) getRoleResources(ctx *web.Context) web.Responser { // @path id id 权限组 ID // @req * []string 资源 ID 列表 // @resp 204 * {} -func (l *Loader) putRoleResources(ctx *web.Context) web.Responser { - return rbac.PutRoleResourcesHandle(l.roleGroup, "id", ctx) +func (m *Module) putRoleResources(ctx *web.Context) web.Responser { + return rbac.PutRoleResourcesHandle(m.roleGroup, "id", ctx) } diff --git a/cmfx/modules/admin/route_token.go b/cmfx/modules/admin/route_token.go index b550e4c6..5dd364e5 100644 --- a/cmfx/modules/admin/route_token.go +++ b/cmfx/modules/admin/route_token.go @@ -6,42 +6,51 @@ package admin import ( "github.com/issue9/web" + "github.com/issue9/web/filter" "github.com/issue9/cmfx/cmfx" "github.com/issue9/cmfx/cmfx/user" + "github.com/issue9/cmfx/locales" ) type queryLogin struct { - Type string `query:"type"` + m *Module + Type string `query:"type,password"` +} + +func (q *queryLogin) Filter(c *web.FilterContext) { + v := func(s string) bool { return q.m.user.Passport().Get(s) != nil } + c.Add(filter.NewBuilder(filter.V(v, locales.InvalidValue))("type", &q.Type)) } // # API POST /login 管理员登录 // @tag admin auth +// @query queryLogin // @req * github.com/issue9/cmfx/cmfx/user.reqAccount // @resp 201 * github.com/issue9/webuse/v7/middlewares/auth/token.Response -func (l *Loader) postLogin(ctx *web.Context) web.Responser { - q := &queryLogin{} +func (m *Module) postLogin(ctx *web.Context) web.Responser { + q := &queryLogin{m: m} if resp := ctx.QueryObject(true, q, cmfx.BadRequestInvalidQuery); resp != nil { return resp } - return l.user.Login(q.Type, ctx, nil, func(u *user.User) { - l.loginEvent.Publish(false, u) + return m.user.Login(q.Type, ctx, nil, func(u *user.User) { + m.loginEvent.Publish(false, u) }) } // # api delete /login 注销当前管理员的登录 // @tag admin auth // @resp 204 * {} -func (l *Loader) deleteLogin(ctx *web.Context) web.Responser { - return l.user.Logout(ctx, func(u *user.User) { - l.logoutEvent.Publish(false, u) +func (m *Module) deleteLogin(ctx *web.Context) web.Responser { + return m.user.Logout(ctx, func(u *user.User) { + m.logoutEvent.Publish(false, u) }, web.Phrase("logout")) } -// # api get /token 续定 token +// # api put /login 续定 token // @tag admin auth // @resp 201 * github.com/issue9/webuse/v7/middlewares/auth/token.Response -func (l *Loader) getToken(ctx *web.Context) web.Responser { - return l.user.RefreshToken(ctx) +func (m *Module) getToken(ctx *web.Context) web.Responser { + return m.user.RefreshToken(ctx) } diff --git a/cmfx/modules/system/install.go b/cmfx/modules/system/install.go index 3ac87aec..d7494dc8 100644 --- a/cmfx/modules/system/install.go +++ b/cmfx/modules/system/install.go @@ -14,24 +14,27 @@ import ( "github.com/issue9/web" "github.com/issue9/cmfx/cmfx" + "github.com/issue9/cmfx/cmfx/modules/admin" ) -func Install(mod *cmfx.Module) { +func Install(mod *cmfx.Module, conf *Config, adminL *admin.Module) *Module { if err := mod.DB().Create(&modelHealth{}, &modelLinkage{}); err != nil { panic(web.SprintError(mod.Server().Locale().Printer(), true, err)) } + + return Load(mod, conf, adminL) } // DeleteLinkage 删除指定的级联数据 -func DeleteLinkage(l *Loader, key string) error { +func DeleteLinkage(l *Module, key string) error { mod := &modelLinkage{Deleted: sql.NullTime{Valid: true, Time: time.Now()}} _, err := l.mod.DB().Where("key=?", key).Update(mod) return err } // InstallLinkage 安装一组级联数据 -func InstallLinkage[T any](l *Loader, key, title string, items []*LinkageItem[T]) error { - // TODO 限制 T 的类型能为指针? +func InstallLinkage[T any](l *Module, key, title string, items []*LinkageItem[T]) error { + checkObjectType[T]() db := l.mod.DB() root := &modelLinkage{} diff --git a/cmfx/modules/system/install_test.go b/cmfx/modules/system/install_test.go index 9b8e64d2..ba4468c7 100644 --- a/cmfx/modules/system/install_test.go +++ b/cmfx/modules/system/install_test.go @@ -10,14 +10,21 @@ import ( "github.com/issue9/assert/v4" "github.com/issue9/cmfx/cmfx/initial/test" + "github.com/issue9/cmfx/cmfx/modules/admin/admintest" ) func TestInstall(t *testing.T) { a := assert.New(t, false) s := test.NewSuite(a) + adminL := admintest.NewAdmin(s) + + conf := &Config{} + s.Assertion().NotError(conf.SanitizeConfig()) mod := s.NewModule("mod") - Install(mod) + l := Install(mod, conf, adminL) + a.NotNil(l) + s.TableExists("mod_linkages").TableExists("mod_api_healths") } diff --git a/cmfx/modules/system/linkage.go b/cmfx/modules/system/linkage.go index 177bd509..3a74ecd2 100644 --- a/cmfx/modules/system/linkage.go +++ b/cmfx/modules/system/linkage.go @@ -8,6 +8,7 @@ import ( "database/sql" "encoding/json" "fmt" + "reflect" "slices" "time" @@ -17,7 +18,7 @@ import ( // Linkage 级联菜单元素 type Linkage[T any] struct { - l *Loader + l *Module id int64 title string items []*LinkageItem[T] @@ -30,8 +31,8 @@ type LinkageItem[T any] struct { Items []*LinkageItem[T] } -func LoadLinkage[T any](l *Loader, key string) (*Linkage[T], error) { - // TODO 限制 T 的类型不能为指针? +func LoadLinkage[T any](l *Module, key string) (*Linkage[T], error) { + checkObjectType[T]() db := l.mod.DB() @@ -73,7 +74,7 @@ func LoadLinkage[T any](l *Loader, key string) (*Linkage[T], error) { } if root == nil { - panic(fmt.Sprintf("未从数据库中找到 %s 的根元素", key)) + return nil, web.NewLocaleError("linkage key %s not found", key) } root.items = make([]*LinkageItem[T], 0, len(linkageItems)) @@ -86,10 +87,10 @@ func LoadLinkage[T any](l *Loader, key string) (*Linkage[T], error) { modItem, found := sliceutil.At(modItems, func(i *modelLinkage, _ int) bool { return i.ID == item.ID }) if !found { - panic("无法从原数据中找到相应的 ID") + panic("无法从原数据中找到相应的 ID") // 这应该是代码出错或是数据库被修改 } if modItem.Parent == 0 { - panic("parent.ID 为零的应该在之前代码中被过滤") + panic("parent.ID 为零的应该在之前代码中被过滤") // 这应该是代码出错或是数据库被修改 } if modItem.Parent == root.id { @@ -101,7 +102,6 @@ func LoadLinkage[T any](l *Loader, key string) (*Linkage[T], error) { } parent.Items = append(parent.Items, item) } - } return root, nil @@ -215,3 +215,9 @@ func (l *LinkageItem[T]) find(id int64) (parent, curr *LinkageItem[T]) { func errLinkageItemNotFound(id int64) error { return web.NewLocaleError("linkage item %d not found", id) } + +func checkObjectType[T any]() { + if reflect.TypeFor[T]().Kind() == reflect.Pointer { + panic(fmt.Sprintf("T 的约束必须是结构体")) + } +} diff --git a/cmfx/modules/system/loader.go b/cmfx/modules/system/module.go similarity index 90% rename from cmfx/modules/system/loader.go rename to cmfx/modules/system/module.go index 802d5108..603210a8 100644 --- a/cmfx/modules/system/loader.go +++ b/cmfx/modules/system/module.go @@ -15,9 +15,9 @@ import ( "github.com/issue9/cmfx/cmfx/modules/admin" ) -type Loader struct { +type Module struct { mod *cmfx.Module - admin *admin.Loader + admin *admin.Module health *health.Health monitor *monitor.Monitor @@ -28,13 +28,13 @@ type Loader struct { // Load 加载当前模块 // // conf 当前模块的配置项,需要调用者自先调用 [Config.SanitizeConfig] 对数据进行校正; -func Load(mod *cmfx.Module, conf *Config, adminL *admin.Loader) *Loader { +func Load(mod *cmfx.Module, conf *Config, adminL *admin.Module) *Module { store, err := newHealthDBStore(mod) if err != nil { panic(web.SprintError(mod.Server().Locale().Printer(), true, err)) } - m := &Loader{ + m := &Module{ mod: mod, admin: adminL, health: health.New(store), @@ -50,7 +50,7 @@ func Load(mod *cmfx.Module, conf *Config, adminL *admin.Loader) *Loader { resGetAPIs := g.New("get-apis", web.Phrase("view apis")) resBackup := g.New("backup", web.Phrase("backup database")) - adminRouter := mod.Router().Prefix(adminL.URLPrefix()+conf.URLPrefix, web.MiddlewareFunc(m.admin.AuthMiddleware)) + adminRouter := mod.Router().Prefix(adminL.URLPrefix()+conf.URLPrefix, web.MiddlewareFunc(m.admin.Middleware)) adminRouter.Get("/info", resGetInfo(m.adminGetInfo)). Get("/services", resGetServices(m.adminGetServices)). Get("/apis", resGetAPIs(m.adminGetAPIs)). diff --git a/cmfx/modules/system/loader_test.go b/cmfx/modules/system/module_test.go similarity index 66% rename from cmfx/modules/system/loader_test.go rename to cmfx/modules/system/module_test.go index 911aca96..c6b36893 100644 --- a/cmfx/modules/system/loader_test.go +++ b/cmfx/modules/system/module_test.go @@ -9,16 +9,10 @@ import ( "github.com/issue9/cmfx/cmfx/modules/admin/admintest" ) -func newSystem(s *test.Suite) *Loader { +func newSystem(s *test.Suite) *Module { adminM := admintest.NewAdmin(s) - mod := s.NewModule("test") - Install(mod) - conf := &Config{} s.Assertion().NotError(conf.SanitizeConfig()) - sys := Load(mod, conf, adminM) - s.Assertion().NotNil(sys) - - return sys + return Install(s.NewModule("test"), conf, adminM) } diff --git a/cmfx/modules/system/route.go b/cmfx/modules/system/route.go index e8af8392..0bc41b0c 100644 --- a/cmfx/modules/system/route.go +++ b/cmfx/modules/system/route.go @@ -14,7 +14,7 @@ import ( // # api get /system/apis API 信息 // @tag system admin // @resp 200 * []github.com/issue9/webuse/v7/plugins/health.State -func (l *Loader) adminGetAPIs(_ *web.Context) web.Responser { return web.OK(l.health.States()) } +func (m *Module) adminGetAPIs(_ *web.Context) web.Responser { return web.OK(m.health.States()) } type dbInfo struct { Name string `json:"name" xml:"name" cbor:"name"` // 数据库驱动 @@ -47,9 +47,9 @@ type info struct { // # api get /system/info 系统信息 // @tag system admin // @resp 200 * info -func (l *Loader) adminGetInfo(ctx *web.Context) web.Responser { - dbVersion := l.mod.DB().Version() - stats := l.mod.DB().Stats() +func (m *Module) adminGetInfo(ctx *web.Context) web.Responser { + dbVersion := m.mod.DB().Version() + stats := m.mod.DB().Stats() srv := ctx.Server() return web.OK(&info{ @@ -62,7 +62,7 @@ func (l *Loader) adminGetInfo(ctx *web.Context) web.Responser { CPUS: runtime.NumCPU(), Goroutines: runtime.NumGoroutine(), DB: &dbInfo{ - Name: l.mod.DB().Dialect().Name(), + Name: m.mod.DB().Dialect().Name(), Version: dbVersion, MaxOpenConnections: stats.MaxOpenConnections, OpenConnections: stats.OpenConnections, @@ -98,7 +98,7 @@ type services struct { // # api get /system/services 系统服务状态 // @tag system admin // @resp 200 * services -func (l *Loader) adminGetServices(ctx *web.Context) web.Responser { +func (m *Module) adminGetServices(ctx *web.Context) web.Responser { ss := services{} ctx.Server().Services().Visit(func(title web.LocaleStringer, state web.State, err error) { var err1 string @@ -114,11 +114,16 @@ func (l *Loader) adminGetServices(ctx *web.Context) web.Responser { }) ctx.Server().Services().VisitJobs(func(j *web.Job) { + var err string + if j.Err() != nil { + err = j.Err().Error() + } + ss.Jobs = append(ss.Jobs, job{ service: service{ Title: j.Title().LocaleString(ctx.LocalePrinter()), State: j.State(), - Err: j.Err().Error(), + Err: err, }, Next: j.Next(), Prev: j.Prev(), @@ -129,17 +134,18 @@ func (l *Loader) adminGetServices(ctx *web.Context) web.Responser { } type problem struct { - Prefix string `json:"prefix" xml:"prefix" cbor:"prefix"` // URL 前缀以,如果此值不为空,与 ID 组成一上完整的地址。 - ID string `json:"id" xml:"id" cbor:"id"` // 唯一 ID - Status int `json:"status" xml:"status,attr" cbor:"status"` // 对应的原始 HTTP 状态码 - Title string `json:"title" xml:"title" cbor:"title"` // 错误的简要描述 - Detail string `json:"detail" xml:"detail" cbor:"detail"` // 错误的明细 + XMLName struct{} `json:"-" cbor:"-" xml:"problem"` + Prefix string `json:"prefix" xml:"prefix" cbor:"prefix"` // URL 前缀以,如果此值不为空,与 ID 组成一上完整的地址。 + ID string `json:"id" xml:"id" cbor:"id"` // 唯一 ID + Status int `json:"status" xml:"status,attr" cbor:"status"` // 对应的原始 HTTP 状态码 + Title string `json:"title" xml:"title" cbor:"title"` // 错误的简要描述 + Detail string `json:"detail" xml:"detail" cbor:"detail"` // 错误的明细 } // # api get /system/problems 系统错误信息 // @tag system // @resp 200 * []problem -func (l *Loader) commonGetProblems(ctx *web.Context) web.Responser { +func (m *Module) commonGetProblems(ctx *web.Context) web.Responser { ps := make([]*problem, 0, 100) ctx.Server().Problems().Visit(func(status int, p *web.LocaleProblem) { ps = append(ps, &problem{ @@ -156,13 +162,13 @@ func (l *Loader) commonGetProblems(ctx *web.Context) web.Responser { // # api get /system/monitor 监视系统数据 // @tag system // @resp 200 text/event-stream github.com/issue9/webuse/v7/handlers/monitor.Stats -func (l *Loader) adminGetMonitor(ctx *web.Context) web.Responser { return l.monitor.Handle(ctx) } +func (m *Module) adminGetMonitor(ctx *web.Context) web.Responser { return m.monitor.Handle(ctx) } // # api post /system/backup 手动执行备份数据 // @tag system // @resp 201 * {} -func (l *Loader) adminPostBackup(ctx *web.Context) web.Responser { - if err := l.mod.DB().Backup(l.buildBackupFilename(ctx.Begin())); err != nil { +func (m *Module) adminPostBackup(ctx *web.Context) web.Responser { + if err := m.mod.DB().Backup(m.buildBackupFilename(ctx.Begin())); err != nil { return ctx.Error(err, web.ProblemInternalServerError) } return web.Created(nil, "") diff --git a/cmfx/modules/system/route_test.go b/cmfx/modules/system/route_test.go index 01334199..20a8b735 100644 --- a/cmfx/modules/system/route_test.go +++ b/cmfx/modules/system/route_test.go @@ -10,10 +10,13 @@ import ( "testing" "github.com/issue9/assert/v4" - "github.com/issue9/cmfx/cmfx/initial/test" "github.com/issue9/mux/v8/header" "github.com/issue9/web/server/servertest" + "github.com/issue9/webuse/v7/middlewares/auth" "golang.org/x/text/language" + + "github.com/issue9/cmfx/cmfx/initial/test" + "github.com/issue9/cmfx/cmfx/modules/admin/admintest" ) func TestSystem_apis(t *testing.T) { @@ -21,14 +24,17 @@ func TestSystem_apis(t *testing.T) { suite := test.NewSuite(a) err := suite.Module().Server().Locale().SetString(language.SimplifiedChinese, "v1 desc", "v1 cn") a.NotError(err) - newSystem(suite) + l := newSystem(suite) defer servertest.Run(a, suite.Module().Server())() defer suite.Close() + token := admintest.GetToken(suite, l.admin) + suite.Get("/admin/system/info"). Header(header.AcceptLanguage, language.SimplifiedChinese.String()). Header(header.Accept, "application/json;charset=utf-8"). + Header(header.Authorization, auth.BuildToken(auth.Bearer, token)). Do(nil). Status(http.StatusOK). BodyFunc(func(a *assert.Assertion, body []byte) { @@ -39,6 +45,7 @@ func TestSystem_apis(t *testing.T) { suite.Get("/admin/system/services"). Header(header.AcceptLanguage, language.SimplifiedChinese.String()). Header(header.Accept, "application/json;charset=utf-8"). + Header(header.Authorization, auth.BuildToken(auth.Bearer, token)). Do(nil). Status(http.StatusOK). BodyFunc(func(a *assert.Assertion, body []byte) { @@ -49,6 +56,7 @@ func TestSystem_apis(t *testing.T) { suite.Get("/admin/system/apis"). Header(header.AcceptLanguage, language.SimplifiedChinese.String()). Header(header.Accept, "application/json;charset=utf-8"). + Header(header.Authorization, auth.BuildToken(auth.Bearer, token)). Do(nil). Status(http.StatusOK). BodyFunc(func(a *assert.Assertion, body []byte) { diff --git a/cmfx/modules/system/systemtest/systemtest.go b/cmfx/modules/system/systemtest/systemtest.go index c538d658..8d7a948c 100644 --- a/cmfx/modules/system/systemtest/systemtest.go +++ b/cmfx/modules/system/systemtest/systemtest.go @@ -10,13 +10,12 @@ import ( "github.com/issue9/cmfx/cmfx/modules/system" ) -func NewSystem(s *test.Suite, adminL *admin.Loader) *system.Loader { +func NewSystem(s *test.Suite, adminL *admin.Module) *system.Module { mod := s.NewModule("system") - system.Install(mod) conf := &system.Config{} s.Assertion().NotError(conf.SanitizeConfig()) - sys := system.Load(mod, conf, adminL) + sys := system.Install(mod, conf, adminL) s.Assertion().NotNil(sys) return sys diff --git a/cmfx/user/loader_test.go b/cmfx/user/loader_test.go deleted file mode 100644 index ba879036..00000000 --- a/cmfx/user/loader_test.go +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-FileCopyrightText: 2022-2024 caixw -// -// SPDX-License-Identifier: MIT - -package user - -import "github.com/issue9/web" - -var _ web.Middleware = &Loader{} diff --git a/cmfx/user/models.go b/cmfx/user/models.go index 83c09824..c5c5dd84 100644 --- a/cmfx/user/models.go +++ b/cmfx/user/models.go @@ -16,49 +16,51 @@ const ( StateDeleted // 离职 ) -type ( - // State 表示管理员的状态 - // - // @enum - // @type string - State int8 - - User struct { - XMLName struct{} `orm:"-" json:"-" xml:"user" cbor:"-"` - ID int64 `orm:"name(id);ai" json:"id" xml:"id,attr" cbor:"id"` // 用户的自增 ID - NO string `orm:"name(no);len(32);unique(no)" json:"no" xml:"no,attr" cbor:"no"` // 用户的唯一编号,一般用于前端 - Created time.Time `orm:"name(created)" json:"created" xml:"created,attr" cbor:"created"` // 添加时间 - State State `orm:"name(state)" json:"state" xml:"state,attr" cbor:"state"` // 状态 - } - - modelLog struct { - ID int64 `orm:"name(id);ai"` - Created time.Time `orm:"name(created)"` - - UID int64 `orm:"name(uid);index(uid)"` // 关联的用户 - Content string `orm:"name(content);len(-1)"` - IP string `orm:"name(ip);len(50)"` - UserAgent string `orm:"name(user_agent);len(500)"` - } - - respLog struct { - Content string `json:"content" xml:",cdata" cbor:"content"` - IP string `json:"ip" xml:"ip,attr" cbor:"ip"` - UserAgent string `json:"ua" xml:"ua" cbor:"ua"` - Created time.Time `xml:"created" json:"created" cbor:"created"` - } -) +// State 表示管理员的状态 +// +// @enum +// @type string +type State int8 + +type respLog struct { + Content string `json:"content" xml:",cdata" cbor:"content"` + IP string `json:"ip" xml:"ip,attr" cbor:"ip"` + UserAgent string `json:"ua" xml:"ua" cbor:"ua"` + Created time.Time `xml:"created" json:"created" cbor:"created"` +} + +//--------------------------------------- user --------------------------------------- + +type User struct { + XMLName struct{} `orm:"-" json:"-" xml:"user" cbor:"-"` + ID int64 `orm:"name(id);ai" json:"id" xml:"id,attr" cbor:"id"` // 用户的自增 ID + NO string `orm:"name(no);len(32);unique(no)" json:"no" xml:"no,attr" cbor:"no"` // 用户的唯一编号,一般用于前端 + Created time.Time `orm:"name(created)" json:"created" xml:"created,attr" cbor:"created"` // 添加时间 + State State `orm:"name(state)" json:"state" xml:"state,attr" cbor:"state"` // 状态 +} func (u *User) GetUID() string { return u.NO } func (*User) TableName() string { return `_users` } -func (a *User) BeforeInsert() error { - a.ID = 0 - a.Created = time.Now() +func (u *User) BeforeInsert() error { + u.ID = 0 + u.Created = time.Now() return nil } +//--------------------------------- modelLog --------------------------------------------- + +type modelLog struct { + ID int64 `orm:"name(id);ai"` + Created time.Time `orm:"name(created)"` + + UID int64 `orm:"name(uid);index(uid)"` // 关联的用户 + Content string `orm:"name(content);len(-1)"` + IP string `orm:"name(ip);len(50)"` + UserAgent string `orm:"name(user_agent);len(500)"` +} + func (l *modelLog) TableName() string { return "_securitylogs" } func (l *modelLog) BeforeInsert() error { @@ -70,6 +72,4 @@ func (l *modelLog) BeforeInsert() error { return nil } -func (l *modelLog) BeforeUpdate() error { - panic("此表不存在更新记录的情况") -} +func (l *modelLog) BeforeUpdate() error { panic("此表不存在更新记录的情况") } diff --git a/cmfx/user/loader.go b/cmfx/user/module.go similarity index 65% rename from cmfx/user/loader.go rename to cmfx/user/module.go index bf38d59d..aff472ad 100644 --- a/cmfx/user/loader.go +++ b/cmfx/user/module.go @@ -7,7 +7,6 @@ package user import ( "net/http" - "os" "github.com/issue9/cache" "github.com/issue9/orm/v6" @@ -20,8 +19,8 @@ import ( "github.com/issue9/cmfx/cmfx/user/passport" ) -// Loader 用户账号加载 -type Loader struct { +// Module 用户账号模块 +type Module struct { mod *cmfx.Module urlPrefix string // 所有接口的 URL 前缀 token *tokens @@ -29,10 +28,10 @@ type Loader struct { } // Load 加载当前模块的环境 -func Load(mod *cmfx.Module, conf *Config) *Loader { +func Load(mod *cmfx.Module, conf *Config) *Module { store := token.NewCacheStore[*User](cache.Prefix(mod.Server().Cache(), mod.ID())) - return &Loader{ + return &Module{ mod: mod, urlPrefix: conf.URLPrefix, token: token.New(mod.Server(), store, conf.accessExpired, conf.refreshExpired, web.ProblemUnauthorized, nil), @@ -40,24 +39,26 @@ func Load(mod *cmfx.Module, conf *Config) *Loader { } } -func (m *Loader) URLPrefix() string { return m.urlPrefix } +func (m *Module) URLPrefix() string { return m.urlPrefix } // GetUser 获取指定 uid 的用户 -func (m *Loader) GetUser(uid int64) (*User, error) { +func (m *Module) GetUser(uid int64) (*User, error) { u := &User{ID: uid} found, err := m.Module().DB().Select(u) if err != nil { return nil, err } if !found { - return nil, web.NewError(http.StatusNotFound, os.ErrNotExist) + return nil, web.NewError(http.StatusNotFound, cmfx.ErrNotFound()) } return u, nil } -func (m *Loader) LeftJoin(sql *sqlbuilder.SelectStmt, alias, on string, states []State) { - sql.Column(alias + ".state") - sql.Join("left", orm.TableName(&User{}), alias, on).WhereStmt().AndIn(alias+".state", sliceutil.AnySlice(states)...) +func (m *Module) LeftJoin(sql *sqlbuilder.SelectStmt, alias, on string, states []State) { + sql.Column(alias+".state"). + Join("left", orm.TableName(&User{}), alias, on). + WhereStmt(). + AndIn(alias+".state", sliceutil.AnySlice(states)...) } -func (m *Loader) Module() *cmfx.Module { return m.mod } +func (m *Module) Module() *cmfx.Module { return m.mod } diff --git a/cmfx/user/module_test.go b/cmfx/user/module_test.go new file mode 100644 index 00000000..9231b2b9 --- /dev/null +++ b/cmfx/user/module_test.go @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: 2022-2024 caixw +// +// SPDX-License-Identifier: MIT + +package user + +import ( + "time" + + "github.com/issue9/web" + + "github.com/issue9/cmfx/cmfx/initial/test" + "github.com/issue9/cmfx/cmfx/user/passport/password" +) + +var _ web.Middleware = &Module{} + +// 声明 [Module] 变量 +// +// 安装了注释库并提供一个 password 名称的密码验证功能 +func newLoader(s *test.Suite) *Module { + conf := &Config{ + URLPrefix: "/user", + AccessExpired: 60, + RefreshExpired: 600, + } + s.Assertion().NotError(conf.SanitizeConfig()) + + mod := s.NewModule("user") + Install(mod) + password.Install(mod) + + u := Load(mod, conf) + s.Assertion().NotNil(u) + + u.Passport().Register("password", password.New(u.Module(), 9), web.Phrase("password")) + p := u.Passport().Get("password") + uid, err := u.NewUser(p, "admin", "password", time.Now()) + s.Assertion().NotError(err).NotZero(uid) + + return u +} diff --git a/cmfx/user/passport/adapter.go b/cmfx/user/passport/adapter.go index fbbb3ec4..b40aac92 100644 --- a/cmfx/user/passport/adapter.go +++ b/cmfx/user/passport/adapter.go @@ -13,9 +13,10 @@ type Adapter interface { // username, password 向验证器提供的登录凭证,不同的实现对此两者的定义可能是不同的, // 比如 oauth2 中表示的是由 authURL 返回的 state 和 code 参数。 // - // uid 表示验证成功之后返回与 username 关联的用户 ID; - // identity 表示在当前适配器中的唯一 ID 表示。部分适配器在 uid 为 0 时可能也返回一个非空的 identity,比如 ouath。 - // 用户可以将之与 uid 关联并注册; + // uid 和 identity 分别表示验证成功之后,与之关联的用户 ID 以及在当前适配器中表示的唯一 ID。 + // 有可能存在 uid 为零而 identity 不会空的情况,比如由 [Adapter.Add] 添了一条 uid 为零的数据 + // 或是像 oauth 验证等也可能返回 uid 为零。一旦返回的 uid 为零,表示用户提交的数据没问题, + // 但是找不到与外部用户关联的 uid,可通过 [Adapter.Add] 与具体的 uid 进行关联; // // 如果验证失败,将返回 [ErrUnauthorized] 错误。 Valid(username, password string, t time.Time) (uid int64, identity string, err error) @@ -26,22 +27,27 @@ type Adapter interface { Identity(int64) (string, error) // Delete 解绑用户 + // + // 如果 uid 为零值,清空所有的临时验证数据。 Delete(uid int64) error // Change 改变用户的认证数据 // - // uid 为需要操作的用户; + // uid 为需要操作的用户,不能为零; // pass 一般为旧的认证代码,比如密码、验证码等; // n 为新的认证数据,由用户自定义,一般为新密码或是新的设备 ID 等; Change(uid int64, pass, n string) error // Set 强制修改用户 uid 的认证数据 // + // uid 为需要操作的用户,不能为零; // n 的定义与 [Adapter.Change] 是相同的。 Set(uid int64, n string) error // Add 关联用户数据 // + // uid 表示关联的用户 ID,如果为空值,表示添加一个临时的验证数据, + // 之后在 [Adapter.Valid] 中验证不再返回错误,但是返回的 uid 为零; // identity 为用户在当前对象中的唯一标记; // code 为实现者的自定义行为,比如密码、设备的当前代码等。 Add(uid int64, identity, code string, t time.Time) error diff --git a/cmfx/user/passport/adaptertest/adaptertest.go b/cmfx/user/passport/adaptertest/adaptertest.go index bd927260..d76c943b 100644 --- a/cmfx/user/passport/adaptertest/adaptertest.go +++ b/cmfx/user/passport/adaptertest/adaptertest.go @@ -16,11 +16,18 @@ import ( // Run 测试 p 的基本功能 func Run(a *assert.Assertion, p passport.Adapter) { // Add + a.NotError(p.Add(1024, "1024", "1024", time.Now())) a.ErrorIs(p.Add(1024, "1024", "1024", time.Now()), passport.ErrUIDExists()) a.ErrorIs(p.Add(1000, "1024", "1024", time.Now()), passport.ErrIdentityExists()) + a.NotError(p.Add(0, "2025", "2025", time.Now())) + a.NotError(p.Add(0, "2026", "2026", time.Now())) + a.ErrorIs(p.Add(111, "1024", "1024", time.Now()), passport.ErrIdentityExists()) // 1024 已经有 uid + a.NotError(p.Add(2025, "2025", "2025", time.Now())) // 将 "2025" 关联 uid + // Valid + uid, identity, err := p.Valid("1024", "1024", time.Now()) a.NotError(err).Equal(identity, "1024").Equal(uid, 1024) uid, identity, err = p.Valid("1024", "pass", time.Now()) // 密码错误 @@ -29,6 +36,7 @@ func Run(a *assert.Assertion, p passport.Adapter) { a.Equal(err, passport.ErrUnauthorized()).Equal(identity, "").Equal(uid, 0) // Change + a.ErrorIs(p.Change(1025, "1024", "1024"), passport.ErrUIDNotExists()) a.ErrorIs(p.Change(1024, "1025", "1024"), passport.ErrUnauthorized()) a.NotError(p.Change(1024, "1024", "1025")) @@ -36,12 +44,14 @@ func Run(a *assert.Assertion, p passport.Adapter) { a.NotError(err).Equal(identity, "1024").Equal(uid, 1024) // Identity + identity, err = p.Identity(1024) a.NotError(err).Equal(identity, "1024") identity, err = p.Identity(10240) a.Equal(err, passport.ErrUIDNotExists()).Empty(identity) // Delete + a.NotError(p.Delete(1024)). NotError(p.Delete(1024)) // 多次删除 uid, identity, err = p.Valid("1024", "1025", time.Now()) diff --git a/cmfx/user/passport/code/code.go b/cmfx/user/passport/code/code.go index 0cbb4e46..dba30b91 100644 --- a/cmfx/user/passport/code/code.go +++ b/cmfx/user/passport/code/code.go @@ -6,7 +6,6 @@ package code import ( - "database/sql" "time" "github.com/issue9/orm/v6" @@ -38,7 +37,7 @@ func New(mod *cmfx.Module, expired time.Duration, tableName string, sender Sende } func (e *code) Delete(uid int64) error { - _, err := e.db.Delete(&modelCode{UID: uid}) + _, err := e.db.Where("uid=?", uid).Delete(&modelCode{}) return err } @@ -57,18 +56,22 @@ func (e *code) Valid(identity, code string, now time.Time) (int64, string, error } func (e *code) Identity(uid int64) (string, error) { - mod := &modelCode{UID: uid} - found, err := e.db.Select(mod) + mod := &modelCode{} + size, err := e.db.Where("uid=?", uid).Select(true, mod) if err != nil { return "", err } - if !found { + if size == 0 { return "", passport.ErrUIDNotExists() } return mod.Identity, nil } func (e *code) Change(uid int64, pass, code string) error { + if uid == 0 { + return passport.ErrUIDMustBeGreatThanZero() + } + m := e.getModel(uid) if m == nil { return passport.ErrUIDNotExists() @@ -77,23 +80,27 @@ func (e *code) Change(uid int64, pass, code string) error { return passport.ErrUnauthorized() } - return e.set(uid, m.Identity, code) + return e.set(m.Identity, code) } func (e *code) Set(uid int64, code string) error { + if uid == 0 { + return passport.ErrUIDMustBeGreatThanZero() + } + m := e.getModel(uid) if m == nil { return passport.ErrUIDNotExists() } - return e.set(uid, m.Identity, code) + return e.set(m.Identity, code) } -func (e *code) set(uid int64, identity, code string) error { - if _, err := e.db.Update(&modelCode{UID: uid, Code: code}, "code", "uid"); err != nil { +func (e *code) set(identity, code string) error { + if _, err := e.db.Update(&modelCode{Identity: identity, Code: code}, "code"); err != nil { return err } - return e.sender.Send(identity, code) + return e.sender.Sent(identity, code) } // Add 注册新用户 @@ -104,33 +111,41 @@ func (e *code) Add(uid int64, identity, code string, now time.Time) error { return passport.ErrInvalidIdentity() } - if e.getModel(uid) != nil { + if uid > 0 && e.getModel(uid) != nil { return passport.ErrUIDExists() } - m := &modelCode{Identity: identity} - f, err := e.db.Select(m) + mod := &modelCode{Identity: identity} + found, err := e.db.Select(mod) if err != nil { return err } - if f { - return passport.ErrIdentityExists() + if found { + if mod.UID > 0 { + return passport.ErrIdentityExists() + } + + _, err = e.db.Update(&modelCode{ + UID: uid, + Identity: identity, + Code: code, + }) + } else { + _, err = e.db.Insert(&modelCode{ + Created: now, + Expired: now.Add(e.expired), + Identity: identity, + UID: uid, + Code: code, + }) } - _, err = e.db.Insert(&modelCode{ - Created: now, - Expired: now.Add(e.expired), - Verified: sql.NullTime{}, - Identity: identity, - UID: uid, - Code: code, - }) return err } func (e *code) getModel(uid int64) *modelCode { - m := &modelCode{UID: uid} - if f, err := e.db.Select(m); err == nil && f { + m := &modelCode{} + if f, err := e.db.Where("uid=?", uid).Select(true, m); err == nil && f > 0 { return m } return nil diff --git a/cmfx/user/passport/code/models.go b/cmfx/user/passport/code/models.go index e8611a04..8f7a18a6 100644 --- a/cmfx/user/passport/code/models.go +++ b/cmfx/user/passport/code/models.go @@ -22,7 +22,7 @@ type modelCode struct { Verified sql.NullTime `orm:"name(verified);nullable"` // 验证时间 Identity string `orm:"name(identity);len(500);unique(identity)"` // 接收者,手机号、邮箱等。 Code string `orm:"name(code);len(8)"` // 验证码 - UID int64 `orm:"name(uid);nullable;unique(uid)"` // 关联的 UID,可以为空 + UID int64 `orm:"name(uid);default(0)"` // 关联的 UID,可以为空 } func (l *modelCode) TableName() string { return `` } diff --git a/cmfx/user/passport/code/sender.go b/cmfx/user/passport/code/sender.go index 00106aec..e35fa456 100644 --- a/cmfx/user/passport/code/sender.go +++ b/cmfx/user/passport/code/sender.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/issue9/errwrap" + "github.com/issue9/mux/v8/header" "github.com/issue9/webfilter/validator" ) @@ -17,11 +18,11 @@ type Sender interface { // ValidIdentity 验证接收地址的格式是否正确 ValidIdentity(string) bool - // Send 发送验证码 + // Sent 发送验证码 // // target 为接收验证码的目标,比如邮箱地址或是手机号码等; // code 为发送的验证码; - Send(target, code string) error + Sent(target, code string) error } // 验证码的占位符 @@ -36,15 +37,24 @@ type smtpSender struct { auth smtp.Auth } -// NewSMTP 基于 SMTP 的 [Sender] 实现 +type emptySender struct{} + +// NewEmptySender 一个空的 [Sender] 实现 +func NewEmptySender() Sender { return &emptySender{} } + +func (s *emptySender) ValidIdentity(_ string) bool { return true } + +func (s *emptySender) Sent(_, _ string) error { return nil } + +// NewSMTPSender 基于 SMTP 的 [Sender] 实现 // // subject 为发送邮件的主题; // addr 为 smtp 的主机地址,需要带上端口号; // template 为邮件模板,可以有一个占位符 %%code%%; -func NewSMTP(subject, addr, from, template string, auth smtp.Auth) Sender { +func NewSMTPSender(subject, addr, from, template string, auth smtp.Auth) Sender { b := errwrap.Buffer{} b.Grow(1024) - b.WString("From: ").WString(from).WString("\r\n"). + b.WString(header.From).WByte(' ').WString(from).WString("\r\n"). WString("Subject: ").WString(subject).WString("\r\n"). WString("MIME-Version: ").WString("1.0\r\n"). WString(`Content-Type: text/plain; charset="utf-8"`) @@ -66,8 +76,8 @@ func NewSMTP(subject, addr, from, template string, auth smtp.Auth) Sender { func (s *smtpSender) ValidIdentity(identity string) bool { return validator.Email(identity) } -// Send 发送邮件 -func (s *smtpSender) Send(email, code string) error { +// Sent 发送邮件 +func (s *smtpSender) Sent(email, code string) error { b := errwrap.Buffer{} b.Grow(1024) b.WString(s.head). diff --git a/cmfx/user/passport/code/sender_test.go b/cmfx/user/passport/code/sender_test.go index 5326d987..cb52de20 100644 --- a/cmfx/user/passport/code/sender_test.go +++ b/cmfx/user/passport/code/sender_test.go @@ -11,6 +11,6 @@ var ( type sender struct{} -func (s *sender) Send(_, _ string) error { return nil } +func (s *sender) Sent(_, _ string) error { return nil } func (s *sender) ValidIdentity(id string) bool { return true } diff --git a/cmfx/user/passport/errors.go b/cmfx/user/passport/errors.go index 41540f1b..5c93cd4c 100644 --- a/cmfx/user/passport/errors.go +++ b/cmfx/user/passport/errors.go @@ -15,6 +15,8 @@ var ( errInvalidIdentity = web.NewLocaleError("invalid indetity format") ) +func ErrUIDMustBeGreatThanZero() error { return web.NewLocaleError("uid must be great than 0") } + func ErrIdentityExists() error { return errIdentityExists } func ErrUIDExists() error { return errUIDExists } diff --git a/cmfx/user/passport/oauth/oauth.go b/cmfx/user/passport/oauth/oauth.go index ef73961d..637d03dc 100644 --- a/cmfx/user/passport/oauth/oauth.go +++ b/cmfx/user/passport/oauth/oauth.go @@ -25,6 +25,7 @@ import ( "github.com/issue9/cmfx/cmfx" "github.com/issue9/cmfx/cmfx/user/passport" + "github.com/issue9/cmfx/locales" ) // UserInfo 表示 OAuth 登录后获取的用户信息 @@ -110,15 +111,15 @@ func (o *OAuth[T]) Identity(uid int64) (string, error) { return mod.Identity, nil } -func (o *OAuth[T]) Change(uid int64, pass, n string) error { - return errors.ErrUnsupported -} +func (o *OAuth[T]) Change(_ int64, _, _ string) error { return errors.ErrUnsupported } -func (o *OAuth[T]) Set(uid int64, n string) error { - return errors.ErrUnsupported -} +func (o *OAuth[T]) Set(_ int64, _ string) error { return errors.ErrUnsupported } func (o *OAuth[T]) Add(uid int64, identity, _ string, now time.Time) error { + if uid == 0 { + return locales.ErrMustBeGreaterThan(0) + } + _, err := o.db.Insert(&modelOAuth{ Created: now, UID: uid, diff --git a/cmfx/user/passport/passport.go b/cmfx/user/passport/passport.go index d76f5404..f5b0c391 100644 --- a/cmfx/user/passport/passport.go +++ b/cmfx/user/passport/passport.go @@ -23,9 +23,9 @@ type Passport struct { } type adapter struct { - id string - name web.LocaleStringer - auth Adapter + id string + name web.LocaleStringer + adapter Adapter } // New 声明 [Passport] 对象 @@ -39,25 +39,30 @@ func New(mod *cmfx.Module) *Passport { // Register 注册 [Adapter] // // id 为适配器的类型名称,需要唯一; -// name 为该验证器的本地化名称; +// name 为该适配器的本地化名称; func (p *Passport) Register(id string, auth Adapter, name web.LocaleStringer) { if _, found := p.adapters[id]; found { panic(fmt.Sprintf("已经存在同名 %s 的验证器", id)) } p.adapters[id] = &adapter{ - id: id, - name: name, - auth: auth, + id: id, + name: name, + adapter: auth, } } +// Get 返回注册的适配器 +// +// 如果找不到,则返回 nil。 +func (p *Passport) Get(id string) Adapter { return p.adapters[id].adapter } + // Valid 验证账号密码 // -// id 表示通过 [Passport.Register] 注册验证器时的 id; +// id 表示通过 [Passport.Register] 注册适配器时的 id; func (p *Passport) Valid(id, identity, password string, now time.Time) (int64, string, bool) { if info, found := p.adapters[id]; found { - uid, ident, err := info.auth.Valid(identity, password, now) + uid, ident, err := info.adapter.Valid(identity, password, now) switch { case errors.Is(err, ErrUnauthorized()): return 0, "", false @@ -71,7 +76,7 @@ func (p *Passport) Valid(id, identity, password string, now time.Time) (int64, s return 0, "", false } -// All 返回所有的验证器 +// All 返回所有的适配器对象 func (p *Passport) All(printer *message.Printer) map[string]string { m := make(map[string]string, len(p.adapters)) for _, i := range p.adapters { @@ -80,13 +85,13 @@ func (p *Passport) All(printer *message.Printer) map[string]string { return m } -// Identities 获取 uid 已经关联的验证器 +// Identities 获取 uid 已经关联的适配器 // -// 返回值键名为验证器 id,键值为该验证器对应的账号。 +// 返回值键名为验证器 id,键值为该适配器对应的账号。 func (p *Passport) Identities(uid int64) map[string]string { m := make(map[string]string, len(p.adapters)) for _, info := range p.adapters { - if identity, err := info.auth.Identity(uid); err == nil { + if identity, err := info.adapter.Identity(uid); err == nil { m[info.id] = identity } else { p.mod.Server().Logs().ERROR().Error(err) diff --git a/cmfx/user/passport/passport_test.go b/cmfx/user/passport/passport_test.go index ff9a1d8d..2210aade 100644 --- a/cmfx/user/passport/passport_test.go +++ b/cmfx/user/passport/passport_test.go @@ -25,39 +25,41 @@ func TestPassport(t *testing.T) { password.Install(mod1) password.Install(mod2) - auth := passport.New(suite.Module()) - a.NotNil(auth) - a.Length(auth.All(suite.Module().Server().Locale().Printer()), 0) + p := passport.New(suite.Module()) + a.NotNil(p). + Length(p.All(suite.Module().Server().Locale().Printer()), 0) // Register - var p1 passport.Adapter = password.New(mod1, 5) - auth.Register("p1", p1, web.Phrase("password")) + var p1 = password.New(mod1, 5) + p.Register("p1", p1, web.Phrase("password")) + a.Equal(p.Get("p1"), p1) p2 := password.New(mod2, 5) - auth.Register("p2", p2, web.Phrase("password")) + p.Register("p2", p2, web.Phrase("password")) + a.Equal(p.Get("p2"), p2) a.PanicString(func() { - auth.Register("p1", p1, web.Phrase("password")) + p.Register("p1", p1, web.Phrase("password")) }, "已经存在同名 p1 的验证器") - a.Length(auth.All(suite.Module().Server().Locale().Printer()), 2) + a.Length(p.All(suite.Module().Server().Locale().Printer()), 2) // Valid / Identities - uid, identity, ok := auth.Valid("p1", "1024", "1024", time.Now()) + uid, identity, ok := p.Valid("p1", "1024", "1024", time.Now()) a.False(ok).Equal(identity, "").Zero(uid) - a.Empty(auth.Identities(1024)) + a.Empty(p.Identities(1024)) // p1.Add - p1.Add(1024, "1024", "1024", time.Now()) - uid, identity, ok = auth.Valid("p1", "1024", "1024", time.Now()) + a.NotError(p1.Add(1024, "1024", "1024", time.Now())) + uid, identity, ok = p.Valid("p1", "1024", "1024", time.Now()) a.True(ok).Equal(identity, "1024").Equal(uid, 1024) - a.Equal(auth.Identities(1024), map[string]string{"p1": "1024"}) + a.Equal(p.Identities(1024), map[string]string{"p1": "1024"}) // p2.Add - p2.Add(1024, "1024", "1024", time.Now()) - uid, identity, ok = auth.Valid("p2", "1024", "not match", time.Now()) + a.NotError(p2.Add(1024, "1024", "1024", time.Now())) + uid, identity, ok = p.Valid("p2", "1024", "not match", time.Now()) a.Zero(identity).Zero(uid).False(ok) - a.Equal(auth.Identities(1024), map[string]string{"p1": "1024", "p2": "1024"}) + a.Equal(p.Identities(1024), map[string]string{"p1": "1024", "p2": "1024"}) } diff --git a/cmfx/user/passport/password/models.go b/cmfx/user/passport/password/models.go index 4a515f05..167a01ef 100644 --- a/cmfx/user/passport/password/models.go +++ b/cmfx/user/passport/password/models.go @@ -11,14 +11,9 @@ type modelPassword struct { Created time.Time `orm:"name(created)"` Updated time.Time `orm:"name(updated)"` - UID int64 `orm:"name(uid);unique(uid)"` + UID int64 `orm:"name(uid);default(0)"` Identity string `orm:"name(identity);len(32);unique(identity)"` Password []byte `orm:"name(password);len(64)"` } func (p *modelPassword) TableName() string { return `_auth_passwords` } - -func (p *modelPassword) BeforeUpdate() error { - p.Updated = time.Now() - return nil -} diff --git a/cmfx/user/passport/password/password.go b/cmfx/user/passport/password/password.go index fa76e1ae..ad096167 100644 --- a/cmfx/user/passport/password/password.go +++ b/cmfx/user/passport/password/password.go @@ -42,88 +42,105 @@ func (p *password) Add(uid int64, identity, pass string, now time.Time) error { if err != nil { return err } - if n > 0 { + if uid > 0 && n > 0 { return passport.ErrUIDExists() } - n, err = db.Where("identity=?", identity).Count(&modelPassword{}) + mod := &modelPassword{Identity: identity} + found, err := db.Select(mod) if err != nil { return err } - if n > 0 { - return passport.ErrIdentityExists() - } pa, err := bcrypt.GenerateFromPassword([]byte(pass), p.cost) if err != nil { return err } - _, err = db.Insert(&modelPassword{ - Created: now, - Updated: now, - UID: uid, - Identity: identity, - Password: pa, - }) + + if found { + if mod.UID > 0 { + return passport.ErrIdentityExists() + } + + _, err = db.Update(&modelPassword{ + Updated: now, + UID: uid, + Identity: identity, + Password: pa, + }) + } else { + _, err = db.Insert(&modelPassword{ + Created: now, + Updated: now, + UID: uid, + Identity: identity, + Password: pa, + }) + } + return err } // Delete 删除关联的密码信息 func (p *password) Delete(uid int64) error { - _, err := p.mod.DB().Delete(&modelPassword{UID: uid}) + _, err := p.mod.DB().Where("uid=?", uid).Delete(&modelPassword{}) return err } // Set 强制修改密码 func (p *password) Set(uid int64, pass string) error { - found, err := p.mod.DB().Select(&modelPassword{UID: uid}) + if uid == 0 { + return passport.ErrUIDMustBeGreatThanZero() + } + + mod := &modelPassword{} + size, err := p.mod.DB().Where("uid=?", uid).Select(true, mod) if err != nil { return err } - if !found { + if size == 0 { return passport.ErrUIDNotExists() } - return p.set(uid, pass) + return p.set(mod.Identity, pass) } -func (p *password) set(uid int64, pass string) error { +func (p *password) set(identity, pass string) error { pa, err := bcrypt.GenerateFromPassword([]byte(pass), p.cost) - if err != nil { - return err + if err == nil { + _, err = p.mod.DB().Update(&modelPassword{Identity: identity, Password: pa}) } - - _, err = p.mod.DB().Update(&modelPassword{ - UID: uid, - Password: pa, - }) return err } // Change 验证并修改 func (p *password) Change(uid int64, old, pass string) error { - pp := &modelPassword{UID: uid} - found, err := p.mod.DB().Select(pp) + if uid == 0 { + return passport.ErrUIDMustBeGreatThanZero() + } + + mod := &modelPassword{UID: uid} + size, err := p.mod.DB().Where("uid=?", uid).Select(true, mod) if err != nil { return err } - if !found { + if size == 0 { return passport.ErrUIDNotExists() } - err = bcrypt.CompareHashAndPassword(pp.Password, []byte(old)) + err = bcrypt.CompareHashAndPassword(mod.Password, []byte(old)) switch { case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword): return passport.ErrUnauthorized() case err != nil: return err default: - return p.set(uid, pass) + return p.set(mod.Identity, pass) } } func (p *password) Valid(username, pass string, _ time.Time) (int64, string, error) { - pp := &modelPassword{Identity: username} - found, err := p.mod.DB().Select(pp) + mod := &modelPassword{Identity: username} + found, err := p.mod.DB().Select(mod) if err != nil { return 0, "", err } @@ -131,28 +148,28 @@ func (p *password) Valid(username, pass string, _ time.Time) (int64, string, err return 0, "", passport.ErrUnauthorized() } - err = bcrypt.CompareHashAndPassword(pp.Password, []byte(pass)) + err = bcrypt.CompareHashAndPassword(mod.Password, []byte(pass)) switch { case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword): return 0, "", passport.ErrUnauthorized() case err != nil: return 0, "", err default: - return pp.UID, pp.Identity, nil + return mod.UID, mod.Identity, nil } } func (p *password) Identity(uid int64) (string, error) { - pp := &modelPassword{UID: uid} - found, err := p.mod.DB().Select(pp) + mod := &modelPassword{} + size, err := p.mod.DB().Where("uid=?", uid).Select(true, mod) if err != nil { return "", err } - if !found { + if size == 0 { return "", passport.ErrUIDNotExists() } - return pp.Identity, nil + return mod.Identity, nil } func validIdentity(id string) bool { @@ -170,4 +187,5 @@ func validIdentity(id string) bool { } func isAlpha(r rune) bool { return r >= 'A' && r <= 'Z' || r >= 'a' && r <= 'z' } + func isDigit(r rune) bool { return r >= '0' && r <= '9' } diff --git a/cmfx/user/securitylog.go b/cmfx/user/securitylog.go index 24ced397..b7545dce 100644 --- a/cmfx/user/securitylog.go +++ b/cmfx/user/securitylog.go @@ -17,7 +17,7 @@ import ( // AddSecurityLog 添加一条记录 // // tx 如果为空,表示由 AddSecurityLog 直接提交数据; -func (m *Loader) AddSecurityLog(tx *orm.Tx, uid int64, ip, ua, content string) error { +func (m *Module) AddSecurityLog(tx *orm.Tx, uid int64, ip, ua, content string) error { _, err := m.Module().Engine(tx).Insert(&modelLog{ UID: uid, Content: content, @@ -27,7 +27,7 @@ func (m *Loader) AddSecurityLog(tx *orm.Tx, uid int64, ip, ua, content string) e return err } -func (m *Loader) AddSecurityLogFromContext(tx *orm.Tx, uid int64, ctx *web.Context, content string) error { +func (m *Module) AddSecurityLogFromContext(tx *orm.Tx, uid int64, ctx *web.Context, content string) error { return m.AddSecurityLog(tx, uid, ctx.ClientIP(), ctx.Request().UserAgent(), content) } @@ -38,13 +38,13 @@ type queryLog struct { } // GetSecurityLogs 将数据以固定的格式输出客户端 -func (m *Loader) GetSecurityLogs(ctx *web.Context) web.Responser { +func (m *Module) GetSecurityLogs(ctx *web.Context) web.Responser { u := m.CurrentUser(ctx) return m.getSecurityLogs(u.ID, ctx) } // getSecurityLogs 将数据以固定的格式输出客户端 -func (m *Loader) getSecurityLogs(uid int64, ctx *web.Context) web.Responser { +func (m *Module) getSecurityLogs(uid int64, ctx *web.Context) web.Responser { q := &queryLog{} if rslt := ctx.QueryObject(true, q, cmfx.BadRequestInvalidQuery); rslt != nil { return rslt diff --git a/cmfx/user/token.go b/cmfx/user/token.go index e0315b22..fa52f380 100644 --- a/cmfx/user/token.go +++ b/cmfx/user/token.go @@ -5,8 +5,8 @@ package user import ( - "errors" "net/http" + "time" "github.com/issue9/mux/v8/header" "github.com/issue9/orm/v6" @@ -24,8 +24,9 @@ type tokens = token.Token[*User] type AfterFunc = func(*User) type reqAccount struct { - Username string `json:"username" xml:"username" yaml:"username" cbor:"username"` - Password string `json:"password" xml:"password" yaml:"password" cbor:"password"` + XMLName struct{} `xml:"account" json:"-" cbor:"-"` + Username string `json:"username" xml:"username" yaml:"username" cbor:"username"` + Password string `json:"password" xml:"password" yaml:"password" cbor:"password"` } func (c *reqAccount) Filter(v *web.FilterContext) { @@ -38,13 +39,13 @@ func (c *reqAccount) Filter(v *web.FilterContext) { // 如果状态为非 [StateNormal],那么也将会被禁止登录。 // // NOTE: 需要保证 u.ID、u.State 和 u.NO 是有效的。 -func (m *Loader) SetState(tx *orm.Tx, u *User, s State) (err error) { +func (m *Module) SetState(tx *orm.Tx, u *User, s State) (err error) { if u.State == s { return nil } if s != StateNormal { - err = m.token.Delete(u) + err = m.token.Delete(u) // 用到 User.NO } if err == nil { @@ -53,82 +54,70 @@ func (m *Loader) SetState(tx *orm.Tx, u *User, s State) (err error) { return err } -// Login 如果 reg 不为空,表示在验证成功,但是不存在用户数是执行注册服务,其原型如下: +// Login 执行登录操作并在成功的情况下发放新的令牌 // -// func(tx *orm.Tx, uid int64, identity string) error +// 如果 reg 不为空,表示在验证成功,但是不存在用户数是执行注册服务,其原型如下: // -// tx 为事务接口,uid 为新用户的 uid,identity 为验证成功时返回的对应值。 -// 如果返回 error,将取消整个事务。 -func (m *Loader) Login(typ string, ctx *web.Context, reg func(*orm.Tx, int64, string) error, after AfterFunc) web.Responser { +// func( uid int64) error +// +// uid 为新用户的 uid。 +func (m *Module) Login(typ string, ctx *web.Context, reg func(int64) error, after AfterFunc) web.Responser { data := &reqAccount{} if resp := ctx.Read(true, data, cmfx.BadRequestInvalidBody); resp != nil { return resp } - // 密码错误 uid, identity, ok := m.passport.Valid(typ, data.Username, data.Password, ctx.Begin()) - if !ok { - return ctx.Problem(cmfx.Unauthorized) + if !ok { // 密码或账号错误 + return ctx.Problem(cmfx.UnauthorizedInvalidAccount) } // 注册 if uid == 0 && reg != nil { - tx, err := m.mod.DB().Begin() - if err != nil { + var err error + if uid, err = m.NewUser(m.Passport().Get(typ), identity, data.Password, ctx.Begin()); err != nil { return ctx.Error(err, "") } - e := tx.NewEngine(m.mod.DB().TablePrefix()) - - a := &User{NO: ctx.Server().UniqueID()} - uid, err = e.LastInsertID(a) - if err != nil { - return ctx.Error(errors.Join(err, tx.Rollback()), "") - } - if err := reg(tx, uid, identity); err != nil { - return ctx.Error(errors.Join(err, tx.Rollback()), "") + if err := reg(uid); err != nil { + return ctx.Error(err, "") } msg := web.StringPhrase("auto register").LocaleString(ctx.LocalePrinter()) - if err := m.AddSecurityLogFromContext(tx, a.ID, ctx, msg); err != nil { // 记录日志出错不回滚 + if err := m.AddSecurityLogFromContext(nil, uid, ctx, msg); err != nil { // 记录日志出错不回滚 ctx.Server().Logs().ERROR().Error(err) } - - if err = tx.Commit(); err != nil { - return ctx.Error(err, "") - } } - a := &User{ID: uid} - found, err := m.mod.DB().Select(a) + u := &User{ID: uid} + found, err := m.mod.DB().Select(u) if err != nil { return ctx.Error(err, "") } if !found { - ctx.Logs().DEBUG().Printf("用户名 %v 不存在\n", data.Username) - return ctx.Problem(cmfx.UnauthorizedRegistrable).WithExtensions(&struct { - Identity string `json:"identity" xml:"identity" yaml:"identity"` - }{Identity: identity}) + ctx.Logs().DEBUG().Printf("数据库不同步,%s 存在于适配器 %s,但是不存在于用户列表数据库\n", data.Username, typ) + return ctx.Problem(cmfx.UnauthorizedInvalidAccount) } - if a.State != StateNormal { + if u.State != StateNormal { return ctx.Problem(cmfx.UnauthorizedInvalidState) } - if err := m.AddSecurityLogFromContext(nil, a.ID, ctx, "登录"); err != nil { + msg := web.Phrase("login").LocaleString(ctx.LocalePrinter()) + if err := m.AddSecurityLogFromContext(nil, u.ID, ctx, msg); err != nil { ctx.Server().Logs().ERROR().Error(err) } if after != nil { - after(a) + after(u) } - return m.token.New(ctx, a, http.StatusCreated) + return m.token.New(ctx, u, http.StatusCreated) } // Logout 退出 -func (m *Loader) Logout(ctx *web.Context, after AfterFunc, reason web.LocaleStringer) web.Responser { +func (m *Module) Logout(ctx *web.Context, after AfterFunc, reason web.LocaleStringer) web.Responser { u := m.CurrentUser(ctx) // 先拿到用户数据再执行 logout if err := m.token.Logout(ctx); err != nil { @@ -148,7 +137,7 @@ func (m *Loader) Logout(ctx *web.Context, after AfterFunc, reason web.LocaleStri } // RefreshToken 刷新令牌 -func (m *Loader) RefreshToken(ctx *web.Context) web.Responser { +func (m *Module) RefreshToken(ctx *web.Context) web.Responser { u := m.CurrentUser(ctx) if u == nil { return web.Status(http.StatusUnauthorized) @@ -163,17 +152,34 @@ func (m *Loader) RefreshToken(ctx *web.Context) web.Responser { } // Passport 管理验证登录信息 -func (m *Loader) Passport() *passport.Passport { return m.passport } +func (m *Module) Passport() *passport.Passport { return m.passport } // Middleware 验证是否登录 -func (m *Loader) Middleware(next web.HandlerFunc) web.HandlerFunc { return m.token.Middleware(next) } +func (m *Module) Middleware(next web.HandlerFunc) web.HandlerFunc { return m.token.Middleware(next) } // CurrentUser 获取当前登录的用户信息 // -// 该信息由 [Loader.Middleware] 存储在 [web.Context.vars] 之中。 -func (m *Loader) CurrentUser(ctx *web.Context) *User { +// 该信息由 [Module.Middleware] 存储在 [web.Context.vars] 之中。 +func (m *Module) CurrentUser(ctx *web.Context) *User { if u, found := m.token.GetInfo(ctx); found { return u } panic("未检测到登录用户") // 未登录账号,不应该到达此处,在中间件部分应该已经被拒绝。 } + +// NewUser 添加新用户 +// +// pa 为注册用户的验证方式 +func (m *Module) NewUser(pa passport.Adapter, identity, password string, t time.Time) (int64, error) { + u := &User{NO: m.Module().Server().UniqueID()} + uid, err := m.mod.DB().LastInsertID(u) + if err != nil { + return 0, err + } + + if err = pa.Add(uid, identity, password, t); err != nil { + return 0, err + } + + return uid, nil +} diff --git a/cmfx/user/token_test.go b/cmfx/user/token_test.go new file mode 100644 index 00000000..402ee0f3 --- /dev/null +++ b/cmfx/user/token_test.go @@ -0,0 +1,159 @@ +// SPDX-FileCopyrightText: 2024 caixw +// +// SPDX-License-Identifier: MIT + +package user + +import ( + "bytes" + "encoding/json" + "net/http" + "strconv" + "testing" + "time" + + "github.com/issue9/assert/v4" + "github.com/issue9/mux/v8/header" + "github.com/issue9/web" + "github.com/issue9/web/server/servertest" + "github.com/issue9/webuse/v7/middlewares/auth" + "github.com/issue9/webuse/v7/middlewares/auth/token" + + "github.com/issue9/cmfx/cmfx/initial/test" + "github.com/issue9/cmfx/cmfx/user/passport/code" +) + +func TestLoader_Login(t *testing.T) { + a := assert.New(t, false) + s := test.NewSuite(a) + u := newLoader(s) + + // 添加用于测试的验证码验证 + code.Install(u.Module(), "_code") + pc := code.New(u.Module(), time.Second, "_code", code.NewEmptySender()) + u.Passport().Register("code", pc, web.Phrase("code")) + a.NotError(pc.Add(0, "new", "password", time.Now())) + + s.Module().Router().Post("/login", func(ctx *web.Context) web.Responser { + q, err := ctx.Queries(true) + if err != nil { + return ctx.Error(err, "") + } + + switch q.String("type", "password") { + case "password": + output := &bytes.Buffer{} + resp := u.Login("password", ctx, func(id int64) error { + _, err := output.WriteString(strconv.FormatInt(id, 10)) + return err + }, func(_ *User) { + output.WriteString("after") + }) + + a.NotNil(resp).Equal(output.String(), "after") // 用户已经存在 + return resp + case "code": + output := &bytes.Buffer{} + resp := u.Login("code", ctx, func(id int64) error { + _, err := output.WriteString(strconv.FormatInt(id, 10)) + return err + }, func(_ *User) { output.WriteString("after") }) + + a.NotNil(resp).Equal(output.String(), "2after") // 注册的新用户 + return resp + default: + return ctx.NotImplemented() + } + }) + + // 测试 SetState + s.Module().Router().Post("/state", u.Middleware(func(ctx *web.Context) web.Responser { + user := u.CurrentUser(ctx) + a.NotError(u.SetState(nil, user, StateNormal)) + a.NotError(u.SetState(nil, user, StateLocked)) + return web.NoContent() + })) + + s.Module().Router().Post("/refresh", u.Middleware(func(ctx *web.Context) web.Responser { + return u.RefreshToken(ctx) + })) + + s.Module().Router().Get("/info", u.Middleware(func(ctx *web.Context) web.Responser { + return web.OK(u.CurrentUser(ctx)) + })) + + s.Module().Router().Delete("/login", u.Middleware(func(ctx *web.Context) web.Responser { + return u.Logout(ctx, nil, web.Phrase("reason")) + })) + + defer servertest.Run(a, s.Module().Server())() + defer s.Close() + + //--------------------------- user 1 ------------------------------------- + + tk1 := &token.Response{} + s.Post("/login", []byte(`{"username":"admin","password":"password"}`)). + Header(header.Accept, header.JSON). + Header(header.ContentType, header.JSON+"; charset=utf-8"). + Do(nil). + Status(http.StatusCreated). + BodyFunc(func(a *assert.Assertion, body []byte) { a.NotError(json.Unmarshal(body, tk1)) }) + + s.Post("/state", nil). + Header(header.Accept, header.JSON). + Header(header.ContentType, header.JSON+"; charset=utf-8"). + Header(header.Authorization, auth.BuildToken(auth.Bearer, tk1.AccessToken)). + Do(nil). + Status(http.StatusNoContent) + + // 状态已改变 + s.Get("/info"). + Header(header.Accept, header.JSON). + Header(header.ContentType, header.JSON+"; charset=utf-8"). + Header(header.Authorization, auth.BuildToken(auth.Bearer, tk1.AccessToken)). + Do(nil). + Status(http.StatusUnauthorized) + + //--------------------------- user 2 ------------------------------------- + + tk1 = &token.Response{} + s.Post("/login?type=code", []byte(`{"username":"new","password":"password"}`)). + Header(header.Accept, header.JSON). + Header(header.ContentType, header.JSON+"; charset=utf-8"). + Do(nil). + Status(http.StatusCreated). + BodyFunc(func(a *assert.Assertion, body []byte) { a.NotError(json.Unmarshal(body, tk1)) }) + + // 正常 + s.Get("/info"). + Header(header.Accept, header.JSON). + Header(header.ContentType, header.JSON+"; charset=utf-8"). + Header(header.Authorization, auth.BuildToken(auth.Bearer, tk1.AccessToken)). + Do(nil). + Status(http.StatusOK) + + tk2 := &token.Response{} + s.Post("/refresh", nil). + Header(header.Accept, header.JSON). + Header(header.ContentType, header.JSON+"; charset=utf-8"). + Header(header.Authorization, auth.BuildToken(auth.Bearer, tk1.RefreshToken)). + Do(nil). + Status(http.StatusCreated). + BodyFunc(func(a *assert.Assertion, body []byte) { a.NotError(json.Unmarshal(body, tk2)) }) + a.NotEqual(tk1.AccessToken, tk2.AccessToken). + NotEqual(tk1.AccessToken, tk2.AccessToken) + + // 退出 tk2 + s.Delete("/login"). + Header(header.Accept, header.JSON). + Header(header.Authorization, auth.BuildToken(auth.Bearer, tk2.AccessToken)). + Do(nil). + Status(http.StatusNoContent) + + // tk2 已退出 + s.Get("/info"). + Header(header.Accept, header.JSON). + Header(header.Authorization, auth.BuildToken(auth.Bearer, tk2.AccessToken)). + Do(nil). + Status(http.StatusUnauthorized) +} diff --git a/locales/locales.go b/locales/locales.go index b6a016d8..7b11056d 100644 --- a/locales/locales.go +++ b/locales/locales.go @@ -37,3 +37,11 @@ func MustBeGreaterThan[T any](v T) web.LocaleStringer { func MustBeLessThan[T any](v T) web.LocaleStringer { return web.Phrase("must be less than %v", v) } + +func ErrMustBeGreaterThan[T any](v T) error { + return web.NewLocaleError("must be greater than %v", v) +} + +func ErrMustBeLessThan[T any](v T) error { + return web.NewLocaleError("must be less than %v", v) +}