Skip to content

Commit

Permalink
Fixes #69: support customizable table name mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
qiangxue committed Jan 7, 2020
1 parent 5df468d commit 6442c82
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 13 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,15 @@ If you want to use a different field as the primary key, tag it with `db:"pk"`.
for composite primary keys. Note that if you also want to explicitly specify the column name for a primary key field,
you should use the tag format `db:"pk,col_name"`.

You can give a common prefix or suffix to your table names by defining your own table name mapping via
`DB.TableMapFunc`. For example, the following code prefixes `tbl_` to all table names.

```go
db.TableMapper = func(a interface{}) string {
return "tbl_" + GetTableName(a)
}
```

### Create

To create (insert) a new row using a model, call the `ModelQuery.Insert()` method. For example,
Expand Down
4 changes: 4 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ type (

// FieldMapper maps struct fields to DB columns. Defaults to DefaultFieldMapFunc.
FieldMapper FieldMapFunc
// TableMapper maps structs to table names. Defaults to GetTableName.
TableMapper TableMapFunc
// LogFunc logs the SQL statements being executed. Defaults to nil, meaning no logging.
LogFunc LogFunc
// PerfFunc logs the SQL execution time. Defaults to nil, meaning no performance profiling.
Expand Down Expand Up @@ -85,6 +87,7 @@ func NewFromDB(sqlDB *sql.DB, driverName string) *DB {
driverName: driverName,
sqlDB: sqlDB,
FieldMapper: DefaultFieldMapFunc,
TableMapper: GetTableName,
}
db.Builder = db.newBuilder(db.sqlDB)
return db
Expand Down Expand Up @@ -121,6 +124,7 @@ func (db *DB) Clone() *DB {
driverName: db.driverName,
sqlDB: db.sqlDB,
FieldMapper: db.FieldMapper,
TableMapper: db.TableMapper,
PerfFunc: db.PerfFunc,
LogFunc: db.LogFunc,
QueryLogFunc: db.QueryLogFunc,
Expand Down
18 changes: 18 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,21 @@ func getPreparedDB() *DB {
}
return db
}

// Naming according to issue 49 ( https://github.com/go-ozzo/ozzo-dbx/issues/49 )

type ArtistDAO struct {
nickname string
}

func (ArtistDAO) TableName() string {
return "artists"
}

func Test_TableNameWithPrefix(t *testing.T) {
db := NewFromDB(nil, "mysql")
db.TableMapper = func(a interface{}) string {
return "tbl_" + GetTableName(a)
}
assert.Equal(t, "tbl_artists", db.TableMapper(ArtistDAO{}))
}
2 changes: 1 addition & 1 deletion model_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func NewModelQuery(model interface{}, fieldMapFunc FieldMapFunc, db *DB, builder
db: db,
ctx: db.ctx,
builder: builder,
model: newStructValue(model, fieldMapFunc),
model: newStructValue(model, fieldMapFunc, db.TableMapper),
}
if q.model == nil {
q.lastError = VarTypeError("must be a pointer to a struct representing the model")
Expand Down
7 changes: 5 additions & 2 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
type SelectQuery struct {
// FieldMapper maps struct field names to DB column names.
FieldMapper FieldMapFunc
// TableMapper maps structs to DB table names.
TableMapper TableMapFunc

builder Builder
ctx context.Context
Expand Down Expand Up @@ -61,6 +63,7 @@ func NewSelectQuery(builder Builder, db *DB) *SelectQuery {
params: Params{},
ctx: db.ctx,
FieldMapper: db.FieldMapper,
TableMapper: db.TableMapper,
}
}

Expand Down Expand Up @@ -286,7 +289,7 @@ func (s *SelectQuery) Build() *Query {
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
func (s *SelectQuery) One(a interface{}) error {
if len(s.from) == 0 {
if tableName := GetTableName(a); tableName != "" {
if tableName := s.TableMapper(a); tableName != "" {
s.from = []string{tableName}
}
}
Expand Down Expand Up @@ -327,7 +330,7 @@ func (s *SelectQuery) Model(pk, model interface{}) error {
// or the TableName() method if the slice element implements the TableModel interface.
func (s *SelectQuery) All(slice interface{}) error {
if len(s.from) == 0 {
if tableName := GetTableName(slice); tableName != "" {
if tableName := s.TableMapper(slice); tableName != "" {
s.from = []string{tableName}
}
}
Expand Down
14 changes: 9 additions & 5 deletions struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ type (
// FieldMapFunc converts a struct field name into a DB column name.
FieldMapFunc func(string) string

// TableMapFunc converts a sample struct into a DB table name.
TableMapFunc func(a interface{}) string

structInfo struct {
nameMap map[string]*fieldInfo // mapping from struct field names to field infos
dbNameMap map[string]*fieldInfo // mapping from db column names to field infos
Expand Down Expand Up @@ -77,16 +80,16 @@ func getStructInfo(a reflect.Type, mapper FieldMapFunc) *structInfo {
return si
}

func newStructValue(model interface{}, mapper FieldMapFunc) *structValue {
func newStructValue(model interface{}, fieldMapFunc FieldMapFunc, tableMapFunc TableMapFunc) *structValue {
value := reflect.ValueOf(model)
if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Struct || value.IsNil() {
return nil
}

return &structValue{
structInfo: getStructInfo(reflect.TypeOf(model).Elem(), mapper),
structInfo: getStructInfo(reflect.TypeOf(model).Elem(), fieldMapFunc),
value: value.Elem(),
tableName: GetTableName(model),
tableName: tableMapFunc(model),
}
}

Expand Down Expand Up @@ -246,8 +249,9 @@ func indirect(v reflect.Value) reflect.Value {
return v
}

// GetTableName returns the table name corresponding to the given model struct or slice of structs.
// Do not call this method in the model's TableName() method, or it will cause infinite loop.
// GetTableName implements the default way of determining the table name corresponding to the given model struct
// or slice of structs. To get the actual table name for a model, you should use DB.TableMapFunc() instead.
// Do not call this method in a model's TableName() method because it will cause infinite loop.
func GetTableName(a interface{}) string {
if tm, ok := a.(TableModel); ok {
v := reflect.ValueOf(a)
Expand Down
10 changes: 5 additions & 5 deletions struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func Test_structValue_columns(t *testing.T) {
Status: 2,
Email: "[email protected]",
}
sv := newStructValue(&customer, DefaultFieldMapFunc)
sv := newStructValue(&customer, DefaultFieldMapFunc, GetTableName)
cols := sv.columns(nil, nil)
assert.Equal(t, map[string]interface{}{"id": 1, "name": "abc", "status": 2, "email": "[email protected]", "address": sql.NullString{}}, cols)

Expand All @@ -87,7 +87,7 @@ func Test_structValue_columns(t *testing.T) {
cols = sv.columns(nil, []string{"ID", "Address"})
assert.Equal(t, map[string]interface{}{"name": "abc", "status": 2, "email": "[email protected]"}, cols)

sv = newStructValue(&customer, nil)
sv = newStructValue(&customer, nil, GetTableName)
cols = sv.columns([]string{"ID", "Name"}, []string{"ID"})
assert.Equal(t, map[string]interface{}{"Name": "abc"}, cols)
}
Expand All @@ -103,22 +103,22 @@ func TestIssue37(t *testing.T) {
Customer
Status string
}{customer, "20"}
sv := newStructValue(&ev, nil)
sv := newStructValue(&ev, nil, GetTableName)
cols := sv.columns([]string{"ID", "Status"}, nil)
assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols)

ev2 := struct {
Status string
Customer
}{"20", customer}
sv = newStructValue(&ev2, nil)
sv = newStructValue(&ev2, nil, GetTableName)
cols = sv.columns([]string{"ID", "Status"}, nil)
assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols)
}

type MyCustomer struct{}

func Test_getTableName(t *testing.T) {
func TestGetTableName(t *testing.T) {
var c1 Customer
assert.Equal(t, "customer", GetTableName(c1))

Expand Down

0 comments on commit 6442c82

Please sign in to comment.