diff --git a/Gemfile b/Gemfile index 73aed3bc..4fbef2ba 100644 --- a/Gemfile +++ b/Gemfile @@ -2,5 +2,5 @@ source 'https://rubygems.org' gemspec -gem 'activerecord', '3.2.22.1' +gem 'activerecord', '4.2.7.1' gem 'pry', '0.10.4' diff --git a/lib/active_record/connection_adapters/odbc_adapter.rb b/lib/active_record/connection_adapters/odbc_adapter.rb index 10881dde..567ff408 100644 --- a/lib/active_record/connection_adapters/odbc_adapter.rb +++ b/lib/active_record/connection_adapters/odbc_adapter.rb @@ -74,6 +74,7 @@ class ODBCAdapter < AbstractAdapter include ::ODBCAdapter::SchemaStatements ADAPTER_NAME = 'ODBC'.freeze + ERR_DUPLICATE_KEY_VALUE = 23505 attr_reader :dbms @@ -115,13 +116,69 @@ def reconnect! else ODBC::Database.new.drvconnect(options[:driver]) end + super end + alias :reset! :reconnect! # Disconnects from the database if already connected. Otherwise, this # method does nothing. def disconnect! @connection.disconnect if @connection.connected? end + + protected + + def initialize_type_map(map) + map.register_type ODBC::SQL_BIT, Type::Boolean.new + map.register_type ODBC::SQL_CHAR, Type::String.new + map.register_type ODBC::SQL_LONGVARCHAR, Type::Text.new + map.register_type ODBC::SQL_TINYINT, Type::Integer.new(limit: 4) + map.register_type ODBC::SQL_SMALLINT, Type::Integer.new(limit: 8) + map.register_type ODBC::SQL_INTEGER, Type::Integer.new(limit: 16) + map.register_type ODBC::SQL_BIGINT, Type::BigInteger.new(limit: 32) + map.register_type ODBC::SQL_REAL, Type::Float.new(limit: 24) + map.register_type ODBC::SQL_FLOAT, Type::Float.new + map.register_type ODBC::SQL_DOUBLE, Type::Float.new(limit: 53) + map.register_type ODBC::SQL_DECIMAL, Type::Float.new + map.register_type ODBC::SQL_NUMERIC, Type::Integer.new + map.register_type ODBC::SQL_BINARY, Type::Binary.new + map.register_type ODBC::SQL_DATE, Type::Date.new + map.register_type ODBC::SQL_DATETIME, Type::DateTime.new + map.register_type ODBC::SQL_TIME, Type::Time.new + map.register_type ODBC::SQL_TIMESTAMP, Type::DateTime.new + map.register_type ODBC::SQL_GUID, Type::String.new + + alias_type map, ODBC::SQL_VARCHAR, ODBC::SQL_CHAR + alias_type map, ODBC::SQL_WCHAR, ODBC::SQL_CHAR + alias_type map, ODBC::SQL_WVARCHAR, ODBC::SQL_CHAR + alias_type map, ODBC::SQL_WLONGVARCHAR, ODBC::SQL_LONGVARCHAR + alias_type map, ODBC::SQL_VARBINARY, ODBC::SQL_BINARY + alias_type map, ODBC::SQL_LONGVARBINARY, ODBC::SQL_BINARY + alias_type map, ODBC::SQL_TYPE_DATE, ODBC::SQL_DATE + alias_type map, ODBC::SQL_TYPE_TIME, ODBC::SQL_TIME + alias_type map, ODBC::SQL_TYPE_TIMESTAMP, ODBC::SQL_TIMESTAMP + end + + def translate_exception(exception, message) + case exception.message[/^\d+/].to_i + when ERR_DUPLICATE_KEY_VALUE + ActiveRecord::RecordNotUnique.new(message, exception) + else + super + end + end + + def new_column(name, default, cast_type, sql_type = nil, null = true, native_type = nil, scale = nil, limit = nil) + ::ODBCAdapter::Column.new(name, default, cast_type, sql_type, null, native_type, scale, limit) + end + + private + + def alias_type(map, new_type, old_type) + map.register_type(new_type) do |_, *args| + map.lookup(old_type, *args) + end + end end end end diff --git a/lib/odbc_adapter/adapters/mysql_odbc_adapter.rb b/lib/odbc_adapter/adapters/mysql_odbc_adapter.rb index 276da92c..be144227 100644 --- a/lib/odbc_adapter/adapters/mysql_odbc_adapter.rb +++ b/lib/odbc_adapter/adapters/mysql_odbc_adapter.rb @@ -9,6 +9,10 @@ class BindSubstitution < Arel::Visitors::MySQL PRIMARY_KEY = 'INT(11) NOT NULL AUTO_INCREMENT PRIMARY KEY' + def truncate(table_name, name = nil) + execute("TRUNCATE TABLE #{quote_table_name(table_name)}", name) + end + def limited_update_conditions(where_sql, _quoted_table_name, _quoted_primary_key) where_sql end @@ -98,7 +102,7 @@ def change_column_default(table_name, column_name, default) def rename_column(table_name, column_name, new_column_name) col = columns(table_name).detect { |c| c.name == column_name.to_s } - current_type = col.sql_type + current_type = col.native_type current_type << "(#{col.limit})" if col.limit execute("ALTER TABLE #{table_name} CHANGE #{column_name} #{new_column_name} #{current_type}") end @@ -111,7 +115,7 @@ def indexes(table_name, name = nil) def options_include_default?(options) # MySQL 5.x doesn't allow DEFAULT NULL for first timestamp column in a table if options.include?(:default) && options[:default].nil? - if options.include?(:column) && options[:column].sql_type =~ /timestamp/i + if options.include?(:column) && options[:column].native_type =~ /timestamp/i options.delete(:default) end end @@ -134,7 +138,7 @@ def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil) end def last_inserted_id(_result) - @connection.last_id + select_value('SELECT LAST_INSERT_ID()').to_i end end end diff --git a/lib/odbc_adapter/adapters/postgresql_odbc_adapter.rb b/lib/odbc_adapter/adapters/postgresql_odbc_adapter.rb index 8c036ad1..daca1a86 100644 --- a/lib/odbc_adapter/adapters/postgresql_odbc_adapter.rb +++ b/lib/odbc_adapter/adapters/postgresql_odbc_adapter.rb @@ -7,63 +7,18 @@ class BindSubstitution < Arel::Visitors::PostgreSQL include Arel::Visitors::BindVisitor end - class PostgreSQLColumn < Column - def initialize(name, default, sql_type, native_type, null = true, scale = nil, native_types = nil, limit = nil) - super - @default = extract_default - end - - private - - def extract_default - case @default - when NilClass - nil - # Numeric types - when /\A\(?(-?\d+(\.\d*)?\)?(::bigint)?)\z/ then $1 - # Character types - when /\A\(?'(.*)'::.*\b(?:character varying|bpchar|text)\z/m then $1 - # Binary data types - when /\A'(.*)'::bytea\z/m then $1 - # Date/time types - when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/ then $1 - when /\A'(.*)'::interval\z/ then $1 - # Boolean type - when 'true' then true - when 'false' then false - # Geometric types - when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/ then $1 - # Network address types - when /\A'(.*)'::(?:cidr|inet|macaddr)\z/ then $1 - # Bit string types - when /\AB'(.*)'::"?bit(?: varying)?"?\z/ then $1 - # XML type - when /\A'(.*)'::xml\z/m then $1 - # Arrays - when /\A'(.*)'::"?\D+"?\[\]\z/ then $1 - # Object identifier types - when /\A-?\d+\z/ then $1 - else - # Anything else is blank, some user type, or some function - # and we can't know the value of that, so return nil. - nil - end - end - end - PRIMARY_KEY = 'SERIAL PRIMARY KEY' - # Override the default column class - def column_class - PostgreSQLColumn - end - # Filter for ODBCAdapter#tables # Omits table from #tables if table_filter returns true def table_filter(schema_name, table_type) %w[information_schema pg_catalog].include?(schema_name) || table_type !~ /TABLE/i end + def truncate(table_name, name = nil) + exec_query("TRUNCATE TABLE #{quote_table_name(table_name)}", name) + end + # Returns the sequence name for a table's primary key or some other specified key. def default_sequence_name(table_name, pk = nil) #:nodoc: serial_sequence(table_name, pk || 'id').split('.').last diff --git a/lib/odbc_adapter/column.rb b/lib/odbc_adapter/column.rb index 8126f098..c0d62f28 100644 --- a/lib/odbc_adapter/column.rb +++ b/lib/odbc_adapter/column.rb @@ -1,67 +1,25 @@ module ODBCAdapter class Column < ActiveRecord::ConnectionAdapters::Column - def initialize(name, default, sql_type, native_type, null = true, scale = nil, native_types = nil, limit = nil) + attr_reader :native_type + + def initialize(name, default, cast_type, sql_type, null, native_type, scale, limit) @name = name @default = default - @sql_type = native_type.to_s - @native_type = native_type.to_s + @cast_type = cast_type + @sql_type = sql_type @null = null - @precision = extract_precision(sql_type, limit) - @scale = extract_scale(sql_type, scale) - @type = genericize(sql_type, @scale, native_types) - @primary = nil - end - - private + @native_type = native_type - # Maps an ODBC SQL type to an ActiveRecord abstract data type - # - # c.f. Mappings in ConnectionAdapters::Column#simplified_type based on - # native column type declaration - # - # See also: - # Column#klass (schema_definitions.rb) for the Ruby class corresponding - # to each abstract data type. - def genericize(sql_type, scale, native_types) - case sql_type - when ODBC::SQL_BIT then :boolean - when ODBC::SQL_CHAR, ODBC::SQL_VARCHAR then :string - when ODBC::SQL_LONGVARCHAR then :text - when ODBC::SQL_WCHAR, ODBC::SQL_WVARCHAR then :string - when ODBC::SQL_WLONGVARCHAR then :text - when ODBC::SQL_TINYINT, ODBC::SQL_SMALLINT, ODBC::SQL_INTEGER, ODBC::SQL_BIGINT then :integer - when ODBC::SQL_REAL, ODBC::SQL_FLOAT, ODBC::SQL_DOUBLE then :float - # If SQLGetTypeInfo output of ODBC driver doesn't include a mapping - # to a native type from SQL_DECIMAL/SQL_NUMERIC, map to :float - when ODBC::SQL_DECIMAL, ODBC::SQL_NUMERIC then numeric_type(scale, native_types) - when ODBC::SQL_BINARY, ODBC::SQL_VARBINARY, ODBC::SQL_LONGVARBINARY then :binary - # SQL_DATETIME is an alias for SQL_DATE in ODBC's sql.h & sqlext.h - when ODBC::SQL_DATE, ODBC::SQL_TYPE_DATE, ODBC::SQL_DATETIME then :date - when ODBC::SQL_TIME, ODBC::SQL_TYPE_TIME then :time - when ODBC::SQL_TIMESTAMP, ODBC::SQL_TYPE_TIMESTAMP then :timestamp - when ODBC::SQL_GUID then :string - else - # when SQL_UNKNOWN_TYPE - # (ruby-odbc driver doesn't support following ODBC SQL types: - # SQL_WCHAR, SQL_WVARCHAR, SQL_WLONGVARCHAR, SQL_INTERVAL_xxx) - raise ArgumentError, "Unsupported ODBC SQL type [#{odbcSqlType}]" + if [ODBC::SQL_DECIMAL, ODBC::SQL_NUMERIC].include?(sql_type) + set_numeric_params(scale, limit) end end - # Ignore the ODBC precision of SQL types which don't take - # an explicit precision when defining a column - def extract_precision(sql_type, precision) - precision if [ODBC::SQL_DECIMAL, ODBC::SQL_NUMERIC].include?(sql_type) - end - - # Ignore the ODBC scale of SQL types which don't take - # an explicit scale when defining a column - def extract_scale(sql_type, scale) - scale || 0 if [ODBC::SQL_DECIMAL, ODBC::SQL_NUMERIC].include?(sql_type) - end + private - def numeric_type(scale, native_types) - scale.nil? || scale == 0 ? :integer : (native_types[:decimal].nil? ? :float : :decimal) + def set_numeric_params(scale, limit) + @cast_type.instance_variable_set(:@scale, scale || 0) + @cast_type.instance_variable_set(:@precision, limit) end end end diff --git a/lib/odbc_adapter/column_metadata.rb b/lib/odbc_adapter/column_metadata.rb index 5cf1a13a..ca34bedd 100644 --- a/lib/odbc_adapter/column_metadata.rb +++ b/lib/odbc_adapter/column_metadata.rb @@ -21,7 +21,6 @@ def initialize(adapter) @adapter = adapter end - # TODO: implement boolean column surrogates def native_database_types grouped = reported_types.group_by { |row| row[1] } diff --git a/lib/odbc_adapter/database_statements.rb b/lib/odbc_adapter/database_statements.rb index f5f475db..652d2ac1 100644 --- a/lib/odbc_adapter/database_statements.rb +++ b/lib/odbc_adapter/database_statements.rb @@ -18,7 +18,8 @@ def select_rows(sql, name = nil) # Executes the SQL statement in the context of this connection. # Returns the number of rows affected. - def execute(sql, name = nil) + # TODO: Currently ignoring binds until we can get prepared statements working. + def execute(sql, name = nil, binds = []) log(sql, name) do @connection.do(sql) end @@ -47,6 +48,14 @@ def exec_query(sql, name = 'SQL', binds = []) end end + # Executes delete +sql+ statement in the context of this connection using + # +binds+ as the bind substitutes. +name+ is logged along with + # the executed +sql+ statement. + def exec_delete(sql, name, binds) + execute(sql, name, binds) + end + alias :exec_update :exec_delete + # Begins the transaction (and turns off auto-committing). def begin_db_transaction @connection.autocommit = false @@ -60,7 +69,7 @@ def commit_db_transaction # Rolls back the transaction (and turns on auto-committing). Must be # done if the transaction block raises an exception or returns false. - def rollback_db_transaction + def exec_rollback_db_transaction @connection.rollback @connection.autocommit = true end @@ -72,112 +81,8 @@ def default_sequence_name(table, _column) "#{table}_seq" end - def recreate_database(name, options = {}) - drop_database(name) - create_database(name, options) - end - - def current_database - dbms.field_for(ODBC::SQL_DATABASE_NAME).strip - end - - # Returns an array of table names, for database tables visible on the - # current connection. - def tables(_name = nil) - stmt = @connection.tables - result = stmt.fetch_all || [] - stmt.drop - - result.each_with_object([]) do |row, table_names| - schema_name, table_name, table_type = row[1..3] - next if respond_to?(:table_filtered?) && table_filtered?(schema_name, table_type) - table_names << format_case(table_name) - end - end - - # The class of the column to instantiate - def column_class - ::ODBCAdapter::Column - end - - # Returns an array of Column objects for the table specified by +table_name+. - def columns(table_name, name = nil) - stmt = @connection.columns(native_case(table_name.to_s)) - result = stmt.fetch_all || [] - stmt.drop - - result.each_with_object([]) do |col, cols| - col_name = col[3] # SQLColumns: COLUMN_NAME - col_default = col[12] # SQLColumns: COLUMN_DEF - col_sql_type = col[4] # SQLColumns: DATA_TYPE - col_native_type = col[5] # SQLColumns: TYPE_NAME - col_limit = col[6] # SQLColumns: COLUMN_SIZE - col_scale = col[8] # SQLColumns: DECIMAL_DIGITS - - # SQLColumns: IS_NULLABLE, SQLColumns: NULLABLE - col_nullable = nullability(col_name, col[17], col[10]) - - cols << column_class.new(format_case(col_name), col_default, col_sql_type, col_native_type, col_nullable, col_scale, native_database_types, col_limit) - end - end - - # Returns an array of indexes for the given table. - def indexes(table_name, name = nil) - stmt = @connection.indexes(native_case(table_name.to_s)) - result = stmt.fetch_all || [] - stmt.drop unless stmt.nil? - - index_cols = [] - index_name = nil - unique = nil - - result.each_with_object([]).with_index do |(row, indices), row_idx| - # Skip table statistics - next if row[6] == 0 # SQLStatistics: TYPE - - if row[7] == 1 # SQLStatistics: ORDINAL_POSITION - # Start of column descriptor block for next index - index_cols = [] - unique = row[3].zero? # SQLStatistics: NON_UNIQUE - index_name = String.new(row[5]) # SQLStatistics: INDEX_NAME - end - - index_cols << format_case(row[8]) # SQLStatistics: COLUMN_NAME - next_row = result[row_idx + 1] - - if (row_idx == result.length - 1) || (next_row[6] == 0 || next_row[7] == 1) - indices << IndexDefinition.new(table_name, format_case(index_name), unique, index_cols) - end - end - end - - # Returns just a table's primary key - def primary_key(table_name) - stmt = @connection.primary_keys(native_case(table_name.to_s)) - result = stmt.fetch_all || [] - stmt.drop unless stmt.nil? - result[0] && result[0][3] - end - - ERR_DUPLICATE_KEY_VALUE = 23505 - - def translate_exception(exception, message) - case exception.message[/^\d+/].to_i - when ERR_DUPLICATE_KEY_VALUE - ActiveRecord::RecordNotUnique.new(message, exception) - else - super - end - end - protected - # Returns an array of record hashes with the column names as keys and - # column values as values. - def select(sql, name = nil, binds = []) - exec_query(sql, name, binds).to_a - end - # Returns the last auto-generated ID from the affected table. def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil) begin diff --git a/lib/odbc_adapter/quoting.rb b/lib/odbc_adapter/quoting.rb index 6bf87a5f..21d00772 100644 --- a/lib/odbc_adapter/quoting.rb +++ b/lib/odbc_adapter/quoting.rb @@ -1,44 +1,5 @@ module ODBCAdapter module Quoting - # Quotes the column value to help prevent - # {SQL injection attacks}[http://en.wikipedia.org/wiki/SQL_injection]. - def quote(value, column = nil) - # records are quoted as their primary key - return value.quoted_id if value.respond_to?(:quoted_id) - - case value - when String, ActiveSupport::Multibyte::Chars - value = value.to_s - return "'#{quote_string(value)}'" unless column - - case column.type - when :binary then "'#{quote_string(column.string_to_binary(value))}'" - when :integer then value.to_i.to_s - when :float then value.to_f.to_s - else - "'#{quote_string(value)}'" - end - - when true, false - if column && column.type == :integer - value ? '1' : '0' - else - value ? quoted_true : quoted_false - end - # BigDecimals need to be put in a non-normalized form and quoted. - when nil then "NULL" - when BigDecimal then value.to_s('F') - when Numeric then value.to_s - when Symbol then "'#{quote_string(value.to_s)}'" - else - if value.acts_like?(:date) || value.acts_like?(:time) - quoted_date(value) - else - super - end - end - end - # Quotes a string, escaping any ' (single quote) characters. def quote_string(string) string.gsub(/\'/, "''") @@ -71,10 +32,15 @@ def quoted_true # Ideally, we'd return an ODBC date or timestamp literal escape # sequence, but not all ODBC drivers support them. def quoted_date(value) - if value.acts_like?(:time) # Time, DateTime - "'#{value.strftime("%Y-%m-%d %H:%M:%S")}'" - else # Date - "'#{value.strftime("%Y-%m-%d")}'" + if value.acts_like?(:time) + zone_conversion_method = ActiveRecord::Base.default_timezone == :utc ? :getutc : :getlocal + + if value.respond_to?(zone_conversion_method) + value = value.send(zone_conversion_method) + end + value.strftime("%Y-%m-%d %H:%M:%S") # Time, DateTime + else + value.strftime("%Y-%m-%d") # Date end end end diff --git a/lib/odbc_adapter/schema_statements.rb b/lib/odbc_adapter/schema_statements.rb index f70fd154..83c7f410 100644 --- a/lib/odbc_adapter/schema_statements.rb +++ b/lib/odbc_adapter/schema_statements.rb @@ -12,5 +12,83 @@ def index_name(table_name, options) maximum = dbms.field_for(ODBC::SQL_MAX_IDENTIFIER_LEN) || 255 super(table_name, options)[0...maximum] end + + def current_database + dbms.field_for(ODBC::SQL_DATABASE_NAME).strip + end + + # Returns an array of table names, for database tables visible on the + # current connection. + def tables(_name = nil) + stmt = @connection.tables + result = stmt.fetch_all || [] + stmt.drop + + result.each_with_object([]) do |row, table_names| + schema_name, table_name, table_type = row[1..3] + next if respond_to?(:table_filtered?) && table_filtered?(schema_name, table_type) + table_names << format_case(table_name) + end + end + + # Returns an array of Column objects for the table specified by +table_name+. + def columns(table_name, name = nil) + stmt = @connection.columns(native_case(table_name.to_s)) + result = stmt.fetch_all || [] + stmt.drop + + result.each_with_object([]) do |col, cols| + col_name = col[3] # SQLColumns: COLUMN_NAME + col_default = col[12] # SQLColumns: COLUMN_DEF + col_sql_type = col[4] # SQLColumns: DATA_TYPE + col_native_type = col[5] # SQLColumns: TYPE_NAME + col_limit = col[6] # SQLColumns: COLUMN_SIZE + col_scale = col[8] # SQLColumns: DECIMAL_DIGITS + + # SQLColumns: IS_NULLABLE, SQLColumns: NULLABLE + col_nullable = nullability(col_name, col[17], col[10]) + + cast_type = lookup_cast_type(col_sql_type) + cols << new_column(format_case(col_name), col_default, cast_type, col_sql_type, col_nullable, col_native_type, col_scale, col_limit) + end + end + + # Returns an array of indexes for the given table. + def indexes(table_name, name = nil) + stmt = @connection.indexes(native_case(table_name.to_s)) + result = stmt.fetch_all || [] + stmt.drop unless stmt.nil? + + index_cols = [] + index_name = nil + unique = nil + + result.each_with_object([]).with_index do |(row, indices), row_idx| + # Skip table statistics + next if row[6] == 0 # SQLStatistics: TYPE + + if row[7] == 1 # SQLStatistics: ORDINAL_POSITION + # Start of column descriptor block for next index + index_cols = [] + unique = row[3].zero? # SQLStatistics: NON_UNIQUE + index_name = String.new(row[5]) # SQLStatistics: INDEX_NAME + end + + index_cols << format_case(row[8]) # SQLStatistics: COLUMN_NAME + next_row = result[row_idx + 1] + + if (row_idx == result.length - 1) || (next_row[6] == 0 || next_row[7] == 1) + indices << IndexDefinition.new(table_name, format_case(index_name), unique, index_cols) + end + end + end + + # Returns just a table's primary key + def primary_key(table_name) + stmt = @connection.primary_keys(native_case(table_name.to_s)) + result = stmt.fetch_all || [] + stmt.drop unless stmt.nil? + result[0] && result[0][3] + end end end diff --git a/lib/odbc_adapter/version.rb b/lib/odbc_adapter/version.rb index 40d077cd..feb54b09 100644 --- a/lib/odbc_adapter/version.rb +++ b/lib/odbc_adapter/version.rb @@ -1,3 +1,3 @@ module ODBCAdapter - VERSION = '3.2.0' + VERSION = '4.2.0' end diff --git a/test/migrations_test.rb b/test/migrations_test.rb index ecde5a85..76e3fbce 100644 --- a/test/migrations_test.rb +++ b/test/migrations_test.rb @@ -7,7 +7,7 @@ def setup def test_table_crud @connection.create_table(:foos, force: true) do |t| - t.timestamps + t.timestamps null: false end assert_equal 3, @connection.columns(:foos).count diff --git a/test/test_helper.rb b/test/test_helper.rb index ded59c42..d105bb8c 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -16,13 +16,13 @@ t.string :first_name t.string :last_name t.integer :letters - t.timestamps + t.timestamps null: false end create_table :todos, force: true do |t| t.integer :user_id t.text :body - t.timestamps + t.timestamps null: false end end