Skip to content

Commit

Permalink
gormschema: support view (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
luantranminh authored May 17, 2024
1 parent ed0d712 commit 797a13b
Show file tree
Hide file tree
Showing 23 changed files with 270 additions and 51 deletions.
66 changes: 63 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,66 @@ env "gorm" {
}
```

### Extra Features

#### Views

> Note: Views are available for logged-in users, run `atlas login` if you haven't already. To learn more about logged-in features for Atlas, visit [Feature Availability](https://atlasgo.io/features#database-features).
To define a Go struct as a database `VIEW`, implement the `ViewDef` method as follow:
```go
// User is a regular gorm.Model stored in the "users" table.
type User struct {
gorm.Model
Name string
Age int
}

// WorkingAgedUsers is mapped to the VIEW definition below.
type WorkingAgedUsers struct {
Name string
Age int
}

func (WorkingAgedUsers) ViewDef(dialect string) []gormschema.ViewOption {
return []gormschema.ViewOption{
gormschema.BuildStmt(func(db *gorm.DB) *gorm.DB {
return db.Model(&User{}).Where("age BETWEEN 18 AND 65").Select("name, age")
}),
}
}
```
In order to pass a plain `CREATE VIEW` statement, use the `CreateStmt` as follows:
```go
type BotlTracker struct {
ID uint
Name string
}

func (BotlTracker) ViewDef(dialect string) []gormschema.ViewOption {
var stmt string
switch dialect {
case "mysql":
stmt = "CREATE VIEW botl_trackers AS SELECT id, name FROM pets WHERE name LIKE 'botl%'"
}
return []gormschema.ViewOption{
gormschema.CreateStmt(stmt),
}
}
```
To include both VIEWs and TABLEs in the migration generation, pass all models to the `Load` function:
```go
stmts, err := gormschema.New("mysql").Load(
&models.User{}, // Table-based model.
&models.WorkingAgedUsers{}, // View-based model.
)
```
The view-based model works just like a regular models in GORM queries. However, make sure the view name is identical to the struct name, and in case they are differ, configure the name using the `TableName` method:
```go
func (WorkingAgedUsers) TableName() string {
return "working_aged_users_custom_name" // View name is different than pluralized struct name.
}
```
### Additional Configuration

To supply custom `gorm.Config{}` object to the provider use the [Go Program Mode](#as-go-file) with
Expand Down Expand Up @@ -210,9 +270,9 @@ The provider supports the following databases:
you should use the following code:
```go
stmts, err := gormschema.New("mysql",
gormschema.WithJoinTable(
&Models.Person{}, "Addresses", &Models.PersonAddress{},
),
gormschema.WithJoinTable(
&Models.Person{}, "Addresses", &Models.PersonAddress{},
),
).Load(&Models.Address{}, &Models.Person{})
```

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
ariga.io/atlas-go-sdk v0.2.3
github.com/alecthomas/kong v0.7.1
github.com/stretchr/testify v1.8.4
golang.org/x/tools v0.17.0
golang.org/x/tools v0.18.0
gorm.io/driver/mysql v1.5.1
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.5.2
Expand Down
11 changes: 6 additions & 5 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand All @@ -136,8 +136,9 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
Expand All @@ -157,8 +158,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc=
golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps=
golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ=
golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
Expand Down
110 changes: 96 additions & 14 deletions gormschema/gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
"slices"

"ariga.io/atlas-go-sdk/recordriver"
Expand Down Expand Up @@ -45,6 +46,18 @@ func WithConfig(cfg *gorm.Config) Option {

// Load loads the models and returns the DDL statements representing the schema.
func (l *Loader) Load(models ...any) (string, error) {
var (
views []ViewDefiner
tables []any
)
for _, obj := range models {
switch view := obj.(type) {
case ViewDefiner:
views = append(views, view)
default:
tables = append(tables, obj)
}
}
var di gorm.Dialector
switch l.dialect {
case "sqlite":
Expand Down Expand Up @@ -91,21 +104,22 @@ func (l *Loader) Load(models ...any) (string, error) {
return "", err
}
}
if err = db.AutoMigrate(models...); err != nil {
if err = db.AutoMigrate(tables...); err != nil {
return "", err
}
db, err = gorm.Open(dialector{Dialector: di}, l.config)
if err != nil {
return "", err
}
cm, ok := db.Migrator().(*migrator)
if !ok {
return "", fmt.Errorf("unexpected migrator type: %T", db.Migrator())
}
if err = cm.CreateViews(views); err != nil {
return "", err
}
if !l.config.DisableForeignKeyConstraintWhenMigrating && l.dialect != "sqlite" {
db, err = gorm.Open(dialector{
Dialector: di,
}, l.config)
if err != nil {
return "", err
}
cm, ok := db.Migrator().(*migrator)
if !ok {
return "", err
}
if err = cm.CreateConstraints(models); err != nil {
if err = cm.CreateConstraints(tables); err != nil {
return "", err
}
}
Expand All @@ -125,8 +139,8 @@ type dialector struct {
gorm.Dialector
}

// Migrator returns a new gorm.Migrator which can be used to automatically create all Constraints
// on existing tables.
// Migrator returns a new gorm.Migrator, which can be used to extend the default migrator,
// helping to create constraints and views ...
func (d dialector) Migrator(db *gorm.DB) gorm.Migrator {
return &migrator{
Migrator: gormig.Migrator{
Expand Down Expand Up @@ -179,6 +193,29 @@ func (m *migrator) CreateConstraints(models []any) error {
return nil
}

// CreateViews creates the given "view-based" models
func (m *migrator) CreateViews(views []ViewDefiner) error {
for _, view := range views {
viewName := m.DB.Config.NamingStrategy.TableName(indirect(reflect.TypeOf(view)).Name())
if namer, ok := view.(interface {
TableName() string
}); ok {
viewName = namer.TableName()
}
viewBuilder := &viewBuilder{
db: m.DB,
viewName: viewName,
}
for _, opt := range view.ViewDef(m.Dialector.Name()) {
opt(viewBuilder)
}
if err := m.DB.Exec(viewBuilder.createStmt).Error; err != nil {
return err
}
}
return nil
}

// WithJoinTable sets up a join table for the given model and field.
func WithJoinTable(model any, field string, jointable any) Option {
return func(l *Loader) {
Expand All @@ -187,3 +224,48 @@ func WithJoinTable(model any, field string, jointable any) Option {
})
}
}

func indirect(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}

type (
// ViewOption configures a viewBuilder.
ViewOption func(*viewBuilder)
// ViewDefiner defines a view.
ViewDefiner interface {
ViewDef(dialect string) []ViewOption
}
viewBuilder struct {
db *gorm.DB
createStmt string
// viewName is only used for the BuildStmt option.
// BuildStmt returns only a subquery; viewName helps to create a full CREATE VIEW statement.
viewName string
}
)

// CreateStmt accepts raw SQL to create a CREATE VIEW statement.
func CreateStmt(stmt string) ViewOption {
return func(b *viewBuilder) {
b.createStmt = b.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Exec(stmt)
})
}
}

// BuildStmt accepts a function with gorm query builder to create a CREATE VIEW statement.
// With this option, the view's name will be the same as the model's table name
func BuildStmt(fn func(db *gorm.DB) *gorm.DB) ViewOption {
return func(b *viewBuilder) {
vd := b.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return fn(tx).
Unscoped(). // Skip gorm deleted_at filtering.
Find(nil) // Execute the query and convert it to SQL.
})
b.createStmt = fmt.Sprintf("CREATE VIEW %s AS %s", b.viewName, vd)
}
}
31 changes: 16 additions & 15 deletions gormschema/gorm_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package gormschema
package gormschema_test

import (
"os"
"testing"

"ariga.io/atlas-go-sdk/recordriver"
"ariga.io/atlas-provider-gorm/gormschema"
ckmodels "ariga.io/atlas-provider-gorm/internal/testdata/circularfks"
"ariga.io/atlas-provider-gorm/internal/testdata/customjointable"
"ariga.io/atlas-provider-gorm/internal/testdata/models"
Expand All @@ -14,12 +15,12 @@ import (

func TestSQLiteConfig(t *testing.T) {
resetSession()
l := New("sqlite")
sql, err := l.Load(models.Pet{}, models.User{}, ckmodels.Event{}, ckmodels.Location{})
l := gormschema.New("sqlite")
sql, err := l.Load(models.WorkingAgedUsers{}, models.Pet{}, models.User{}, ckmodels.Event{}, ckmodels.Location{}, models.TopPetOwner{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/sqlite_default")
resetSession()
l = New("sqlite", WithConfig(&gorm.Config{
l = gormschema.New("sqlite", gormschema.WithConfig(&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
}))
sql, err = l.Load(models.Pet{}, models.User{})
Expand All @@ -30,12 +31,12 @@ func TestSQLiteConfig(t *testing.T) {

func TestPostgreSQLConfig(t *testing.T) {
resetSession()
l := New("postgres")
sql, err := l.Load(ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{})
l := gormschema.New("postgres")
sql, err := l.Load(models.WorkingAgedUsers{}, ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/postgresql_default")
resetSession()
l = New("postgres", WithConfig(
l = gormschema.New("postgres", gormschema.WithConfig(
&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
}))
Expand All @@ -46,12 +47,12 @@ func TestPostgreSQLConfig(t *testing.T) {

func TestMySQLConfig(t *testing.T) {
resetSession()
l := New("mysql")
sql, err := l.Load(ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{})
l := gormschema.New("mysql")
sql, err := l.Load(models.WorkingAgedUsers{}, ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/mysql_default")
resetSession()
l = New("mysql", WithConfig(
l = gormschema.New("mysql", gormschema.WithConfig(
&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
},
Expand All @@ -60,20 +61,20 @@ func TestMySQLConfig(t *testing.T) {
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/mysql_no_fk")
resetSession()
l = New("mysql", WithJoinTable(&customjointable.Person{}, "Addresses", &customjointable.PersonAddress{}))
sql, err = l.Load(customjointable.Address{}, customjointable.Person{})
l = gormschema.New("mysql", gormschema.WithJoinTable(&customjointable.Person{}, "Addresses", &customjointable.PersonAddress{}))
sql, err = l.Load(customjointable.Address{}, customjointable.Person{}, customjointable.TopCrowdedAddresses{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/mysql_custom_join_table")
}

func TestSQLServerConfig(t *testing.T) {
resetSession()
l := New("sqlserver")
sql, err := l.Load(ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{})
l := gormschema.New("sqlserver")
sql, err := l.Load(models.WorkingAgedUsers{}, ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/sqlserver_default")
resetSession()
l = New("sqlserver", WithConfig(
l = gormschema.New("sqlserver", gormschema.WithConfig(
&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
}))
Expand Down
1 change: 1 addition & 0 deletions gormschema/testdata/mysql_custom_join_table
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
CREATE TABLE `addresses` (`id` bigint AUTO_INCREMENT,`name` longtext,PRIMARY KEY (`id`));
CREATE TABLE `people` (`id` bigint AUTO_INCREMENT,`name` longtext,PRIMARY KEY (`id`));
CREATE TABLE `person_addresses` (`person_id` bigint,`address_id` bigint,`created_at` datetime(3) NULL,`deleted_at` datetime(3) NULL,PRIMARY KEY (`person_id`,`address_id`));
CREATE VIEW top_crowded_addresses AS SELECT address_id, COUNT(person_id) AS count FROM person_addresses GROUP BY address_id ORDER BY count DESC LIMIT 10;
ALTER TABLE `person_addresses` ADD CONSTRAINT `fk_person_addresses_address` FOREIGN KEY (`address_id`) REFERENCES `addresses`(`id`);
ALTER TABLE `person_addresses` ADD CONSTRAINT `fk_person_addresses_person` FOREIGN KEY (`person_id`) REFERENCES `people`(`id`);
4 changes: 3 additions & 1 deletion gormschema/testdata/mysql_default
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
CREATE TABLE `events` (`eventId` varchar(191),`locationId` varchar(191),PRIMARY KEY (`eventId`),UNIQUE INDEX `idx_events_location_id` (`locationId`));
CREATE TABLE `locations` (`locationId` varchar(191),`eventId` varchar(191),PRIMARY KEY (`locationId`),UNIQUE INDEX `idx_locations_event_id` (`eventId`));
CREATE TABLE `users` (`id` bigint unsigned AUTO_INCREMENT,`created_at` datetime(3) NULL,`updated_at` datetime(3) NULL,`deleted_at` datetime(3) NULL,`name` longtext,PRIMARY KEY (`id`),INDEX `idx_users_deleted_at` (`deleted_at`));
CREATE TABLE `users` (`id` bigint unsigned AUTO_INCREMENT,`created_at` datetime(3) NULL,`updated_at` datetime(3) NULL,`deleted_at` datetime(3) NULL,`name` longtext,`age` bigint,PRIMARY KEY (`id`),INDEX `idx_users_deleted_at` (`deleted_at`));
CREATE TABLE `pets` (`id` bigint unsigned AUTO_INCREMENT,`created_at` datetime(3) NULL,`updated_at` datetime(3) NULL,`deleted_at` datetime(3) NULL,`name` longtext,`user_id` bigint unsigned,PRIMARY KEY (`id`),INDEX `idx_pets_deleted_at` (`deleted_at`));
CREATE VIEW working_aged_users AS SELECT name, age FROM `users` WHERE age BETWEEN 18 AND 65;
CREATE VIEW top_pet_owners AS SELECT user_id, COUNT(id) AS pet_count FROM pets GROUP BY user_id ORDER BY pet_count DESC LIMIT 10;
ALTER TABLE `events` ADD CONSTRAINT `fk_locations_event` FOREIGN KEY (`locationId`) REFERENCES `locations`(`locationId`);
ALTER TABLE `locations` ADD CONSTRAINT `fk_events_location` FOREIGN KEY (`eventId`) REFERENCES `events`(`eventId`);
ALTER TABLE `pets` ADD CONSTRAINT `fk_users_pets` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`);
4 changes: 3 additions & 1 deletion gormschema/testdata/postgresql_default
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ CREATE TABLE "events" ("eventId" varchar(191),"locationId" varchar(191),PRIMARY
CREATE UNIQUE INDEX IF NOT EXISTS "idx_events_location_id" ON "events" ("locationId");
CREATE TABLE "locations" ("locationId" varchar(191),"eventId" varchar(191),PRIMARY KEY ("locationId"));
CREATE UNIQUE INDEX IF NOT EXISTS "idx_locations_event_id" ON "locations" ("eventId");
CREATE TABLE "users" ("id" bigserial,"created_at" timestamptz,"updated_at" timestamptz,"deleted_at" timestamptz,"name" text,PRIMARY KEY ("id"));
CREATE TABLE "users" ("id" bigserial,"created_at" timestamptz,"updated_at" timestamptz,"deleted_at" timestamptz,"name" text,"age" bigint,PRIMARY KEY ("id"));
CREATE INDEX IF NOT EXISTS "idx_users_deleted_at" ON "users" ("deleted_at");
CREATE TABLE "pets" ("id" bigserial,"created_at" timestamptz,"updated_at" timestamptz,"deleted_at" timestamptz,"name" text,"user_id" bigint,PRIMARY KEY ("id"));
CREATE INDEX IF NOT EXISTS "idx_pets_deleted_at" ON "pets" ("deleted_at");
CREATE VIEW working_aged_users AS SELECT name, age FROM "users" WHERE age BETWEEN 18 AND 65;
CREATE VIEW top_pet_owners AS SELECT user_id, COUNT(id) AS pet_count FROM pets GROUP BY user_id ORDER BY pet_count DESC LIMIT 10;
ALTER TABLE "events" ADD CONSTRAINT "fk_locations_event" FOREIGN KEY ("locationId") REFERENCES "locations"("locationId");
ALTER TABLE "locations" ADD CONSTRAINT "fk_events_location" FOREIGN KEY ("eventId") REFERENCES "events"("eventId");
ALTER TABLE "pets" ADD CONSTRAINT "fk_users_pets" FOREIGN KEY ("user_id") REFERENCES "users"("id");
Loading

0 comments on commit 797a13b

Please sign in to comment.