Skip to content

Commit

Permalink
Merge pull request #2 from BrandonRoehl/master
Browse files Browse the repository at this point in the history
Switch to use go channels
  • Loading branch information
BrandonRoehl authored Apr 1, 2019
2 parents f61ba94 + db86117 commit 298b70a
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 282 deletions.
241 changes: 130 additions & 111 deletions dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"io"
"reflect"
"strings"
"sync"
"text/template"
"time"
)
Expand All @@ -27,15 +26,17 @@ type Data struct {
headerTmpl *template.Template
tableTmpl *template.Template
footerTmpl *template.Template
mux sync.Mutex
wg sync.WaitGroup
err error
}

type table struct {
Name string
SQL string
Values []string
Name string
Err error

data *Data
rows *sql.Rows
types []reflect.Type
values []interface{}
}

type metaData struct {
Expand All @@ -44,8 +45,9 @@ type metaData struct {
CompleteTime string
}

const version = "0.3.5"
const version = "0.4.0"

// takes a *metaData
const headerTmpl = `-- Go SQL Dump {{ .DumpVersion }}
--
-- ------------------------------------------------------
Expand All @@ -63,46 +65,46 @@ const headerTmpl = `-- Go SQL Dump {{ .DumpVersion }}
/*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */;
`

// takes a *metaData
const footerTmpl = `/*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */;
/*!40101 SET SQL_MODE=@OLD_SQL_MODE */;
/*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */;
/*!40014 SET UNIQUE_CHECKS=@OLD_UNIQUE_CHECKS */;
/*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */;
/*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */;
/*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */;
/*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */;
-- Dump completed on {{ .CompleteTime }}
`

// Takes a *table
const tableTmpl = `
--
-- Table structure for table {{ .Name }}
-- Table structure for table {{ .NameEsc }}
--
DROP TABLE IF EXISTS {{ .Name }};
DROP TABLE IF EXISTS {{ .NameEsc }};
/*!40101 SET @saved_cs_client = @@character_set_client */;
SET character_set_client = utf8mb4 ;
{{ .SQL }};
{{ .CreateSQL }};
/*!40101 SET character_set_client = @saved_cs_client */;
--
-- Dumping data for table {{ .Name }}
-- Dumping data for table {{ .NameEsc }}
--
LOCK TABLES {{ .Name }} WRITE;
/*!40000 ALTER TABLE {{ .Name }} DISABLE KEYS */;
{{- if .Values }}
INSERT INTO {{ .Name }} VALUES
{{- range $index, $element := .Values -}}
{{- if $index }},{{ else }} {{ end -}}{{ $element }}
{{- end -}};
LOCK TABLES {{ .NameEsc }} WRITE;
/*!40000 ALTER TABLE {{ .NameEsc }} DISABLE KEYS */;
{{- if .Next }}
INSERT INTO {{ .NameEsc }} VALUES {{ .RowValues }}
{{- range $value := .Stream }},{{ $value }}{{ end -}};
{{- end }}
/*!40000 ALTER TABLE {{ .Name }} ENABLE KEYS */;
/*!40000 ALTER TABLE {{ .NameEsc }} ENABLE KEYS */;
UNLOCK TABLES;
`

const footerTmpl = `/*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */;
/*!40101 SET SQL_MODE=@OLD_SQL_MODE */;
/*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */;
/*!40014 SET UNIQUE_CHECKS=@OLD_UNIQUE_CHECKS */;
/*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */;
/*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */;
/*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */;
/*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */;
-- Dump completed on {{ .CompleteTime }}
`

const nullType = "NULL"

// Dump data using struct
Expand All @@ -128,13 +130,11 @@ func (data *Data) Dump() error {
return err
}

data.wg.Add(len(tables))
for _, name := range tables {
if err := data.dumpTable(name); err != nil {
return err
}
}
data.wg.Wait()
if data.err != nil {
return data.err
}
Expand All @@ -156,24 +156,14 @@ func (data *Data) dumpTable(name string) error {
return err
}

go data.writeTable(table)
return nil
return data.writeTable(table)
}

func (data *Data) writeTable(table *table) {
// Keep a counter of how many tables have been written
defer data.wg.Done()

// Force this method into serial
data.mux.Lock()
defer data.mux.Unlock()

if data.err != nil {
return
} else if err := data.tableTmpl.Execute(data.Out, table); err != nil {
data.err = err
func (data *Data) writeTable(table *table) error {
if err := data.tableTmpl.Execute(data.Out, table); err != nil {
return err
}
return
return table.Err
}

// MARK: get methods
Expand Down Expand Up @@ -237,112 +227,141 @@ func (data *metaData) updateServerVersion(db *sql.DB) (err error) {
// MARK: create methods

func (data *Data) createTable(name string) (*table, error) {
var err error
t := &table{Name: "`" + name + "`"}

if t.SQL, err = data.createTableSQL(name); err != nil {
return nil, err
}

if t.Values, err = data.createTableValues(name); err != nil {
return nil, err
t := &table{
Name: name,
data: data,
}

return t, nil
}

func (data *Data) createTableSQL(name string) (string, error) {
var tableReturn, tableSQL sql.NullString
err := data.Connection.QueryRow("SHOW CREATE TABLE `"+name+"`").Scan(&tableReturn, &tableSQL)
func (table *table) NameEsc() string {
return "`" + table.Name + "`"
}

if err != nil {
func (table *table) CreateSQL() (string, error) {
var tableReturn, tableSQL sql.NullString
if err := table.data.Connection.QueryRow("SHOW CREATE TABLE "+table.NameEsc()).Scan(&tableReturn, &tableSQL); err != nil {
return "", err
}
if tableReturn.String != name {

if tableReturn.String != table.Name {
return "", errors.New("Returned table is not the same as requested table")
}

return tableSQL.String, nil
}

func (data *Data) createTableValues(name string) ([]string, error) {
rows, err := data.Connection.Query("SELECT * FROM `" + name + "`")
// defer rows.Close()
func (table *table) Init() (err error) {
if len(table.types) != 0 {
return errors.New("can't init twice")
}

table.rows, err = table.data.Connection.Query("SELECT * FROM " + table.NameEsc())
if err != nil {
return nil, err
return err
}
defer rows.Close()

columns, err := rows.Columns()
columns, err := table.rows.Columns()
if err != nil {
return nil, err
return err
}
if len(columns) == 0 {
return nil, errors.New("No columns in table " + name + ".")
return errors.New("No columns in table " + table.Name + ".")
}

dataText := make([]string, 0)
tt, err := rows.ColumnTypes()
tt, err := table.rows.ColumnTypes()
if err != nil {
return nil, err
return err
}

types := make([]reflect.Type, len(tt))
table.types = make([]reflect.Type, len(tt))
for i, tp := range tt {
st := tp.ScanType()
if tp.DatabaseTypeName() == "BLOB" {
types[i] = reflect.TypeOf(sql.RawBytes{})
table.types[i] = reflect.TypeOf(sql.RawBytes{})
} else if st != nil && (st.Kind() == reflect.Int ||
st.Kind() == reflect.Int8 ||
st.Kind() == reflect.Int16 ||
st.Kind() == reflect.Int32 ||
st.Kind() == reflect.Int64) {
types[i] = reflect.TypeOf(sql.NullInt64{})
table.types[i] = reflect.TypeOf(sql.NullInt64{})
} else {
types[i] = reflect.TypeOf(sql.NullString{})
table.types[i] = reflect.TypeOf(sql.NullString{})
}
}
values := make([]interface{}, len(tt))
for i := range values {
values[i] = reflect.New(types[i]).Interface()
table.values = make([]interface{}, len(tt))
for i := range table.values {
table.values[i] = reflect.New(table.types[i]).Interface()
}
for rows.Next() {
if err := rows.Scan(values...); err != nil {
return dataText, err
return nil
}

func (table *table) Next() bool {
if table.rows == nil {
if err := table.Init(); err != nil {
table.Err = err
return false
}
}
// Fallthrough
if table.rows.Next() {
if err := table.rows.Scan(table.values...); err != nil {
table.Err = err
return false
} else if err := table.rows.Err(); err != nil {
table.Err = err
return false
}
} else {
table.rows.Close()
table.rows = nil
return false
}
return true
}

dataStrings := make([]string, len(columns))
func (table *table) RowValues() string {
dataStrings := make([]string, len(table.values))

for key, value := range values {
if value == nil {
for key, value := range table.values {
switch s := value.(type) {
case nil:
dataStrings[key] = nullType
case *sql.NullString:
if s.Valid {
dataStrings[key] = "'" + sanitize(s.String) + "'"
} else {
dataStrings[key] = nullType
}
case *sql.NullInt64:
if s.Valid {
dataStrings[key] = fmt.Sprintf("%d", s.Int64)
} else {
dataStrings[key] = nullType
}
case *sql.RawBytes:
if len(*s) == 0 {
dataStrings[key] = nullType
} else {
switch s := value.(type) {
case *sql.NullString:
if s.Valid {
dataStrings[key] = "'" + sanitize(s.String) + "'"
} else {
dataStrings[key] = nullType
}
case *sql.NullInt64:
if s.Valid {
dataStrings[key] = fmt.Sprintf("%d", s.Int64)
} else {
dataStrings[key] = nullType
}
case *sql.RawBytes:
if len(*s) == 0 {
dataStrings[key] = nullType
} else {
dataStrings[key] = "_binary '" + sanitize(string(*s)) + "'"
}
default:
dataStrings[key] = fmt.Sprint("'", value, "'")
}
dataStrings[key] = "_binary '" + sanitize(string(*s)) + "'"
}
default:
dataStrings[key] = fmt.Sprint("'", value, "'")
}

dataText = append(dataText, "("+strings.Join(dataStrings, ",")+")")
}

return dataText, rows.Err()
return "(" + strings.Join(dataStrings, ",") + ")"
}

func (table *table) Stream() <-chan string {
valueOut := make(chan string, 1)
go func(out chan string) {
defer close(out)
for table.Next() {
out <- table.RowValues()
}
}(valueOut)
return valueOut
}
Loading

0 comments on commit 298b70a

Please sign in to comment.