diff --git a/integration_tests/mysql/main.go b/integration_tests/mysql/main.go index dbfa4896..69ee85f2 100644 --- a/integration_tests/mysql/main.go +++ b/integration_tests/mysql/main.go @@ -76,9 +76,13 @@ const testTypesCreateTableQuery = ` CREATE TABLE %s ( pk INTEGER PRIMARY KEY NOT NULL, c_tinyint TINYINT, + c_tinyint_unsigned TINYINT UNSIGNED, c_smallint SMALLINT, + c_smallint_unsigned SMALLINT UNSIGNED, c_mediumint MEDIUMINT, + c_mediumint_unsigned MEDIUMINT UNSIGNED, c_int INT, + c_int_unsigned INT(15) UNSIGNED, c_bigint BIGINT, c_decimal DECIMAL(7, 5), c_numeric NUMERIC(5, 3), @@ -110,12 +114,20 @@ INSERT INTO %s VALUES ( 1, -- c_tinyint 1, + -- c_smallint_unsigned + 2, -- c_smallint 2, + -- c_smallint_unsigned + 3, -- c_mediumint 3, + -- c_mediumint_unsigned + 4, -- c_int 4, + -- c_int_unsigned + 55, -- c_bigint 5, -- c_decimal @@ -186,6 +198,14 @@ const expectedPayloadTemplate = `{ "name": "", "parameters": null }, + { + "type": "int16", + "optional": false, + "default": null, + "field": "c_tinyint_unsigned", + "name": "", + "parameters": null + }, { "type": "int16", "optional": false, @@ -194,6 +214,14 @@ const expectedPayloadTemplate = `{ "name": "", "parameters": null }, + { + "type": "int32", + "optional": false, + "default": null, + "field": "c_smallint_unsigned", + "name": "", + "parameters": null + }, { "type": "int32", "optional": false, @@ -202,6 +230,14 @@ const expectedPayloadTemplate = `{ "name": "", "parameters": null }, + { + "type": "int32", + "optional": false, + "default": null, + "field": "c_mediumint_unsigned", + "name": "", + "parameters": null + }, { "type": "int32", "optional": false, @@ -210,6 +246,14 @@ const expectedPayloadTemplate = `{ "name": "", "parameters": null }, + { + "type": "int64", + "optional": false, + "default": null, + "field": "c_int_unsigned", + "name": "", + "parameters": null + }, { "type": "int64", "optional": false, @@ -415,15 +459,19 @@ const expectedPayloadTemplate = `{ "c_enum": "medium", "c_float": 90.123, "c_int": 4, + "c_int_unsigned": 55, "c_json": "{\"key1\": \"value1\", \"key2\": \"value2\"}", "c_mediumint": 3, + "c_mediumint_unsigned": 4, "c_numeric": "AN3M", "c_set": "one,two", "c_smallint": 2, + "c_smallint_unsigned": 3, "c_text": "ZXCV", "c_time": 14706000000, "c_timestamp": "2001-02-03T04:05:06Z", "c_tinyint": 1, + "c_tinyint_unsigned": 2, "c_varbinary": "Qk5N", "c_varchar": "GHJKL", "c_year": 2001, diff --git a/lib/mysql/schema/schema.go b/lib/mysql/schema/schema.go index 3a5a0f3d..51c6e368 100644 --- a/lib/mysql/schema/schema.go +++ b/lib/mysql/schema/schema.go @@ -106,12 +106,22 @@ func DescribeTable(db *sql.DB, table string) ([]Column, error) { return result, nil } -func parseColumnDataType(s string) (DataType, *Opts, error) { +func parseColumnDataType(originalS string) (DataType, *Opts, error) { + // Preserve the original value, so we can return the error message without the actual value being mutated. + s := originalS var metadata string + var unsigned bool + if strings.HasSuffix(s, " unsigned") { + // If a number is unsigned, we'll bump them up by one (e.g. int32 -> int64) + unsigned = true + s = strings.TrimSuffix(s, " unsigned") + } + parenIndex := strings.Index(s, "(") if parenIndex != -1 { if s[len(s)-1] != ')' { - return -1, nil, fmt.Errorf("malformed data type: %s", s) + // Make sure the format looks like int (n) unsigned + return -1, nil, fmt.Errorf("malformed data type: %s", originalS) } metadata = s[parenIndex+1 : len(s)-1] s = s[:parenIndex] @@ -124,12 +134,28 @@ func parseColumnDataType(s string) (DataType, *Opts, error) { return Boolean, nil, nil } + if unsigned { + return SmallInt, nil, nil + } + return TinyInt, nil, nil case "smallint": + if unsigned { + return Int, nil, nil + } + return SmallInt, nil, nil case "mediumint": + if unsigned { + return Int, nil, nil + } + return MediumInt, nil, nil case "int": + if unsigned { + return BigInt, nil, nil + } + return Int, nil, nil case "bigint": return BigInt, nil, nil @@ -194,7 +220,7 @@ func parseColumnDataType(s string) (DataType, *Opts, error) { case "json": return JSON, nil, nil default: - return -1, nil, fmt.Errorf("unknown data type: %s", s) + return -1, nil, fmt.Errorf("unknown data type: %s", originalS) } } diff --git a/lib/mysql/schema/schema_test.go b/lib/mysql/schema/schema_test.go index 8ba61207..d1458f3b 100644 --- a/lib/mysql/schema/schema_test.go +++ b/lib/mysql/schema/schema_test.go @@ -41,6 +41,35 @@ func TestParseColumnDataType(t *testing.T) { expectedType: Decimal, expectedOpts: &Opts{Precision: ptr.ToInt(5), Scale: ptr.ToInt(2)}, }, + { + input: "int(10) unsigned", + expectedType: BigInt, + expectedOpts: nil, + }, + { + input: "tinyint unsigned", + expectedType: SmallInt, + expectedOpts: nil, + }, + { + input: "smallint unsigned", + expectedType: Int, + expectedOpts: nil, + }, + { + input: "mediumint unsigned", + expectedType: Int, + expectedOpts: nil, + }, + { + input: "int unsigned", + expectedType: BigInt, + expectedOpts: nil, + }, + { + input: "int(10 unsigned", + expectedErr: "malformed data type: int(10 ", + }, { input: "foo", expectedErr: "unknown data type: foo",