diff --git a/torndb.py b/torndb.py index c1cad6a..ffb5adf 100644 --- a/torndb.py +++ b/torndb.py @@ -71,7 +71,7 @@ class Connection(object): MySQLdb version >= 1.2.5 and MySQL version > 5.1.12. """ def __init__(self, host, database, user=None, password=None, - max_idle_time=7 * 3600, connect_timeout=0, + max_idle_time=7 * 3600, connect_timeout=10, time_zone="+0:00", charset = "utf8", sql_mode="TRADITIONAL", **kwargs): self.host = host @@ -253,17 +253,21 @@ def __getattr__(self, name): raise AttributeError(name) if MySQLdb is not None: - # Fix the access conversions to properly recognize unicode/binary - FIELD_TYPE = MySQLdb.constants.FIELD_TYPE - FLAG = MySQLdb.constants.FLAG - CONVERSIONS = copy.copy(MySQLdb.converters.conversions) + if MySQLdb.__package__ != 'pymysql': + # Fix the access conversions to properly recognize unicode/binary + FIELD_TYPE = MySQLdb.constants.FIELD_TYPE + FLAG = MySQLdb.constants.FLAG + CONVERSIONS = copy.copy(MySQLdb.converters.conversions) - field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING] - if 'VARCHAR' in vars(FIELD_TYPE): - field_types.append(FIELD_TYPE.VARCHAR) + field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING] + if 'VARCHAR' in vars(FIELD_TYPE): + field_types.append(FIELD_TYPE.VARCHAR) + + for field_type in field_types: + CONVERSIONS[field_type] = [(FLAG.BINARY, str)] + CONVERSIONS[field_type] + else: + CONVERSIONS = {} - for field_type in field_types: - CONVERSIONS[field_type] = [(FLAG.BINARY, str)] + CONVERSIONS[field_type] # Alias some common MySQL exceptions IntegrityError = MySQLdb.IntegrityError