diff --git a/mysql/connector/__init__.py b/mysql/connector/__init__.py index fc0e45d..da47cd5 100644 --- a/mysql/connector/__init__.py +++ b/mysql/connector/__init__.py @@ -1,5 +1,5 @@ # MySQL Connector/Python - MySQL driver written in Python. -# Copyright 2009 Sun Microsystems, Inc. All rights reserved +# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved. # Use is subject to license terms. (See COPYING) # This program is free software; you can redistribute it and/or modify @@ -25,15 +25,6 @@ MySQL Connector/Python - MySQL drive written in Python """ -import sys -_name = 'MySQL Connector/Python' -if not hasattr(sys, "version_info") or sys.version_info < (2,4): - raise RuntimeError("%s requires Python 2.4 or higher." % (_name)) -elif sys.version_info >= (3,0): - raise RuntimeError("%s does not yet support Python v3." % (_name)) -del _name -del sys - # Python Db API v2 apilevel = '2.0' threadsafety = 1 @@ -43,21 +34,22 @@ paramstyle = 'pyformat' import _version __version__ = _version.version -from mysql import MySQL +from connection import MySQLConnection from errors import * -from constants import FieldFlag, FieldType, CharacterSet, RefreshOption +from constants import FieldFlag, FieldType, CharacterSet,\ + RefreshOption, ClientFlag from dbapi import * def Connect(*args, **kwargs): - """Shortcut for creating a mysql.MySQL object.""" - return MySQL(*args, **kwargs) + """Shortcut for creating a connection.MySQLConnection object.""" + return MySQLConnection(*args, **kwargs) connect = Connect __all__ = [ - 'MySQL', 'Connect', + 'MySQLConnection', 'Connect', # Some useful constants - 'FieldType','FieldFlag','CharacterSet','RefreshOption', + 'FieldType','FieldFlag','ClientFlag','CharacterSet','RefreshOption', # Error handling 'Error','Warning', diff --git a/mysql/connector/_version.py b/mysql/connector/_version.py index 2cf038e..435f6ac 100644 --- a/mysql/connector/_version.py +++ b/mysql/connector/_version.py @@ -1,5 +1,5 @@ # MySQL Connector/Python - MySQL driver written in Python. -# Copyright 2009 Sun Microsystems, Inc. All rights reserved +# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved. # Use is subject to license terms. (See COPYING) # This program is free software; you can redistribute it and/or modify @@ -25,4 +25,4 @@ """ # Next line is generated -version = (0, 1, 0, 'devel', '') +version = (0, 3, 0, 'devel', '') diff --git a/mysql/connector/connection.py b/mysql/connector/connection.py index c2b21ab..a00dc40 100644 --- a/mysql/connector/connection.py +++ b/mysql/connector/connection.py @@ -1,5 +1,5 @@ # MySQL Connector/Python - MySQL driver written in Python. -# Copyright 2009 Sun Microsystems, Inc. All rights reserved +# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved. # Use is subject to license terms. (See COPYING) # This program is free software; you can redistribute it and/or modify @@ -24,31 +24,35 @@ """Implementing communication to MySQL servers """ +import sys import socket +import logging import os +import weakref +from collections import deque +import constants +import conversion import protocol import errors -from constants import CharacterSet +import utils +import cursor -class MySQLBaseConnection(object): +logger = logging.getLogger('myconnpy') + +class MySQLBaseSocket(object): """Base class for MySQL Connections subclasses. Should not be used directly but overloaded, changing the - open_connection part. Examples over subclasses are - MySQLTCPConnection - MySQLUNIXConnection + open_connection part. Examples of subclasses are + MySQLTCPSocket + MySQLUnixSocket """ - def __init__(self, prtcls=None): + def __init__(self): self.sock = None # holds the socket connection self.connection_timeout = None - self.protocol = None - self.socket_flags = 0 - try: - self.protocol = prtcls(self) - except: - self.protocol = protocol.MySQLProtocol(self) - self._set_socket_flags() + self.buffer = deque() + self.recvsize = 1024*8 def open_connection(self): pass @@ -58,10 +62,12 @@ class MySQLBaseConnection(object): self.sock.close() except: pass + + def get_address(self): + pass def send(self, buf): - """ - Send packets using the socket to the server. + """Send packets over the socket """ pktlen = len(buf) try: @@ -69,49 +75,57 @@ class MySQLBaseConnection(object): pktlen -= self.sock.send(buf) except Exception, e: raise errors.OperationalError('%s' % e) - + def recv(self): - """ - Receive packets using the socket from the server. + """Receive packets from the socket """ try: - header = self.sock.recv(4, self.socket_flags) - (pktsize, pktnr) = self.protocol.handle_header(header) - buf = header + self.sock.recv(pktsize, self.socket_flags) - self.protocol.is_error(buf) + return self.buffer.popleft() + except IndexError: + pass + + pktnr = -1 + try: + buf = self.sock.recv(self.recvsize) + while buf: + totalsize = len(buf) + if pktnr == -1 and totalsize > 4: + pktsize = utils.intread(buf[0:3]) + pktnr = utils.intread(buf[3]) + if pktnr > -1 and totalsize >= pktsize+4: + size = pktsize+4 + self.buffer.append(buf[0:size]) + buf = buf[size:] + pktnr = -1 + if len(buf) == 0: + break + elif len(buf) < pktsize+4: + buf += self.sock.recv(self.recvsize) + except socket.timeout, e: + raise errors.InterfaceError(errno=2013) + except socket.error, e: + raise errors.InterfaceError(errno=2055, + values=dict(socketaddr=self.get_address(),errno=e.errno)) except: raise - - return (buf, pktsize, pktnr) - - def set_protocol(self, prtcls): + try: - self.protocol = prtcls(self, self.protocol.handshake) - except: - self.protocol = protocol.MySQLProtocol(self) - + return self.buffer.popleft() + except IndexError, e: + pass + def set_connection_timeout(self, timeout): self.connection_timeout = timeout - def _set_socket_flags(self, flags=None): - self.socket_flags = 0 - if flags is None: - if os.name == 'nt': - flags = 0 - else: - flags = socket.MSG_WAITALL - - if flags is not None: - self.socket_flags = flags - - -class MySQLUnixConnection(MySQLBaseConnection): +class MySQLUnixSocket(MySQLBaseSocket): """Opens a connection through the UNIX socket of the MySQL Server.""" - def __init__(self, prtcls=None,unix_socket='/tmp/mysql.sock'): - MySQLBaseConnection.__init__(self, prtcls=prtcls) + def __init__(self, unix_socket='/tmp/mysql.sock'): + MySQLBaseSocket.__init__(self) self.unix_socket = unix_socket - self.socket_flags = socket.MSG_WAITALL + + def get_address(self): + return self.unix_socket def open_connection(self): """Opens a UNIX socket and checks the MySQL handshake.""" @@ -119,19 +133,26 @@ class MySQLUnixConnection(MySQLBaseConnection): self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.sock.settimeout(self.connection_timeout) self.sock.connect(self.unix_socket) + except socket.error, e: + try: + m = e.errno + except: + m = e + raise errors.InterfaceError(errno=2002, + values=dict(socketaddr=self.get_address(),errno=m)) except StandardError, e: - raise errors.OperationalError('%s' % e) + raise errors.InterfaceError('%s' % e) - buf = self.recv()[0] - self.protocol.handle_handshake(buf) - -class MySQLTCPConnection(MySQLBaseConnection): +class MySQLTCPSocket(MySQLBaseSocket): """Opens a TCP connection to the MySQL Server.""" - def __init__(self, prtcls=None, host='127.0.0.1', port=3306): - MySQLBaseConnection.__init__(self, prtcls=prtcls) + def __init__(self, host='127.0.0.1', port=3306): + MySQLBaseSocket.__init__(self) self.server_host = host self.server_port = port + + def get_address(self): + return "%s:%s" % (self.server_host,self.server_port) def open_connection(self): """Opens a TCP Connection and checks the MySQL handshake.""" @@ -139,10 +160,453 @@ class MySQLTCPConnection(MySQLBaseConnection): self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.settimeout(self.connection_timeout) self.sock.connect( (self.server_host, self.server_port) ) + except socket.error, e: + try: + m = e.errno + except: + m = e + raise errors.InterfaceError(errno=2003, + values=dict(socketaddr=self.get_address(),errno=m)) except StandardError, e: - raise errors.OperationalError('%s' % e) - - buf = self.recv()[0] - self.protocol.handle_handshake(buf) + raise errors.InterfaceError('%s' % e) + except: + raise +class MySQLConnection(object): + """MySQL""" + + def __init__(self, *args, **kwargs): + """Initializing""" + self.conn = None # Holding the connection + self.protocol = None + self.converter = None + self.cursors = [] + + self.client_flags = constants.ClientFlag.get_default() + self._charset = 33 + + self._username = '' + self._database = '' + self._server_host = '127.0.0.1' + self._server_port = 3306 + self._unix_socket = None + self.client_host = '' + self.client_port = 0 + + self.affected_rows = 0 + self.server_status = 0 + self.warning_count = 0 + self.field_count = 0 + self.insert_id = 0 + self.info_msg = '' + self.use_unicode = True + self.get_warnings = False + self.raise_on_warnings = False + self.connection_timeout = None + self.buffered = False + self.unread_result = False + self.raw = False + if len(kwargs) > 0: + self.connect(*args, **kwargs) + + def connect(self, database=None, user='', password='', + host='127.0.0.1', port=3306, unix_socket=None, + use_unicode=True, charset='utf8', collation=None, + autocommit=False, + time_zone=None, sql_mode=None, + get_warnings=False, raise_on_warnings=False, + connection_timeout=None, client_flags=0, + buffered=False, raw=False, + passwd=None, db=None, connect_timeout=None, dsn=None): + if db and not database: + database = db + if passwd and not password: + password = passwd + if connect_timeout and not connection_timeout: + connection_timeout = connect_timeout + + if dsn is not None: + errors.NotSupportedError("Data source name is not supported") + + self._server_host = host + self._server_port = port + self._unix_socket = unix_socket + if database is not None: + self._database = database.strip() + else: + self._database = None + self._username = user + + self.set_warnings(get_warnings,raise_on_warnings) + self.connection_timeout = connection_timeout + self.buffered = buffered + self.raw = raw + self.use_unicode = use_unicode + self.set_client_flags(client_flags) + self._charset = constants.CharacterSet.get_charset_info(charset)[0] + + if user or password: + self.set_login(user, password) + + self.disconnect() + self._open_connection(username=user, password=password, database=database, + client_flags=self.client_flags, charset=charset) + self._post_connection(time_zone=time_zone, sql_mode=sql_mode, + collation=collation) + + def _get_connection(self, prtcls=None): + """Get connection based on configuration + + This method will return the appropriated connection object using + the connection parameters. + + Returns subclass of MySQLBaseSocket. + """ + conn = None + if self.unix_socket and os.name != 'nt': + conn = MySQLUnixSocket(unix_socket=self.unix_socket) + else: + conn = MySQLTCPSocket(host=self.server_host, + port=self.server_port) + conn.set_connection_timeout(self.connection_timeout) + return conn + + def _open_connection(self, username=None, password=None, database=None, + client_flags=None, charset=None): + """Opens the connection + + Open the connection, check the MySQL version, and set the + protocol. + """ + try: + self.protocol = protocol.MySQLProtocol(self._get_connection()) + self.protocol.do_handshake() + version = self.protocol.server_version + if version < (4,1): + raise errors.InterfaceError( + "MySQL Version %s is not supported." % version) + self.protocol.do_auth(username, password, database, client_flags, + self._charset) + (self._charset, self.charset_name, c) = \ + constants.CharacterSet.get_charset_info(charset) + except: + raise + + def _post_connection(self, time_zone=None, autocommit=False, + sql_mode=None, collation=None): + """Post connection session setup + + Should be called after a connection was established""" + self.set_converter_class(conversion.MySQLConverter) + try: + if collation is not None: + self.collation = collation + self.autocommit = autocommit + if time_zone is not None: + self.time_zone = time_zone + if sql_mode is not None: + self.sql_mode = sql_mode + except: + raise + + def is_connected(self): + """ + Check whether we are connected to the MySQL server. + """ + return self.protocol.cmd_ping() + ping = is_connected + + def disconnect(self): + """ + Disconnect from the MySQL server. + """ + if not self.protocol: + return + + if self.protocol.conn.sock is not None: + self.protocol.cmd_quit() + try: + self.protocol.conn.close_connection() + except: + pass + self.protocol = None + + def set_converter_class(self, convclass): + """ + Set the converter class to be used. This should be a class overloading + methods and members of conversion.MySQLConverter. + """ + self.converter_class = convclass + self.converter = convclass(self.charset_name, self.use_unicode) + + def get_server_version(self): + """Returns the server version as a tuple""" + try: + return self.protocol.server_version + except: + pass + + return None + + def get_server_info(self): + """Returns the server version as a string""" + return self.protocol.server_version_original + + @property + def connection_id(self): + """MySQL connection ID""" + threadid = None + try: + threadid = self.protocol.server_threadid + except: + pass + return threadid + + def set_login(self, username=None, password=None): + """Set login information for MySQL + + Set the username and/or password for the user connecting to + the MySQL Server. + """ + if username is not None: + self.username = username.strip() + else: + self.username = '' + if password is not None: + self.password = password.strip() + else: + self.password = '' + + def set_unicode(self, value=True): + """Toggle unicode mode + + Set whether we return string fields as unicode or not. + Default is True. + """ + self.use_unicode = value + if self.converter: + self.converter.set_unicode(value) + + def set_charset(self, charset): + try: + (idx, charset_name, c) = \ + constants.CharacterSet.get_charset_info(charset) + self._execute_query("SET NAMES '%s'" % charset_name) + except: + raise + else: + self._charset = idx + self.charset_name = charset_name + self.converter.set_charset(charset_name) + def get_charset(self): + return self._info_query( + "SELECT @@session.character_set_connection")[0] + charset = property(get_charset, set_charset, + doc="Character set for this connection") + + def set_collation(self, collation): + try: + self._execute_query( + "SET @@session.collation_connection = '%s'" % collation) + except: + raise + def get_collation(self): + return self._info_query( + "SELECT @@session.collation_connection")[0] + collation = property(get_collation, set_collation, + doc="Collation for this connection") + + def set_warnings(self, fetch=False, raise_on_warnings=False): + """Set how to handle warnings coming from MySQL + + Set wheter we should get warnings whenever an operation produced some. + If you set raise_on_warnings to True, any warning will be raised + as a DataError exception. + """ + if raise_on_warnings is True: + self.get_warnings = True + self.raise_on_warnings = True + else: + self.get_warnings = fetch + self.raise_on_warnings = False + + def set_client_flags(self, flags): + """Set the client flags + + The flags-argument can be either an int or a list (or tuple) of + ClientFlag-values. If it is an integer, it will set client_flags + to flags. + If flags is a list (or tuple), each flag will be set or unset + when it's negative. + + set_client_flags([ClientFlag.FOUND_ROWS,-ClientFlag.LONG_FLAG]) + + Returns self.client_flags + """ + if isinstance(flags,int) and flags > 0: + self.set_client_flags(flags) + else: + if isinstance(flags,(tuple,list)): + for f in flags: + if f < 0: + self.unset_client_flag(abs(f)) + else: + self.set_client_flag(f) + return self.client_flags + + def set_client_flag(self, flag): + if flag > 0: + self.client_flags |= flag + + def unset_client_flag(self, flag): + if flag > 0: + self.client_flags &= ~flag + + def isset_client_flag(self, flag): + if (self.client_flags & flag) > 0: + return True + return False + + @property + def user(self): + """User used while connecting to MySQL""" + return self._username + + @property + def server_host(self): + """MySQL server IP address or name""" + return self._server_host + + @property + def server_port(self): + "MySQL server TCP/IP port" + return self._server_port + + @property + def unix_socket(self): + "MySQL Unix socket file location" + return self._unix_socket + + def set_database(self, value): + try: + self.protocol.cmd_query("USE %s" % value) + except: + raise + def get_database(self): + """Get the current database""" + return self._info_query("SELECT DATABASE()")[0] + database = property(get_database, set_database, + doc="Current database") + + def set_time_zone(self, value): + try: + self.protocol.cmd_query("SET @@session.time_zone = %s" % value) + except: + raise + def get_time_zone(self): + return self._info_query("SELECT @@session.time_zone")[0] + time_zone = property(get_time_zone, set_time_zone, + doc="time_zone value for current MySQL session") + + def set_sql_mode(self, value): + try: + self.protocol.cmd_query("SET @@session.sql_mode = %s" % value) + except: + raise + def get_sql_mode(self): + return self._info_query("SELECT @@session.sql_mode")[0] + sql_mode = property(get_sql_mode, set_sql_mode, + doc="sql_mode value for current MySQL session") + + def set_autocommit(self, value): + try: + if value: + s = 'ON' + else: + s = 'OFF' + self._execute_query("SET @@session.autocommit = %s" % s) + except: + raise + def get_autocommit(self): + value = self._info_query("SELECT @@session.autocommit")[0] + if value == 1: + return True + return False + autocommit = property(get_autocommit, set_autocommit, + doc="autocommit value for current MySQL session") + + def close(self): + del self.cursors[:] + self.disconnect() + + def remove_cursor(self, c): + try: + self.cursors.remove(c) + except ValueError: + raise errors.ProgrammingError( + "Cursor could not be removed.") + + def cursor(self, buffered=None, raw=None, cursor_class=None): + """Instantiates and returns a cursor + + By default, MySQLCursor is returned. Depending on the options + while connecting, a buffered and/or raw cursor instantiated + instead. + + It is possible to also give a custom cursor through the + cursor_class paramter, but it needs to be a subclass of + mysql.connector.cursor.CursorBase. + + Returns a cursor-object + """ + if cursor_class is not None: + if not issubclass(cursor_class, cursor.CursorBase): + raise errors.ProgrammingError( + "Cursor class needs be subclass of cursor.CursorBase") + c = (cursor_class)(self) + else: + buffered = buffered or self.buffered + raw = raw or self.raw + + t = 0 + if buffered is True: + t |= 1 + if raw is True: + t |= 2 + + types = { + 0 : cursor.MySQLCursor, + 1 : cursor.MySQLCursorBuffered, + 2 : cursor.MySQLCursorRaw, + 3 : cursor.MySQLCursorBufferedRaw, + } + c = (types[t])(self) + + if c not in self.cursors: + self.cursors.append(c) + return c + + def commit(self): + """Commit current transaction""" + self._execute_query("COMMIT") + + def rollback(self): + """Rollback current transaction""" + self._execute_query("ROLLBACK") + + def _execute_query(self, query): + if self.unread_result is True: + raise errors.InternalError("Unread result found.") + + self.protocol.cmd_query(query) + + def _info_query(self, query): + try: + cur = self.cursor(buffered=True) + cur.execute(query) + row = cur.fetchone() + cur.close() + except: + raise + return row diff --git a/mysql/connector/constants.py b/mysql/connector/constants.py index ea66cba..61a1078 100644 --- a/mysql/connector/constants.py +++ b/mysql/connector/constants.py @@ -1,5 +1,5 @@ # MySQL Connector/Python - MySQL driver written in Python. -# Copyright 2009 Sun Microsystems, Inc. All rights reserved +# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved. # Use is subject to license terms. (See COPYING) # This program is free software; you can redistribute it and/or modify @@ -26,6 +26,14 @@ from errors import ProgrammingError +def flag_is_set(flag, flags): + """Checks if the flag is set + + Returns boolean""" + if (flags & flag) > 0: + return True + return False + class _constants(object): prefix = '' @@ -36,21 +44,20 @@ class _constants(object): @classmethod def get_desc(cls,name): - res = '' try: - res = cls.desc[name][1] - except KeyError, e: - raise KeyError, e - else: - return res + return cls.desc[name][1] + except: + return None @classmethod def get_info(cls,n): - res = () - for k,v in cls.desc.items(): - if v[0] == n: - return v[1] - raise KeyError, e + try: + res = {} + for v in cls.desc.items(): + res[v[1][0]] = v[0] + return res[n] + except: + return None @classmethod def get_full_info(cls): @@ -61,7 +68,20 @@ class _constants(object): res = ('No information found in constant class.%s' % e) return res - + +class _constantflags(_constants): + + @classmethod + def get_bit_info(cls, v): + """Get the name of all bits set + + Returns a list of strings.""" + res = [] + for name,d in cls.desc.items(): + if v & d[0]: + res.append(name) + return res + class FieldType(_constants): prefix = 'FIELD_TYPE_' @@ -155,7 +175,7 @@ class FieldType(_constants): cls.DATETIME, cls.TIMESTAMP, ] -class FieldFlag(_constants): +class FieldFlag(_constantflags): """ Field flags as found in MySQL sources mysql-src/include/mysql_com.h """ @@ -213,8 +233,7 @@ class FieldFlag(_constants): 'FIELD_IN_ADD_INDEX': (1 << 20, "Intern: Field used in ADD INDEX"), 'FIELD_IS_RENAMED': (1 << 21, "Intern: Field is being renamed"), } - - + class ServerCmd(_constants): _prefix = 'COM_' SLEEP = 0 @@ -248,7 +267,7 @@ class ServerCmd(_constants): STMT_FETCH = 28 DAEMON = 29 -class ClientFlag(_constants): +class ClientFlag(_constantflags): """ Client Options as found in the MySQL sources mysql-src/include/mysql_com.h """ @@ -314,7 +333,7 @@ class ClientFlag(_constants): flags |= f return flags -class ServerFlag(_constants): +class ServerFlag(_constantflags): """ Server flags as found in the MySQL sources mysql-src/include/mysql_com.h """ @@ -364,204 +383,330 @@ class RefreshOption(_constants): } class CharacterSet(_constants): - """ - List of supported character sets with their collations. This maps to the - character set we get from the server within the handshake packet. + """MySQL supported character sets and collations - To update this list, use the following query: - SELECT ID,CHARACTER_SET_NAME, COLLATION_NAME - FROM INFORMATION_SCHEMA.COLLATIONS - ORDER BY ID + List of character sets with their collations supported by MySQL. This + maps to the character set we get from the server within the handshake + packet. - This list is hardcoded because we want to avoid doing each time the above - query to get the name of the character set used. + The list is hardcode so we avoid a database query when getting the + name of the used character set or collation. """ - _max_id = 211 # SELECT MAX(ID)+1 FROM INFORMATION_SCHEMA.COLLATIONS - - @classmethod - def _init_desc(cls): - if not cls.__dict__.has_key('desc'): - - # Do not forget to update the tests in test_constants! - cls.desc = [ None for i in range(cls._max_id)] - cls.desc[1] = ('big5','big5_chinese_ci') - cls.desc[2] = ('latin2','latin2_czech_cs') - cls.desc[3] = ('dec8','dec8_swedish_ci') - cls.desc[4] = ('cp850','cp850_general_ci') - cls.desc[5] = ('latin1','latin1_german1_ci') - cls.desc[6] = ('hp8','hp8_english_ci') - cls.desc[7] = ('koi8r','koi8r_general_ci') - cls.desc[8] = ('latin1','latin1_swedish_ci') - cls.desc[9] = ('latin2','latin2_general_ci') - cls.desc[10] = ('swe7','swe7_swedish_ci') - cls.desc[11] = ('ascii','ascii_general_ci') - cls.desc[12] = ('ujis','ujis_japanese_ci') - cls.desc[13] = ('sjis','sjis_japanese_ci') - cls.desc[14] = ('cp1251','cp1251_bulgarian_ci') - cls.desc[15] = ('latin1','latin1_danish_ci') - cls.desc[16] = ('hebrew','hebrew_general_ci') - cls.desc[18] = ('tis620','tis620_thai_ci') - cls.desc[19] = ('euckr','euckr_korean_ci') - cls.desc[20] = ('latin7','latin7_estonian_cs') - cls.desc[21] = ('latin2','latin2_hungarian_ci') - cls.desc[22] = ('koi8u','koi8u_general_ci') - cls.desc[23] = ('cp1251','cp1251_ukrainian_ci') - cls.desc[24] = ('gb2312','gb2312_chinese_ci') - cls.desc[25] = ('greek','greek_general_ci') - cls.desc[26] = ('cp1250','cp1250_general_ci') - cls.desc[27] = ('latin2','latin2_croatian_ci') - cls.desc[28] = ('gbk','gbk_chinese_ci') - cls.desc[29] = ('cp1257','cp1257_lithuanian_ci') - cls.desc[30] = ('latin5','latin5_turkish_ci') - cls.desc[31] = ('latin1','latin1_german2_ci') - cls.desc[32] = ('armscii8','armscii8_general_ci') - cls.desc[33] = ('utf8','utf8_general_ci') - cls.desc[34] = ('cp1250','cp1250_czech_cs') - cls.desc[35] = ('ucs2','ucs2_general_ci') - cls.desc[36] = ('cp866','cp866_general_ci') - cls.desc[37] = ('keybcs2','keybcs2_general_ci') - cls.desc[38] = ('macce','macce_general_ci') - cls.desc[39] = ('macroman','macroman_general_ci') - cls.desc[40] = ('cp852','cp852_general_ci') - cls.desc[41] = ('latin7','latin7_general_ci') - cls.desc[42] = ('latin7','latin7_general_cs') - cls.desc[43] = ('macce','macce_bin') - cls.desc[44] = ('cp1250','cp1250_croatian_ci') - cls.desc[47] = ('latin1','latin1_bin') - cls.desc[48] = ('latin1','latin1_general_ci') - cls.desc[49] = ('latin1','latin1_general_cs') - cls.desc[50] = ('cp1251','cp1251_bin') - cls.desc[51] = ('cp1251','cp1251_general_ci') - cls.desc[52] = ('cp1251','cp1251_general_cs') - cls.desc[53] = ('macroman','macroman_bin') - cls.desc[57] = ('cp1256','cp1256_general_ci') - cls.desc[58] = ('cp1257','cp1257_bin') - cls.desc[59] = ('cp1257','cp1257_general_ci') - cls.desc[63] = ('binary','binary') - cls.desc[64] = ('armscii8','armscii8_bin') - cls.desc[65] = ('ascii','ascii_bin') - cls.desc[66] = ('cp1250','cp1250_bin') - cls.desc[67] = ('cp1256','cp1256_bin') - cls.desc[68] = ('cp866','cp866_bin') - cls.desc[69] = ('dec8','dec8_bin') - cls.desc[70] = ('greek','greek_bin') - cls.desc[71] = ('hebrew','hebrew_bin') - cls.desc[72] = ('hp8','hp8_bin') - cls.desc[73] = ('keybcs2','keybcs2_bin') - cls.desc[74] = ('koi8r','koi8r_bin') - cls.desc[75] = ('koi8u','koi8u_bin') - cls.desc[77] = ('latin2','latin2_bin') - cls.desc[78] = ('latin5','latin5_bin') - cls.desc[79] = ('latin7','latin7_bin') - cls.desc[80] = ('cp850','cp850_bin') - cls.desc[81] = ('cp852','cp852_bin') - cls.desc[82] = ('swe7','swe7_bin') - cls.desc[83] = ('utf8','utf8_bin') - cls.desc[84] = ('big5','big5_bin') - cls.desc[85] = ('euckr','euckr_bin') - cls.desc[86] = ('gb2312','gb2312_bin') - cls.desc[87] = ('gbk','gbk_bin') - cls.desc[88] = ('sjis','sjis_bin') - cls.desc[89] = ('tis620','tis620_bin') - cls.desc[90] = ('ucs2','ucs2_bin') - cls.desc[91] = ('ujis','ujis_bin') - cls.desc[92] = ('geostd8','geostd8_general_ci') - cls.desc[93] = ('geostd8','geostd8_bin') - cls.desc[94] = ('latin1','latin1_spanish_ci') - cls.desc[95] = ('cp932','cp932_japanese_ci') - cls.desc[96] = ('cp932','cp932_bin') - cls.desc[97] = ('eucjpms','eucjpms_japanese_ci') - cls.desc[98] = ('eucjpms','eucjpms_bin') - cls.desc[128] = ('ucs2','ucs2_unicode_ci') - cls.desc[129] = ('ucs2','ucs2_icelandic_ci') - cls.desc[130] = ('ucs2','ucs2_latvian_ci') - cls.desc[131] = ('ucs2','ucs2_romanian_ci') - cls.desc[132] = ('ucs2','ucs2_slovenian_ci') - cls.desc[133] = ('ucs2','ucs2_polish_ci') - cls.desc[134] = ('ucs2','ucs2_estonian_ci') - cls.desc[135] = ('ucs2','ucs2_spanish_ci') - cls.desc[136] = ('ucs2','ucs2_swedish_ci') - cls.desc[137] = ('ucs2','ucs2_turkish_ci') - cls.desc[138] = ('ucs2','ucs2_czech_ci') - cls.desc[139] = ('ucs2','ucs2_danish_ci') - cls.desc[140] = ('ucs2','ucs2_lithuanian_ci') - cls.desc[141] = ('ucs2','ucs2_slovak_ci') - cls.desc[142] = ('ucs2','ucs2_spanish2_ci') - cls.desc[143] = ('ucs2','ucs2_roman_ci') - cls.desc[144] = ('ucs2','ucs2_persian_ci') - cls.desc[145] = ('ucs2','ucs2_esperanto_ci') - cls.desc[146] = ('ucs2','ucs2_hungarian_ci') - cls.desc[192] = ('utf8','utf8_unicode_ci') - cls.desc[193] = ('utf8','utf8_icelandic_ci') - cls.desc[194] = ('utf8','utf8_latvian_ci') - cls.desc[195] = ('utf8','utf8_romanian_ci') - cls.desc[196] = ('utf8','utf8_slovenian_ci') - cls.desc[197] = ('utf8','utf8_polish_ci') - cls.desc[198] = ('utf8','utf8_estonian_ci') - cls.desc[199] = ('utf8','utf8_spanish_ci') - cls.desc[200] = ('utf8','utf8_swedish_ci') - cls.desc[201] = ('utf8','utf8_turkish_ci') - cls.desc[202] = ('utf8','utf8_czech_ci') - cls.desc[203] = ('utf8','utf8_danish_ci') - cls.desc[204] = ('utf8','utf8_lithuanian_ci') - cls.desc[205] = ('utf8','utf8_slovak_ci') - cls.desc[206] = ('utf8','utf8_spanish2_ci') - cls.desc[207] = ('utf8','utf8_roman_ci') - cls.desc[208] = ('utf8','utf8_persian_ci') - cls.desc[209] = ('utf8','utf8_esperanto_ci') - cls.desc[210] = ('utf8','utf8_hungarian_ci') + desc = [ + # (character set name, collation, default) + None, + ("big5","big5_chinese_ci",True), # 1 + ("latin2","latin2_czech_cs",False), # 2 + ("dec8","dec8_swedish_ci",True), # 3 + ("cp850","cp850_general_ci",True), # 4 + ("latin1","latin1_german1_ci",False), # 5 + ("hp8","hp8_english_ci",True), # 6 + ("koi8r","koi8r_general_ci",True), # 7 + ("latin1","latin1_swedish_ci",True), # 8 + ("latin2","latin2_general_ci",True), # 9 + ("swe7","swe7_swedish_ci",True), # 10 + ("ascii","ascii_general_ci",True), # 11 + ("ujis","ujis_japanese_ci",True), # 12 + ("sjis","sjis_japanese_ci",True), # 13 + ("cp1251","cp1251_bulgarian_ci",False), # 14 + ("latin1","latin1_danish_ci",False), # 15 + ("hebrew","hebrew_general_ci",True), # 16 + None, + ("tis620","tis620_thai_ci",True), # 18 + ("euckr","euckr_korean_ci",True), # 19 + ("latin7","latin7_estonian_cs",False), # 20 + ("latin2","latin2_hungarian_ci",False), # 21 + ("koi8u","koi8u_general_ci",True), # 22 + ("cp1251","cp1251_ukrainian_ci",False), # 23 + ("gb2312","gb2312_chinese_ci",True), # 24 + ("greek","greek_general_ci",True), # 25 + ("cp1250","cp1250_general_ci",True), # 26 + ("latin2","latin2_croatian_ci",False), # 27 + ("gbk","gbk_chinese_ci",True), # 28 + ("cp1257","cp1257_lithuanian_ci",False), # 29 + ("latin5","latin5_turkish_ci",True), # 30 + ("latin1","latin1_german2_ci",False), # 31 + ("armscii8","armscii8_general_ci",True), # 32 + ("utf8","utf8_general_ci",True), # 33 + ("cp1250","cp1250_czech_cs",False), # 34 + ("ucs2","ucs2_general_ci",True), # 35 + ("cp866","cp866_general_ci",True), # 36 + ("keybcs2","keybcs2_general_ci",True), # 37 + ("macce","macce_general_ci",True), # 38 + ("macroman","macroman_general_ci",True), # 39 + ("cp852","cp852_general_ci",True), # 40 + ("latin7","latin7_general_ci",True), # 41 + ("latin7","latin7_general_cs",False), # 42 + ("macce","macce_bin",False), # 43 + ("cp1250","cp1250_croatian_ci",False), # 44 + None, + None, + ("latin1","latin1_bin",False), # 47 + ("latin1","latin1_general_ci",False), # 48 + ("latin1","latin1_general_cs",False), # 49 + ("cp1251","cp1251_bin",False), # 50 + ("cp1251","cp1251_general_ci",True), # 51 + ("cp1251","cp1251_general_cs",False), # 52 + ("macroman","macroman_bin",False), # 53 + None, + None, + None, + ("cp1256","cp1256_general_ci",True), # 57 + ("cp1257","cp1257_bin",False), # 58 + ("cp1257","cp1257_general_ci",True), # 59 + None, + None, + None, + ("binary","binary",True), # 63 + ("armscii8","armscii8_bin",False), # 64 + ("ascii","ascii_bin",False), # 65 + ("cp1250","cp1250_bin",False), # 66 + ("cp1256","cp1256_bin",False), # 67 + ("cp866","cp866_bin",False), # 68 + ("dec8","dec8_bin",False), # 69 + ("greek","greek_bin",False), # 70 + ("hebrew","hebrew_bin",False), # 71 + ("hp8","hp8_bin",False), # 72 + ("keybcs2","keybcs2_bin",False), # 73 + ("koi8r","koi8r_bin",False), # 74 + ("koi8u","koi8u_bin",False), # 75 + None, + ("latin2","latin2_bin",False), # 77 + ("latin5","latin5_bin",False), # 78 + ("latin7","latin7_bin",False), # 79 + ("cp850","cp850_bin",False), # 80 + ("cp852","cp852_bin",False), # 81 + ("swe7","swe7_bin",False), # 82 + ("utf8","utf8_bin",False), # 83 + ("big5","big5_bin",False), # 84 + ("euckr","euckr_bin",False), # 85 + ("gb2312","gb2312_bin",False), # 86 + ("gbk","gbk_bin",False), # 87 + ("sjis","sjis_bin",False), # 88 + ("tis620","tis620_bin",False), # 89 + ("ucs2","ucs2_bin",False), # 90 + ("ujis","ujis_bin",False), # 91 + ("geostd8","geostd8_general_ci",True), # 92 + ("geostd8","geostd8_bin",False), # 93 + ("latin1","latin1_spanish_ci",False), # 94 + ("cp932","cp932_japanese_ci",True), # 95 + ("cp932","cp932_bin",False), # 96 + ("eucjpms","eucjpms_japanese_ci",True), # 97 + ("eucjpms","eucjpms_bin",False), # 98 + ("cp1250","cp1250_polish_ci",False), # 99 + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ("ucs2","ucs2_unicode_ci",False), # 128 + ("ucs2","ucs2_icelandic_ci",False), # 129 + ("ucs2","ucs2_latvian_ci",False), # 130 + ("ucs2","ucs2_romanian_ci",False), # 131 + ("ucs2","ucs2_slovenian_ci",False), # 132 + ("ucs2","ucs2_polish_ci",False), # 133 + ("ucs2","ucs2_estonian_ci",False), # 134 + ("ucs2","ucs2_spanish_ci",False), # 135 + ("ucs2","ucs2_swedish_ci",False), # 136 + ("ucs2","ucs2_turkish_ci",False), # 137 + ("ucs2","ucs2_czech_ci",False), # 138 + ("ucs2","ucs2_danish_ci",False), # 139 + ("ucs2","ucs2_lithuanian_ci",False), # 140 + ("ucs2","ucs2_slovak_ci",False), # 141 + ("ucs2","ucs2_spanish2_ci",False), # 142 + ("ucs2","ucs2_roman_ci",False), # 143 + ("ucs2","ucs2_persian_ci",False), # 144 + ("ucs2","ucs2_esperanto_ci",False), # 145 + ("ucs2","ucs2_hungarian_ci",False), # 146 + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ("utf8","utf8_unicode_ci",False), # 192 + ("utf8","utf8_icelandic_ci",False), # 193 + ("utf8","utf8_latvian_ci",False), # 194 + ("utf8","utf8_romanian_ci",False), # 195 + ("utf8","utf8_slovenian_ci",False), # 196 + ("utf8","utf8_polish_ci",False), # 197 + ("utf8","utf8_estonian_ci",False), # 198 + ("utf8","utf8_spanish_ci",False), # 199 + ("utf8","utf8_swedish_ci",False), # 200 + ("utf8","utf8_turkish_ci",False), # 201 + ("utf8","utf8_czech_ci",False), # 202 + ("utf8","utf8_danish_ci",False), # 203 + ("utf8","utf8_lithuanian_ci",False), # 204 + ("utf8","utf8_slovak_ci",False), # 205 + ("utf8","utf8_spanish2_ci",False), # 206 + ("utf8","utf8_roman_ci",False), # 207 + ("utf8","utf8_persian_ci",False), # 208 + ("utf8","utf8_esperanto_ci",False), # 209 + ("utf8","utf8_hungarian_ci",False), # 210 + ] @classmethod def get_info(cls,setid): - """Returns information about the charset for given MySQL ID.""" - cls._init_desc() - res = () - errmsg = "Character set with id '%d' unsupported." % (setid) + """Retrieves character set information as tuple using an ID + + Retrieves character set and collation information based on the + given MySQL ID. + + Returns a tuple. + """ try: - res = cls.desc[setid] + r = cls.desc[setid] + if r is None: + raise + return r[0:2] except: - raise ProgrammingError, errmsg - - if res is None: - raise ProgrammingError, errmsg - - return res + raise ProgrammingError("Character set '%d' unsupported" % (setid)) @classmethod def get_desc(cls,setid): - """Returns info string about the charset for given MySQL ID.""" - res = () + """Retrieves character set information as string using an ID + + Retrieves character set and collation information based on the + given MySQL ID. + + Returns a tuple. + """ try: - res = "%s/%s" % self.get_info(setid) - except ProgrammingError, e: + return "%s/%s" % cls.get_info(setid) + except: raise - else: - return res @classmethod - def get_charset_info(cls, name, collation=None): - """Returns information about the charset and optional collation.""" - cls._init_desc() - l = len(cls.desc) - errmsg = "Character set '%s' unsupported." % (name) + def get_default_collation(cls, charset): + """Retrieves the default collation for given character set + + Raises ProgrammingError when character set is not supported. + + Returns list (collation, charset, index) + """ + if isinstance(charset, int): + try: + c = cls.desc[charset] + return c[1], c[0], charset + except: + ProgrammingError("Character set ID '%s' unsupported." % ( + charset)) + + for cid, c in enumerate(cls.desc): + if c is None: + continue + if c[0] == charset and c[2] is True: + return c[1], c[0], cid + + raise ProgrammingError("Character set '%s' unsupported." % (charset)) + + @classmethod + def get_charset_info(cls, charset, collation=None): + """Retrieves character set information as tuple using a name + + Retrieves character set and collation information based on the + given a valid name. If charset is an integer, it will look up + the character set based on the MySQL's ID. + + Raises ProgrammingError when character set is not supported. + + Returns a tuple. + """ + idx = None + + if isinstance(charset, int): + try: + c = cls.desc[charset] + return charset, c[0], c[1] + except: + ProgrammingError("Character set ID '%s' unsupported." % ( + charset)) if collation is None: - collation = '%s_general_ci' % (name) - - # Search the list and return when found - idx = 0 - for info in cls.desc: - if info and info[0] == name and info[1] == collation: - return (idx,info[0],info[1]) - idx += 1 - - # If we got here, we didn't find the charset - raise ProgrammingError, errmsg + collation, charset, idx = cls.get_default_collation(charset) + else: + for cid, c in enumerate(cls.desc): + if c is None: + continue + if c[0] == charset and c[1] == collation: + idx = cid + break + + if idx is not None: + return (idx,charset,collation) + else: + raise ProgrammingError("Character set '%s' unsupported." % ( + charset)) @classmethod def get_supported(cls): - """Returns a list with names of all supproted character sets.""" + """Retrieves a list with names of all supproted character sets + + Returns a tuple. + """ res = [] for info in cls.desc: if info and info[0] not in res: diff --git a/mysql/connector/conversion.py b/mysql/connector/conversion.py index fc54ece..3272780 100644 --- a/mysql/connector/conversion.py +++ b/mysql/connector/conversion.py @@ -1,5 +1,5 @@ # MySQL Connector/Python - MySQL driver written in Python. -# Copyright 2009 Sun Microsystems, Inc. All rights reserved +# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved. # Use is subject to license terms. (See COPYING) # This program is free software; you can redistribute it and/or modify @@ -24,8 +24,7 @@ """Converting MySQL and Python types """ -from types import NoneType -import re +import struct import datetime import time from decimal import Decimal @@ -79,8 +78,7 @@ class MySQLConverter(ConverterBase): """ def __init__(self, charset=None, use_unicode=True): ConverterBase.__init__(self, charset, use_unicode) - - # Python types + self.python_types = { int : int, str : self._str_to_mysql, @@ -88,7 +86,7 @@ class MySQLConverter(ConverterBase): float : float, unicode : self._unicode_to_mysql, bool : self._bool_to_mysql, - NoneType : self._none_to_mysql, + type(None) : self._none_to_mysql, datetime.datetime : self._datetime_to_mysql, datetime.date : self._date_to_mysql, datetime.time : self._time_to_mysql, @@ -97,7 +95,6 @@ class MySQLConverter(ConverterBase): Decimal : self._decimal_to_mysql, } - # MySQL types self.mysql_types = { FieldType.TINY : self._int, FieldType.SHORT : self._int, @@ -116,7 +113,9 @@ class MySQLConverter(ConverterBase): FieldType.NEWDATE : self._DATE_to_python, FieldType.DATETIME : self._DATETIME_to_python, FieldType.TIMESTAMP : self._DATETIME_to_python, - FieldType.BLOB : self._STRING_to_python, + FieldType.BLOB : self._BLOB_to_python, + FieldType.YEAR: self._YEAR_to_python, + FieldType.BIT: self._BIT_to_python, } def escape(self, value): @@ -131,9 +130,8 @@ class MySQLConverter(ConverterBase): return value elif isinstance(value, (int,float,long,Decimal)): return value - backslash = re.compile(r'\134') res = value - res = backslash.sub(r'\\\\', res) + res = res.replace('\\','\\\\') res = res.replace('\n','\\n') res = res.replace('\r','\\r') res = res.replace('\047','\134\047') # single quotes @@ -152,7 +150,7 @@ class MySQLConverter(ConverterBase): """ if isinstance(buf, (int,float,long,Decimal)): return str(buf) - elif isinstance(buf, NoneType): + elif isinstance(buf, type(None)): return "NULL" else: # Anything else would be a string @@ -195,12 +193,11 @@ class MySQLConverter(ConverterBase): If the instance isn't a datetime.datetime type, it return None. - Returns a string or None when not valid. + Returns a string. """ - if isinstance(value, datetime.datetime): - return value.strftime('%Y-%m-%d %H:%M:%S') - - return None + return '%d-%02d-%02d %02d:%02d:%02d' % ( + value.year, value.month, value.day, + value.hour, value.minute, value.second) def _date_to_mysql(self, value): """ @@ -209,13 +206,9 @@ class MySQLConverter(ConverterBase): If the instance isn't a datetime.date type, it return None. - Returns a string or None when not valid. + Returns a string. """ - if isinstance(value, datetime.date): - return value.strftime('%Y-%m-%d') - - - return None + return '%d-%02d-%02d' % (value.year, value.month, value.day) def _time_to_mysql(self, value): """ @@ -226,10 +219,7 @@ class MySQLConverter(ConverterBase): Returns a string or None when not valid. """ - if isinstance(value, datetime.time): - return value.strftime('%H:%M:%S') - - return None + return value.strftime('%H:%M:%S') def _struct_time_to_mysql(self, value): """ @@ -239,24 +229,19 @@ class MySQLConverter(ConverterBase): Returns a string or None when not valid. """ - if isinstance(value, time.struct_time): - return time.strftime('%Y-%m-%d %H:%M:%S',value) - return None + return time.strftime('%Y-%m-%d %H:%M:%S',value) def _timedelta_to_mysql(self, value): """ Converts a timedelta instance to a string suitable for MySQL. The returned string has format: %H:%M:%S - Returns a string or None when not valid. + Returns a string. """ - if isinstance(value, datetime.timedelta): - secs = value.seconds%60 - mins = value.seconds%3600/60 - hours = value.seconds/3600+(value.days*24) - return '%d:%02d:%02d' % (hours,mins,secs) - - return None + (hours, r) = divmod(value.seconds, 3600) + (mins, secs) = divmod(r, 60) + hours = hours + (value.days * 24) + return '%02d:%02d:%02d' % (hours,mins,secs) def _decimal_to_mysql(self, value): """ @@ -280,7 +265,7 @@ class MySQLConverter(ConverterBase): """ res = value - if value == '\x00': + if value == '\x00' and flddsc[1] != FieldType.BIT: # Don't go further when we hit a NULL value return None if value is None: @@ -316,7 +301,7 @@ class MySQLConverter(ConverterBase): """ Returns v as long type. """ - return long(v) + return int(v) def _decimal(self, v, desc=None): """ @@ -330,6 +315,13 @@ class MySQLConverter(ConverterBase): """ return str(v) + def _BIT_to_python(self, v, dsc=None): + """Returns BIT columntype as integer""" + s = v + if len(s) < 8: + s = '\x00'*(8-len(s)) + s + return struct.unpack('>Q', s)[0] + def _DATE_to_python(self, v, dsc=None): """ Returns DATE column type as datetime.date type. @@ -361,24 +353,37 @@ class MySQLConverter(ConverterBase): """ pv = None try: - pv = datetime.datetime(*time.strptime(v, "%Y-%m-%d %H:%M:%S")[0:6]) + (sd,st) = v.split(' ') + dt = [ int(v) for v in sd.split('-') ] +\ + [ int(v) for v in st.split(':') ] + pv = datetime.datetime(*dt) except ValueError: pv = None return pv + + def _YEAR_to_python(self, v, desc=None): + """Returns YEAR column type as integer""" + try: + year = int(v) + except ValueError: + raise ValueError("Failed converting YEAR to int (%s)" % v) + + return year def _SET_to_python(self, v, dsc=None): - """ + """Returns SET column typs as set + Actually, MySQL protocol sees a SET as a string type field. So this code isn't called directly, but used by STRING_to_python() method. - Returns SET column type as string splitted using a comma. + Returns SET column type as a set. """ pv = None try: - pv = v.split(',') + pv = set(v.split(',')) except ValueError: - raise ValueError, "Could not convert set %s to a sequence." % v + raise ValueError, "Could not convert SET %s to a set." % v return pv def _STRING_to_python(self, v, dsc=None): @@ -392,6 +397,8 @@ class MySQLConverter(ConverterBase): # Check if we deal with a SET if dsc[7] & FieldFlag.SET: return self._SET_to_python(v, dsc) + if dsc[7] & FieldFlag.BINARY: + return v if self.use_unicode: try: @@ -399,3 +406,11 @@ class MySQLConverter(ConverterBase): except: raise return str(v) + + def _BLOB_to_python(self, v, dsc=None): + if dsc is not None: + if dsc[7] & FieldFlag.BINARY: + return v + + return self._STRING_to_python(v, dsc) + diff --git a/mysql/connector/cursor.py b/mysql/connector/cursor.py index 53b5e49..83d3da4 100644 --- a/mysql/connector/cursor.py +++ b/mysql/connector/cursor.py @@ -1,5 +1,5 @@ # MySQL Connector/Python - MySQL driver written in Python. -# Copyright 2009 Sun Microsystems, Inc. All rights reserved +# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved. # Use is subject to license terms. (See COPYING) # This program is free software; you can redistribute it and/or modify @@ -24,14 +24,20 @@ """Cursor classes """ -import exceptions +import sys +from collections import deque +import weakref +import re -import mysql -import connection +import constants import protocol import errors import utils +RE_SQL_COMMENT = re.compile("\/\*.*\*\/") +RE_SQL_INSERT_VALUES = re.compile(r'\sVALUES\s*(\(.*\))', re.I) +RE_SQL_INSERT_STMT = re.compile(r'INSERT\s+INTO', re.I) + class CursorBase(object): """ Base for defining MySQLCursor. This class is a skeleton and defines @@ -78,6 +84,9 @@ class CursorBase(object): def setoutputsize(self, size, column=None): pass + + def reset(self): + pass class MySQLCursor(CursorBase): """ @@ -100,17 +109,15 @@ class MySQLCursor(CursorBase): def __init__(self, db=None): CursorBase.__init__(self) self.db = None - self.fields = () - self.nrflds = 0 - self._result = [] + self._more_results = False + self._results = deque() self._nextrow = (None, None) self.lastrowid = None self._warnings = None self._warning_count = 0 self._executed = None self._have_result = False - self._get_warnings = False - + self._raise_on_warnings = True if db is not None: self.set_connection(db) @@ -120,41 +127,28 @@ class MySQLCursor(CursorBase): and returns the next row. """ return iter(self.fetchone, None) - - def _valid_protocol(self,db): - if not hasattr(db,'conn'): - raise errors.InterfaceError( - "MySQL connection object connection not valid.") - - try: - if not isinstance(db.conn.protocol,protocol.MySQLProtocol): - raise errors.InterfaceError( - "MySQL connection has no protocol set.") - except AttributeError: - raise errors.InterfaceError( - "MySQL connection object connection not valid.") - - return True def set_connection(self, db): - if isinstance(db,mysql.MySQLBase): - if self._valid_protocol(db): - self.db = db - self.protocol = db.conn.protocol - self.db.register_cursor(self) - self._get_warnings = self.db.get_warnings - else: - raise errors.InterfaceError( - "MySQLCursor db-argument must subclass of mysql.MySQLBase") + try: + if isinstance(db.protocol,protocol.MySQLProtocol): + self.db = weakref.ref(db) + if self not in self.db().cursors: + self.db().cursors.append(self) + except: + raise errors.InterfaceError(errno=2048) def _reset_result(self): - del self._result[:] self.rowcount = -1 self._nextrow = (None, None) self._have_result = False + try: + self.db().unread_result = False + except: + pass self._warnings = None self._warning_count = 0 - self._fields = () + self.description = () + self.reset() def next(self): """ @@ -177,20 +171,21 @@ class MySQLCursor(CursorBase): """ if self.db is None: return False + try: - self.db.remove_cursor(self) + self._reset_result() + self.db().remove_cursor(self) self.db = None except: return False - del self._result[:] return True def _process_params_dict(self, params): try: - to_mysql = self.db.converter.to_mysql - escape = self.db.converter.escape - quote = self.db.converter.quote + to_mysql = self.db().converter.to_mysql + escape = self.db().converter.escape + quote = self.db().converter.quote res = {} for k,v in params.items(): c = v @@ -222,9 +217,9 @@ class MySQLCursor(CursorBase): try: res = params - to_mysql = self.db.converter.to_mysql - escape = self.db.converter.escape - quote = self.db.converter.quote + to_mysql = self.db().converter.to_mysql + escape = self.db().converter.escape + quote = self.db().converter.quote res = map(to_mysql,res) res = map(escape,res) @@ -235,35 +230,11 @@ class MySQLCursor(CursorBase): else: return tuple(res) return None - - def _get_description(self, res=None): - """ - Gets the description of the fields out of a result we got from - the MySQL Server. If res is None then self.description is - returned (which can be None). - - Returns a list or None when no descriptions are available. - """ - if not res: - return self.description - - desc = [] - try: - for fld in res[1]: - if not isinstance(fld, protocol.FieldPacket): - raise errors.ProgrammingError( - "Can only get description from protocol.FieldPacket") - desc.append(fld.get_description()) - except TypeError: - raise errors.ProgrammingError( - "_get_description needs a list as argument." - ) - return desc def _row_to_python(self, rowdata, desc=None): res = () try: - to_python = self.db.converter.to_python + to_python = self.db().converter.to_python if not desc: desc = self.description for idx,v in enumerate(rowdata): @@ -278,19 +249,34 @@ class MySQLCursor(CursorBase): return None def _handle_noresultset(self, res): - """Handles result of execute() when there is no result set.""" + """Handles result of execute() when there is no result set + """ try: - self.rowcount = res.affected_rows - self.lastrowid = res.insert_id - self._warning_count = res.warning_count - if self._get_warnings is True and self._warning_count: + self.rowcount = res['affected_rows'] + self.lastrowid = res['insert_id'] + self._warning_count = res['warning_count'] + if self.db().get_warnings is True and self._warning_count: self._warnings = self._fetch_warnings() + self._set_more_results(res['server_status']) + except errors.Error: + raise except StandardError, e: raise errors.ProgrammingError( "Failed handling non-resultset; %s" % e) def _handle_resultset(self): pass + + def _handle_result(self, res): + if isinstance(res, dict): + self.db().unread_result = False + self._have_result = False + self._handle_noresultset(res) + else: + self.description = res[1] + self.db().unread_result = True + self._have_result = True + self._handle_resultset() def execute(self, operation, params=None): """ @@ -306,32 +292,34 @@ class MySQLCursor(CursorBase): """ if not operation: return 0 + if self.db().unread_result is True: + raise errors.InternalError("Unread result found.") + self._reset_result() stmt = '' - # Make sure we send the query in correct character set try: if isinstance(operation, unicode): - operation.encode(self.db.charset_name) + operation = operation.encode(self.db().charset_name) + if params is not None: - stmt = operation % self._process_params(params) + try: + stmt = operation % self._process_params(params) + except TypeError: + raise errors.ProgrammingError( + "Wrong number of arguments during string formatting") else: stmt = operation - res = self.protocol.cmd_query(stmt) - if isinstance(res, protocol.OKResultPacket): - self._have_result = False - self._handle_noresultset(res) - else: - self.description = self._get_description(res) - self._have_result = True - self._handle_resultset() - except errors.ProgrammingError: - raise - except errors.OperationalError: + + res = self.db().protocol.cmd_query(stmt) + self._handle_result(res) + except (UnicodeDecodeError,UnicodeEncodeError), e: + raise errors.ProgrammingError(str(e)) + except errors.Error: raise except StandardError, e: - raise errors.InterfaceError( - "Failed executing the operation; %s" % e) + raise errors.InterfaceError, errors.InterfaceError( + "Failed executing the operation; %s" % e), sys.exc_info()[2] else: self._executed = stmt return self.rowcount @@ -339,45 +327,99 @@ class MySQLCursor(CursorBase): return 0 def executemany(self, operation, seq_params): - """Loops over seq_params and calls excute()""" + """Loops over seq_params and calls execute() + + INSERT statements are optimized by batching the data, that is + using the MySQL multiple rows syntax. + """ if not operation: return 0 - - rowcnt = 0 - try: + if self.db().unread_result is True: + raise errors.InternalError("Unread result found.") + + # Optimize INSERTs by batching them + if re.match(RE_SQL_INSERT_STMT,operation): + opnocom = re.sub(RE_SQL_COMMENT,'',operation) + m = re.search(RE_SQL_INSERT_VALUES,opnocom) + fmt = m.group(1) + values = [] for params in seq_params: - self.execute(operation, params) - if self._have_result: - self.fetchall() - rowcnt += self.rowcount - except (ValueError,TypeError), e: - raise errors.InterfaceError( - "Failed executing the operation; %s" % e) - except: - # Raise whatever execute() raises - raise - - return rowcnt + values.append(fmt % self._process_params(params)) + operation = re.sub(re.escape(m.group(1)), + ','.join(values),operation,count=1) + self.execute(operation) + else: + rowcnt = 0 + try: + for params in seq_params: + self.execute(operation, params) + if self._have_result: + self.fetchall() + rowcnt += self.rowcount + except (ValueError,TypeError), e: + raise errors.InterfaceError( + "Failed executing the operation; %s" % e) + except: + # Raise whatever execute() raises + raise + self.rowcount = rowcnt + return self.rowcount + def _set_more_results(self, flags): + flag = constants.ServerFlag.MORE_RESULTS_EXISTS + self._more_results = constants.flag_is_set(flag, flags) + + def next_resultset(self): + """Gets next result after executing multiple statements + + When more results are available, this function will reset the + current result and advance to the next set. + + This is useful when executing multiple statements. If you need + to retrieve multiple results after executing a stored procedure + using callproc(), use next_proc_resultset() instead. + """ + if self._more_results is True: + buf = self.db().protocol._recv_packet() + res = self.db().protocol.handle_cmd_result(buf) + self._reset_result() + self._handle_result(res) + return True + + return None + + def next_proc_resultset(self): + """Get the next result set after calling a stored procedure + + Returns a MySQLCursorBuffered-object""" + try: + return self._results.popleft() + except IndexError: + return None + except: + raise + + return None + def callproc(self, procname, args=()): """Calls a stored procedue with the given arguments - + The arguments will be set during this session, meaning they will be called like ___arg where is an enumeration (+1) of the arguments. - + Coding Example: 1) Definining the Stored Routine in MySQL: CREATE PROCEDURE multiply(IN pFac1 INT, IN pFac2 INT, OUT pProd INT) BEGIN SET pProd := pFac1 * pFac2; END - + 2) Executing in Python: args = (5,5,0) # 0 is to hold pprod cursor.callproc(multiply, args) print cursor.fetchone() - + The last print should output ('5', '5', 25L) Does not return a value, but a result set will be @@ -385,24 +427,36 @@ class MySQLCursor(CursorBase): Raises exceptions when something is wrong. """ argfmt = "@_%s_arg%d" - + self._results = deque() + try: procargs = self._process_params(args) argnames = [] - + for idx,arg in enumerate(procargs): argname = argfmt % (procname, idx+1) argnames.append(argname) setquery = "SET %s=%%s" % argname self.execute(setquery, (arg,)) - + call = "CALL %s(%s)" % (procname,','.join(argnames)) - res = self.protocol.cmd_query(call) + res = self.db().protocol.cmd_query(call) - select = "SELECT %s" % ','.join(argnames) - self.execute(select) - - except errors.ProgrammingError: + while not isinstance(res, dict): + tmp = MySQLCursorBuffered(self.db()) + tmp.description = res[1] + tmp._handle_resultset() + self._results.append(tmp) + buf = self.db().protocol._recv_packet() + res = self.db().protocol.handle_cmd_result(buf) + try: + select = "SELECT %s" % ','.join(argnames) + self.execute(select) + return self.fetchone() + except: + raise + + except errors.Error: raise except StandardError, e: raise errors.InterfaceError( @@ -420,13 +474,17 @@ class MySQLCursor(CursorBase): """ res = [] try: - c = self.db.cursor() + c = self.db().cursor() cnt = c.execute("SHOW WARNINGS") res = c.fetchall() c.close() except StandardError, e: - raise errors.ProgrammingError( + raise errors.InterfaceError( "Failed getting warnings; %s" % e) + + if self.db().raise_on_warnings is True: + msg = '; '.join([ "(%s) %s" % (r[1],r[2]) for r in res]) + raise errors.get_mysql_exception(res[0][1],res[0][2]) else: if len(res): return res @@ -435,22 +493,25 @@ class MySQLCursor(CursorBase): def _handle_eof(self, eof): self._have_result = False + self.db().unread_result = False self._nextrow = (None, None) - self._warning_count = eof.warning_count - if self.db.get_warnings is True and eof.warning_count: + self._warning_count = eof['warning_count'] + if self.db().get_warnings is True and eof['warning_count']: self._warnings = self._fetch_warnings() - + self._set_more_results(eof['status_flag']) + def _fetch_row(self): if self._have_result is False: return None row = None try: if self._nextrow == (None, None): - (row, eof) = self.protocol.result_get_row() + (row, eof) = self.db().protocol.get_row() else: (row, eof) = self._nextrow if row: - (foo, eof) = self._nextrow = self.protocol.result_get_row() + (foo, eof) = self._nextrow = \ + self.db().protocol.get_row() if eof is not None: self._handle_eof(eof) if self.rowcount == -1: @@ -491,11 +552,15 @@ class MySQLCursor(CursorBase): raise errors.InterfaceError("No result set to fetch from.") res = [] row = None - while self._have_result: + while self.db().unread_result: row = self.fetchone() if row: res.append(row) return res + + @property + def column_names(self): + return tuple( [d[0].decode('utf8') for d in self.description] ) def __unicode__(self): fmt = "MySQLCursor: %s" @@ -516,27 +581,71 @@ class MySQLCursorBuffered(MySQLCursor): def __init__(self, db=None): MySQLCursor.__init__(self, db) - self._rows = [] + self._rows = None self._next_row = 0 def _handle_resultset(self): - self._get_all_rows() - - def _get_all_rows(self): - (self._rows, eof) = self.protocol.result_get_rows() + (self._rows, eof) = self.db().protocol.get_rows() self.rowcount = len(self._rows) self._handle_eof(eof) self._next_row = 0 - self._have_result = True + try: + self.db().unread_result = False + except: + pass + + def reset(self): + self._rows = None def _fetch_row(self): row = None try: row = self._rows[self._next_row] except: - self._have_result = False return None else: self._next_row += 1 return row return None + + def fetchall(self): + if self._rows is None: + raise errors.InterfaceError("No result set to fetch from.") + res = [] + for row in self._rows: + res.append(self._row_to_python(row)) + self._next_row = len(self._rows) + return res + + def fetchmany(self,size=None): + res = [] + cnt = (size or self.arraysize) + while cnt > 0: + cnt -= 1 + row = self.fetchone() + if row: + res.append(row) + + return res + +class MySQLCursorRaw(MySQLCursor): + + def fetchone(self): + row = self._fetch_row() + if row: + return tuple(row) + return None + +class MySQLCursorBufferedRaw(MySQLCursorBuffered): + + def fetchone(self): + row = self._fetch_row() + if row: + return tuple(row) + return None + + def fetchall(self): + if self._rows is None: + raise errors.InterfaceError("No result set to fetch from.") + return [ tuple(r) for r in self._rows ] + diff --git a/mysql/connector/dbapi.py b/mysql/connector/dbapi.py index 892a5cb..57fdcf1 100644 --- a/mysql/connector/dbapi.py +++ b/mysql/connector/dbapi.py @@ -1,5 +1,5 @@ # MySQL Connector/Python - MySQL driver written in Python. -# Copyright 2009 Sun Microsystems, Inc. All rights reserved +# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved. # Use is subject to license terms. (See COPYING) # This program is free software; you can redistribute it and/or modify diff --git a/mysql/connector/errors.py b/mysql/connector/errors.py index f5f61c6..f3dd480 100644 --- a/mysql/connector/errors.py +++ b/mysql/connector/errors.py @@ -1,5 +1,5 @@ # MySQL Connector/Python - MySQL driver written in Python. -# Copyright 2009 Sun Microsystems, Inc. All rights reserved +# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved. # Use is subject to license terms. (See COPYING) # This program is free software; you can redistribute it and/or modify @@ -24,32 +24,152 @@ """Python exceptions """ -import exceptions -import protocol +import logging + +logger = logging.getLogger('myconnpy') + +# see get_mysql_exceptions method for errno ranges and smaller lists +__programming_errors = ( + 1083,1084,1090,1091,1093,1096,1097,1101,1102,1103,1107,1108,1110,1111, + 1113,1120,1124,1125,1128,1136,1366,1139,1140,1146,1149,) +__operational_errors = ( + 1028,1029,1030,1053,1077,1078,1079,1080,1081,1095,1104,1106,1114,1116, + 1117,1119,1122,1123,1126,1133,1135,1137,1145,1147,) + +def get_mysql_exception(errno,msg): + + exception = OperationalError + + if (errno >= 1046 and errno <= 1052) or \ + (errno >= 1054 and errno <= 1061) or \ + (errno >= 1063 and errno <= 1075) or \ + errno in __programming_errors: + exception = ProgrammingError + elif errno in (1097,1109,1118,1121,1138,1292): + exception = DataError + elif errno in (1031,1089,1112,1115,1127,1148,1149): + exception = NotSupportedError + elif errno in (1062,1082,1099,1100): + exception = IntegrityError + elif errno in (1085,1086,1094,1098): + exception = InternalError + elif (errno >= 1004 and errno <= 1030) or \ + (errno >= 1132 and errno <= 1045) or \ + (errno >= 1141 and errno <= 1145) or \ + (errno >= 1129 and errno <= 1133) or \ + errno in __operational_errors: + exception = OperationalError + + return exception(msg,errno=errno) + +class ClientError(object): + + client_errors = { + 2000: "Unknown MySQL error", + 2001: "Can't create UNIX socket (%(socketaddr)d)", + 2002: "Can't connect to local MySQL server through socket '%(socketaddr)s' (%(errno)s)", + 2003: "Can't connect to MySQL server on '%(socketaddr)s' (%(errno)s)", + 2004: "Can't create TCP/IP socket (%s)", + 2005: "Unknown MySQL server host '%(socketaddr)s' (%s)", + 2006: "MySQL server has gone away", + 2007: "Protocol mismatch; server version = %(server_version)d, client version = %(client_version)d", + 2008: "MySQL client ran out of memory", + 2009: "Wrong host info", + 2010: "Localhost via UNIX socket", + 2011: "%(misc)s via TCP/IP", + 2012: "Error in server handshake", + 2013: "Lost connection to MySQL server during query", + 2014: "Commands out of sync; you can't run this command now", + 2015: "Named pipe: %(socketaddr)s", + 2016: "Can't wait for named pipe to host: %(host)s pipe: %(socketaddr)s (%(errno)d)", + 2017: "Can't open named pipe to host: %s pipe: %s (%(errno)d)", + 2018: "Can't set state of named pipe to host: %(host)s pipe: %(socketaddr)s (%(errno)d)", + 2019: "Can't initialize character set %(charset)s (path: %(misc)s)", + 2020: "Got packet bigger than 'max_allowed_packet' bytes", + 2021: "Embedded server", + 2022: "Error on SHOW SLAVE STATUS:", + 2023: "Error on SHOW SLAVE HOSTS:", + 2024: "Error connecting to slave:", + 2025: "Error connecting to master:", + 2026: "SSL connection error", + 2027: "Malformed packet", + 2028: "This client library is licensed only for use with MySQL servers having '%s' license", + 2029: "Invalid use of null pointer", + 2030: "Statement not prepared", + 2031: "No data supplied for parameters in prepared statement", + 2032: "Data truncated", + 2033: "No parameters exist in the statement", + 2034: "Invalid parameter number", + 2035: "Can't send long data for non-string/non-binary data types (parameter: %d)", + 2036: "Using unsupported buffer type: %d (parameter: %d)", + 2037: "Shared memory: %s", + 2038: "Can't open shared memory; client could not create request event (%d)", + 2039: "Can't open shared memory; no answer event received from server (%d)", + 2040: "Can't open shared memory; server could not allocate file mapping (%d)", + 2041: "Can't open shared memory; server could not get pointer to file mapping (%d)", + 2042: "Can't open shared memory; client could not allocate file mapping (%d)", + 2043: "Can't open shared memory; client could not get pointer to file mapping (%d)", + 2044: "Can't open shared memory; client could not create %s event (%d)", + 2045: "Can't open shared memory; no answer from server (%d)", + 2046: "Can't open shared memory; cannot send request event to server (%d)", + 2047: "Wrong or unknown protocol", + 2048: "Invalid connection handle", + 2049: "Connection using old (pre-4.1.1) authentication protocol refused (client option 'secure_auth' enabled)", + 2050: "Row retrieval was canceled by mysql_stmt_close() call", + 2051: "Attempt to read column without prior row fetch", + 2052: "Prepared statement contains no metadata", + 2053: "Attempt to read a row while there is no result set associated with the statement", + 2054: "This feature is not implemented yet", + 2055: "Lost connection to MySQL server at '%(socketaddr)s', system error: %(errno)d", + 2056: "Statement closed indirectly because of a preceeding %s() call", + 2057: "The number of columns in the result set differs from the number of bound buffers. You must reset the statement, rebind the result set columns, and execute the statement again", + } + + def __new__(cls): + raise TypeError, "Can not instanciate from %s" % cls.__name__ + + @classmethod + def get_error_msg(cls,errno,values=None): + res = None + if cls.client_errors.has_key(errno): + if values is not None: + try: + res = cls.client_errors[errno] % values + except: + logger.debug("Missing values for errno %d" % errno) + res = cls.client_errors[errno], "(missing values!)" + else: + res = cls.client_errors[errno] + if res is None: + res = "Unknown client error %d" % errno + logger.debug(res) + return res class Error(StandardError): - def __init__(self, m): - if isinstance(m,protocol.ErrorResultPacket): + def __init__(self, m, errno=None, values=None): + try: # process MySQL error packet self._process_packet(m) - else: - # else the message should be a string - self.errno = -1 - self.errmsg = str(m) + except: + self.errno = errno or -1 self.sqlstate = -1 - self.msg = str(m) - + if m is None and (errno >= 2000 and errno < 3000): + m = ClientError.get_error_msg(errno,values) + elif m is None: + m = 'Unknown error' + if self.errno != -1: + self.msg = "%s: %s" % (self.errno,m) + else: + self.msg = m + def _process_packet(self, packet): self.errno = packet.errno - self.errmsg = packet.errmsg self.sqlstate = packet.sqlstate if self.sqlstate: - m = '%d (%s): %s' % (self.errno, self.sqlstate, self.errmsg) + self.msg = '%d (%s): %s' % (self.errno,self.sqlstate,packet.errmsg) else: - m = '%d: %s' % (self.errno, self.errmsg) - self.errmsglong = m - self.msg = m + self.msg = '%d: %s' % (self.errno, packet.errmsg) def __str__(self): return self.msg @@ -61,12 +181,12 @@ class Warning(StandardError): pass class InterfaceError(Error): - def __init__(self, msg): - Error.__init__(self, msg) + def __init__(self, m=None, errno=None, values=None): + Error.__init__(self, m, errno, values) class DatabaseError(Error): - def __init__(self, msg): - Error.__init__(self, msg) + def __init__(self, m=None, errno=None, values=None): + Error.__init__(self, m, errno, values) class InternalError(DatabaseError): pass diff --git a/mysql/connector/mysql.py b/mysql/connector/mysql.py deleted file mode 100644 index c6861a2..0000000 --- a/mysql/connector/mysql.py +++ /dev/null @@ -1,414 +0,0 @@ -# MySQL Connector/Python - MySQL driver written in Python. -# Copyright 2009 Sun Microsystems, Inc. All rights reserved -# Use is subject to license terms. (See COPYING) - -# This program is free software; you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation. -# -# There are special exceptions to the terms and conditions of the GNU -# General Public License as it is applied to this software. View the -# full text of the exception in file EXCEPTIONS-CLIENT in the directory -# of this software distribution or see the FOSS License Exception at -# www.mysql.com. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program; if not, write to the Free Software -# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - -"""Main classes for interacting with MySQL -""" - -import socket, string, os - -from connection import * -import constants -import conversion -import protocol -import errors -import utils -import cursor - - -class MySQLBase(object): - """MySQLBase""" - - def __init__(self): - """Initializing""" - self.conn = None # Holding the connection - self.converter = None - - self.client_flags = constants.ClientFlag.get_default() - (self.charset, - self.charset_name, - self.collation_name) = constants.CharacterSet.get_charset_info('utf8') - - self.username = '' - self.password = '' - self.database = '' - self.client_host = '' - self.client_port = 0 - - self.affected_rows = 0 - self.server_status = 0 - self.warning_count = 0 - self.field_count = 0 - self.insert_id = 0 - self.info_msg = '' - self.use_unicode = True - self.get_warnings = False - self.autocommit = False - self.connection_timeout = None - self.buffered = False - - def connect(self): - """To be implemented while subclassing MySQLBase.""" - pass - - def _set_connection(self, prtcls=None): - """Automatically chooses based on configuration which connection type to setup.""" - if self.unix_socket and os.name != 'nt': - self.conn = MySQLUnixConnection(prtcls=prtcls, - unix_socket=self.unix_socket) - else: - self.conn = MySQLTCPConnection(prtcls=prtcls, - host=self.server_host, port=self.server_port) - self.conn.set_connection_timeout(self.connection_timeout) - - def _open_connection(self): - """Opens the connection and sets the appropriated protocol.""" - # We don't know yet the MySQL version we connect too - self._set_connection() - try: - self.conn.open_connection() - version = self.conn.protocol.server_version - if version < (4,1): - raise InterfaceError("MySQL Version %s is not supported." % version) - else: - self.conn.set_protocol(protocol.MySQLProtocol) - self.protocol = self.conn.protocol - self.protocol.do_auth(username=self.username, password=self.password, - database=self.database) - except: - raise - - def _post_connection(self): - """Should be called after a connection was established""" - self.get_characterset_info() - self.set_converter_class(conversion.MySQLConverter) - - try: - self.set_charset(self.charset_name) - self.set_autocommit(self.autocommit) - except: - raise - - def is_connected(self): - """ - Check whether we are connected to the MySQL server. - """ - return self.protocol.cmd_ping() - ping = is_connected - - def disconnect(self): - """ - Disconnect from the MySQL server. - """ - if not self.conn: - return - - if self.conn.sock is not None: - self.protocol.cmd_quit() - try: - self.conn.close_connection() - except: - pass - self.protocol = None - self.conn = None - - def set_converter_class(self, convclass): - """ - Set the converter class to be used. This should be a class overloading - methods and members of conversion.MySQLConverter. - """ - self.converter_class = convclass - self.converter = self.converter_class(self.charset_name, self.use_unicode) - - def get_characterset_info(self): - try: - (self.charset_name, self.collation_name) = constants.CharacterSet.get_info(self.charset) - except: - raise ProgrammingError, "Illegal character set information (id=%d)" % self.charset - return (self.charset_name, self.collation_name) - - def get_server_version(self): - """Returns the server version as a tuple""" - try: - return self.protocol.server_version - except: - pass - - return None - - def get_server_info(self): - """Returns the server version as a string""" - return self.protocol.server_version_original - - def get_server_threadid(self): - """Returns the MySQL threadid of the connection.""" - threadid = None - try: - threadid = self.protocol.server_threadid - except: - pass - - return threadid - - def set_host(self, host): - """ - Set the host for connection to the MySQL server. - """ - self.server_host = host - - def set_port(self, port): - """ - Set the TCP port to be used when connecting to the server, usually 3306. - """ - self.server_port = port - - def set_login(self, username=None, password=None): - """ - Set the username and/or password for the user connecting to the MySQL Server. - """ - self.username = username - self.password = password - - def set_unicode(self, value=True): - """ - Set whether we return string fields as unicode or not. - Default is True. - """ - self.use_unicode = value - if self.converter: - self.converter.set_unicode(value) - - def set_database(self, database): - """ - Set the database to be used after connection succeeded. - """ - self.database = database - - def set_charset(self, name): - """ - Set the character set used for the connection. This is the recommended - way of change it per connection basis. It does execute SET NAMES - internally, but it's good not to use this command directly, since we - are setting some other members accordingly. - """ - if name not in constants.CharacterSet.get_supported(): - raise errors.ProgrammingError, "Character set '%s' not supported." % name - return - try: - info = constants.CharacterSet.get_charset_info(name) - except errors.ProgrammingError, e: - raise - - try: - self.protocol.cmd_query("SET NAMES '%s'" % name) - except: - raise - else: - (self.charset, self.charset_name, self.collation_name) = info - self.converter.set_charset(self.charset_name) - - def set_getwarnings(self, bool): - """ - Set wheter we should get warnings whenever an operation produced some. - """ - self.get_warnings = bool - - def set_autocommit(self, switch): - """ - Set auto commit on or off. The argument 'switch' must be a boolean type. - """ - if not isinstance(switch, bool): - raise ValueError, "The switch argument must be boolean." - - s = 'OFF' - if switch: - s = 'ON' - - try: - self.protocol.cmd_query("SET AUTOCOMMIT = %s" % s) - except: - raise - else: - self.autocommit = switch - - def set_unixsocket(self, loc): - """Set the UNIX Socket location. Does not check if it exists.""" - self.unix_socket = loc - - def set_connection_timeout(self, timeout): - self.connection_timeout = timeout - - def set_client_flags(self, flags): - self.client_flags = flags - - def set_buffered(self, val=False): - """Sets whether cursor .execute() fetches rows""" - self.buffered = val - -class MySQL(MySQLBase): - """ - Class implementing Python DB API v2.0. - """ - - def __init__(self, *args, **kwargs): - """ - Initializes the MySQL object. Calls connect() to open the connection - when an instance is created. - """ - MySQLBase.__init__(self) - self.cursors = [] - self.affected_rows = 0 - self.server_status = 0 - self.warning_count = 0 - self.field_count = 0 - self.insert_id = 0 - self.info_msg = '' - - self.connect(*args, **kwargs) - - def connect(self, dsn='', user='', password='', host='127.0.0.1', - port=3306, db=None, database=None, use_unicode=True, charset='utf8', get_warnings=False, - autocommit=False, unix_socket=None, - connection_timeout=None, client_flags=None, buffered=False): - """ - Establishes a connection to the MySQL Server. Called also when instansiating - a new MySQLConnection object through the __init__ method. - - Possible parameters are: - - dsn - (not used) - user - The username used to authenticate with the MySQL Server. - - password - The password to authenticate the user with the MySQL Server. - - host - The hostname or the IP address of the MySQL Server we are connecting with. - (default 127.0.0.1) - - port - TCP port to use for connecting to the MySQL Server. - (default 3306) - - database - db - Initial database to use once we are connected with the MySQL Server. - The db argument is synonym, but database takes precedence. - - use_unicode - If set to true, string values received from MySQL will be returned - as Unicode strings. - Default: True - - charset - Which character shall we use for sending data to MySQL. One can still - override this by using the SET NAMES command directly, but this is - discouraged. Instead, use the set_charset() method if you - want to change it. - Default: Whatever the MySQL server has default. - - get_warnings - If set to true, whenever a query gives a warning, a SHOW WARNINGS will - be done to fetch them. They will be available as MySQLCursor.warnings. - The default is to ignore these warnings, for debugging it's good to - enable it though, or use strict mode in MySQL to make most of these - warnings errors. - Default: False - - autocommit - Auto commit is OFF by default, which is required by the Python Db API - 2.0 specification. - Default: False - - unix_socket - Full path to the MySQL Server UNIX socket. By default TCP connection will - be used using the address specified by the host argument. - - connection_timeout - Timeout for the TCP and UNIX socket connection. - - client_flags - Allows to set flags for the connection. Check following for possible flags: - >>> from mysql.connector.constants import ClientFlag - >>> print '\n'.join(ClientFlag.get_full_info()) - - buffered - When set to True .execute() will fetch the rows immediatly. - - """ - # db is not part of Db API v2.0, but MySQLdb supports it. - if db and not database: - database = db - - self.set_host(host) - self.set_port(port) - self.set_database(database) - self.set_getwarnings(get_warnings) - self.set_unixsocket(unix_socket) - self.set_connection_timeout(connection_timeout) - self.set_client_flags(client_flags) - self.set_buffered(buffered) - - if user or password: - self.set_login(user, password) - - self.disconnect() - self._open_connection() - self._post_connection() - - def close(self): - del self.cursors[:] - self.disconnect() - - def remove_cursor(self, c): - try: - self.cursors.remove(c) - except ValueError: - raise errors.ProgrammingError( - "Cursor could not be removed.") - - def register_cursor(self, c): - try: - self.cursors.append(c) - except: - raise - - def cursor(self): - if self.buffered: - c = (cursor.MySQLCursorBuffered)(self) - else: - c = (cursor.MySQLCursor)(self) - - self.register_cursor(c) - return c - - def commit(self): - """Shortcut for executing COMMIT.""" - self.protocol.cmd_query("COMMIT") - - def rollback(self): - """Shortcut for executing ROLLBACK""" - self.protocol.cmd_query("ROLLBACK") - - diff --git a/mysql/connector/protocol.py b/mysql/connector/protocol.py index 53dc9d1..bfc2891 100644 --- a/mysql/connector/protocol.py +++ b/mysql/connector/protocol.py @@ -1,5 +1,5 @@ # MySQL Connector/Python - MySQL driver written in Python. -# Copyright 2009 Sun Microsystems, Inc. All rights reserved +# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved. # Use is subject to license terms. (See COPYING) # This program is free software; you can redistribute it and/or modify @@ -24,8 +24,6 @@ """Implementing the MySQL Client/Server protocol """ -import string -import socket import re import struct @@ -42,326 +40,515 @@ from constants import * import errors import utils -class MySQLProtocol(object): - """Class handling the MySQL Protocol. - - MySQL v4.1 Client/Server Protocol is currently supported. - """ - def __init__(self, conn, handshake=None): +def packet_is_error(idx=0,label=None): + def deco(func): + def call(*args, **kwargs): + try: + if label: + pktdata = kwargs[label] + else: + pktdata = args[idx] + except Exception, e: + raise errors.InterfaceError( + "Can't check for Error packet; %s" % e) + + try: + if pktdata and pktdata[4] == '\xff': + MySQLProtocol.raise_error(pktdata) + except errors.Error: + raise + except: + pass + return func(*args, **kwargs) + return call + return deco + +def packet_is_ok(idx=0,label=None): + def deco(func): + def call(*args, **kwargs): + try: + if label: + pktdata = kwargs[label] + else: + pktdata = args[idx] + except: + raise errors.InterfaceError("Can't check for OK packet.") + + try: + if pktdata and pktdata[4] == '\x00': + return func(*args, **kwargs) + else: + raise + except Exception, e: + raise errors.InterfaceError("Expected OK packet") + return call + return deco + +def packet_is_eof(idx=0,label=None): + def deco(func): + def call(*args, **kwargs): + try: + if label: + pktdata = kwargs[label] + else: + pktdata = args[idx] + except: + raise errors.InterfaceError("Can't check for EOF packet.") + if pktdata[4] == '\xfe' and len(pktdata) == 9: + return func(*args, **kwargs) + else: + raise errors.InterfaceError("Expected EOF packet") + + return call + return deco + +def set_pktnr(idx=1,label=None): + def deco(func): + def call(*args, **kwargs): + try: + if label: + pktdata = kwargs[label] + else: + pktdata = args[idx] + except: + raise errors.InterfaceError("Can't check for EOF packet.") + + try: + args[0].pktnr = ord(pktdata[3]) + pktdata = pktdata[4:] + if label: + kwargs[label] = pktdata + else: + args = list(args) + args[idx] = pktdata + except: + raise errors.InterfaceError("Failed getting Packet Number.") + return func(*args,**kwargs) + return call + return deco + +def reset_pktnr(func): + def deco(*args, **kwargs): + try: + args[0].pktnr = -1 + except: + pass + return func(*args, **kwargs) + + return deco + +class MySQLProtocolBase(object): + pass + +class MySQLProtocol(MySQLProtocolBase): + + def __init__(self, conn): self.client_flags = 0 - self.conn = conn # MySQL Connection - if handshake: - self.set_handshake(handshake) + self.conn = conn + self.pktnr = -1 - def handle_header(self, buf): - """Takes a buffer and readers information from header. + @classmethod + def raise_error(cls, buf): + """Raise an errors.Error when buffer has a MySQL error""" + errno = errmsg = None + try: + buf = buf[5:] + (buf,errno) = utils.read_int(buf, 2) + if buf[0] != '\x23': + # Error without SQLState + errmsg = buf + else: + (buf,sqlstate) = utils.read_bytes(buf[1:],5) + errmsg = buf + except Exception, e: + raise errors.InterfaceError("Failed getting Error information (%r)"\ + % e) + else: + raise errors.get_mysql_exception(errno,errmsg) + + def _recv_packet(self): + """Getting a packet from the MySQL server""" + buf = self.conn.recv() + if buf[4] == '\xff': + MySQLProtocol.raise_error(buf) + else: + return buf - Returns a tuple (pktsize, pktnr) - """ - pktsize = utils.int3read(buf[0:3]) - pktnr = utils.int1read(buf[3]) + def _scramble_password(self, passwd, seed): + """Scramble a password ready to send to MySQL""" + hash4 = None + try: + hash1 = sha1(passwd).digest() + hash2 = sha1(hash1).digest() # Password as found in mysql.user() + hash3 = sha1(seed + hash2).digest() + xored = [ utils.intread(h1) ^ utils.intread(h3) + for (h1,h3) in zip(hash1, hash3) ] + hash4 = struct.pack('20B', *xored) + except Exception, e: + raise errors.InterfaceError('Failed scrambling password; %s' % e) - return (pktsize, pktnr) + return hash4 + + def _pkt_make_header(self, pktlength, pktnr=None): + """Make the header for a MySQL packet""" + pktnr = pktnr or self.pktnr+1 + return utils.int3store(pktlength) + utils.int1store(pktnr) + def _prepare_auth(self, usr, pwd, db, flags, seed): + + if usr is not None and len(usr) > 0: + _username = usr + '\x00' + else: + _username = '\x00' + + if pwd is not None and len(pwd) > 0: + _password = utils.int1store(20) +\ + self._scramble_password(pwd,seed) + else: + _password = '\x00' + + if db is not None and len(db): + _database = db + '\x00' + else: + _database = '\x00' + + return (_username, _password, _database) + + def _pkt_make_auth(self, username=None, password=None, database=None, + seed=None, charset=33, client_flags=0): + """Make a MySQL Authentication packet""" + try: + seed = seed or self.scramble + except: + raise errors.ProgrammingError('Seed missing') + + (_username, _password, _database) = self._prepare_auth( + username, password, database, client_flags, seed) + data = utils.int4store(client_flags) +\ + utils.int4store(10 * 1024 * 1024) +\ + utils.int1store(charset) +\ + '\x00'*23 +\ + _username +\ + _password +\ + _database + + header = self._pkt_make_header(len(data)) + return header+data + + def _pkt_make_command(self, command, argument=None): + """Make a MySQL packet containing a command""" + data = utils.int1store(command) + if argument is not None: + data += str(argument) + header = self._pkt_make_header(len(data)) + return header+data + + def _pkt_make_changeuser(self, username=None, password=None, + database=None, charset=8, seed=None): + """Make a MySQL packet with the Change User command""" + try: + seed = seed or self.scramble + except: + raise errors.ProgrammingError('Seed missing') + + (_username, _password, _database) = self._prepare_auth( + username, password, database, self.client_flags, seed) + data = utils.int1store(ServerCmd.CHANGE_USER) +\ + _username +\ + _password +\ + _database +\ + utils.int2store(charset) + + header = self._pkt_make_header(len(data)) + return header+data + + @set_pktnr(1) + def _pkt_parse_handshake(self, buf): + """Parse a MySQL Handshake-packet""" + res = {} + (buf,res['protocol']) = utils.read_int(buf,1) + (buf,res['server_version_original']) = utils.read_string(buf,end='\x00') + (buf,res['server_threadid']) = utils.read_int(buf,4) + (buf,res['scramble']) = utils.read_bytes(buf, 8) + buf = buf[1:] # Filler 1 * \x00 + (buf,res['capabilities']) = utils.read_int(buf,2) + (buf,res['charset']) = utils.read_int(buf,1) + (buf,res['server_status']) = utils.read_int(buf,2) + buf = buf[13:] # Filler 13 * \x00 + (buf,scramble_next) = utils.read_bytes(buf,12) + res['scramble'] += scramble_next + return res + + @set_pktnr(1) + def _pkt_parse_ok(self, buf): + """Parse a MySQL OK-packet""" + ok = {} + (buf,ok['field_count']) = utils.read_int(buf,1) + (buf,ok['affected_rows']) = utils.read_lc_int(buf) + (buf,ok['insert_id']) = utils.read_lc_int(buf) + (buf,ok['server_status']) = utils.read_int(buf,2) + (buf,ok['warning_count']) = utils.read_int(buf,2) + if buf: + (buf,ok['info_msg']) = utils.read_lc_string(buf) + return ok + + @set_pktnr(1) + def _pkt_parse_field(self, buf): + """Parse a MySQL Field-packet""" + field = {} + (buf,field['catalog']) = utils.read_lc_string(buf) + (buf,field['db']) = utils.read_lc_string(buf) + (buf,field['table']) = utils.read_lc_string(buf) + (buf,field['org_table']) = utils.read_lc_string(buf) + (buf,field['name']) = utils.read_lc_string(buf) + (buf,field['org_name']) = utils.read_lc_string(buf) + buf = buf[1:] # filler 1 * \x00 + (buf,field['charset']) = utils.read_int(buf, 2) + (buf,field['length']) = utils.read_int(buf, 4) + (buf,field['type']) = utils.read_int(buf, 1) + (buf,field['flags']) = utils.read_int(buf, 2) + (buf,field['decimal']) = utils.read_int(buf, 1) + buf = buf[2:] # filler 2 * \x00 + + res = ( + field['name'], + field['type'], + None, # display_size + None, # internal_size + None, # precision + None, # scale + ~field['flags'] & FieldFlag.NOT_NULL, # null_ok + field['flags'], # MySQL specific + ) + return res + + @set_pktnr(1) + def _pkt_parse_eof(self, buf): + """Parse a MySQL EOF-packet""" + res = {} + buf = buf[1:] # disregard the first checking byte + (buf, res['warning_count']) = utils.read_int(buf, 2) + (buf, res['status_flag']) = utils.read_int(buf, 2) + return res + + def do_handshake(self): + """Get the handshake from the MySQL server""" + try: + self.conn.open_connection() + buf = self._recv_packet() + self.handle_handshake(buf) + except: + raise + def do_auth(self, username=None, password=None, database=None, - client_flags=None): + client_flags=0, charset=33): + """Authenticate with the MySQL server """ - Make and send the authentication using information found in the - handshake packet. - """ - if not client_flags: - client_flags = ClientFlag.get_default() + pkt = self._pkt_make_auth(username=username, password=password, + database=database, charset=charset, + client_flags=client_flags) - auth = Auth(client_flags=client_flags, - pktnr=self.handshake.pktnr+1) - auth.create(username=username, password=password, - seed=self.handshake.info['seed'], database=database) - - self.conn.send(auth.get()) - buf = self.conn.recv()[0] - if self.is_eof(buf): - raise errors.InterfaceError("Found EOF after Auth, expecting OK. Using old passwords?") + self.conn.send(pkt) + buf = self._recv_packet() + if buf[4] == '\xfe': + raise errors.NotSupportedError( + "Authentication with old (insecure) passwords "\ + "is not supported: "\ + "http://dev.mysql.com/doc/refman/5.1/en/password-hashing.html") - connect_with_db = client_flags & ClientFlag.CONNECT_WITH_DB - if self.is_ok(buf) and database and not connect_with_db: - self.cmd_init_db(database) + try: + if not (client_flags & ClientFlag.CONNECT_WITH_DB) and database: + self.cmd_init_db(database) + except: + raise + + return True def handle_handshake(self, buf): - """ + """Check and handle the MySQL server's handshake + Check whether the buffer is a valid handshake. If it is, we set some member variables for later usage. The handshake packet is returned for later usuage, e.g. authentication. """ - - if self.is_error(buf): - # an ErrorPacket is returned by the server - self._handle_error(buf) - - handshake = None try: - handshake = Handshake(buf) - except errors.InterfaceError, msg: - raise errors.InterfaceError(msg) - self.set_handshake(handshake) - - def set_handshake(self, handshake): - """Gather data from the given handshake.""" - ver = re.compile("^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)") - version = handshake.info['version'] - m = ver.match(version) - if not m: - raise errors.InterfaceError("Could not parse MySQL version, was '%s'" % version) - else: - self.server_version = tuple([ int(v) for v in m.groups()[0:3]]) - - self.server_version_original = handshake.info['version'] - self.server_threadid = handshake.info['thrdid'] - self.capabilities = handshake.info['capabilities'] - self.charset = handshake.info['charset'] - self.threadid = handshake.info['thrdid'] - self.handshake = handshake - - def _handle_error(self, buf): - """Raise an OperationalError if result is an error - """ - try: - err = ErrorResultPacket(buf) - except errors.InterfaceError, e: - raise e - else: - raise errors.OperationalError(err) - - def is_error(self, buf): - """Check if the given buffer is a MySQL Error Packet. - - Buffer should start with \xff. - - Returns boolean. - """ - if buf and buf[4] == '\xff': - self._handle_error(buf) - return True - return False - - def _handle_ok(self, buf): - """ - Handle an OK Result Packet. If we got an InterfaceError, raise that - instead. - """ - try: - ok = OKResultPacket(buf) - except errors.InterfaceError, e: - raise e - else: - self.server_status = ok.server_status - self.warning_count = ok.warning_count - self.field_count = ok.field_count - self.affected_rows = ok.affected_rows - self.info_msg = ok.info_msg - - def is_ok(self, buf): - """ - Check if the given buffer is a MySQL OK Packet. It should - start with \x00. - - Returns boolean. - """ - if buf and buf[4] == '\x00': - self._handle_ok(buf) - return True - return False - - def _handle_fields(self, nrflds): - """Reads a number of fields from a result set.""" - i = 0 - fields = [] - while i < nrflds: - buf = self.conn.recv()[0] - fld = FieldPacket(buf) - fields.append(fld) - i += 1 - return fields - - def is_eof(self, buf): - """ - Check if the given buffer is a MySQL EOF Packet. It should - start with \xfe and be smaller 9 bytes. - - Returns boolean. - """ - l = utils.read_int(buf, 3)[1] - if buf and buf[4] == '\xfe' and l < 9: - return True - return False - - def _handle_resultset(self, pkt): - """Processes a resultset getting fields information. - - The argument pkt must be a protocol.Packet with length 1, a byte - which contains the number of fields. - """ - if not isinstance(pkt, PacketIn): - raise ValueError("%s is not a protocol.PacketIn" % pkt) - - if len(pkt) == 1: - (buf,nrflds) = utils.read_lc_int(pkt.data) + res = self._pkt_parse_handshake(buf) + for k,v in res.items(): + self.__dict__[k] = v - # Get the fields - fields = self._handle_fields(nrflds) - - buf = self.conn.recv()[0] - eof = EOFPacket(buf) - - return (nrflds, fields, eof) - else: - raise errors.InterfaceError('Something wrong reading result after query.') - - def result_get_row(self): - """Get data for 1 row - - Get one row's data. Should be called after getting the field - descriptions. - - Returns a tuple with 2 elements: a row's data and the - EOF packet. - """ - buf = self.conn.recv()[0] - if self.is_eof(buf): - eof = EOFPacket(buf) - rowdata = None - else: - eof = None - rowdata = utils.read_lc_string_list(buf[4:]) - return (rowdata, eof) + regex_ver = re.compile("^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)") + m = regex_ver.match(self.server_version_original) + if not m: + raise errors.InterfaceError("Failed parsing MySQL version number.") + self.server_version = tuple([ int(v) for v in m.groups()[0:3]]) + except errors.Error: + raise + except Exception, e: + raise errors.InterfaceError('Failed handling handshake; %s' % e) - def result_get_rows(self, cnt=None): - """Get all rows + @packet_is_ok(1) + def _handle_ok(self, buf): + try: + return self._pkt_parse_ok(buf) + except: + raise errors.InterfaceError("Failed parsing OK packet.") + @packet_is_eof(1) + def _handle_eof(self, buf): + try: + return self._pkt_parse_eof(buf) + except: + raise errors.InterfaceError("Failed parsing EOF packet.") + + @packet_is_error(1) + @set_pktnr(1) + def _handle_resultset(self, buf): + (buf,nrflds) = utils.read_lc_int(buf) + if nrflds == 0: + raise errors.InterfaceError('Empty result set.') + + fields = [] + for i in xrange(0,nrflds): + buf = self._recv_packet() + fields.append(self._pkt_parse_field(buf)) + + buf = self._recv_packet() + eof = self._handle_eof(buf) + return (nrflds, fields, eof) + + + def get_rows(self, cnt=None): + """Get all rows + Returns a tuple with 2 elements: a list with all rows and the EOF packet. """ rows = [] eof = None rowdata = None - while eof is None: - (rowdata,eof) = self.result_get_row() + i = 0 + while True: + if eof is not None: + break + if i == cnt: + break + buf = self._recv_packet() + if buf[4] == '\xfe': + eof = self._handle_eof(buf) + rowdata = None + else: + eof = None + rowdata = utils.read_lc_string_list(buf[4:]) if eof is None and rowdata is not None: rows.append(rowdata) + i += 1 return (rows,eof) - + + def get_row(self): + (rows,eof) = self.get_rows(cnt=1) + if len(rows): + return (rows[0],eof) + return (None,eof) + + def handle_cmd_result(self, buf): + if buf[4] == '\x00': + return self._handle_ok(buf) + else: + return self._handle_resultset(buf)[0:2] + + @reset_pktnr def cmd_query(self, query): - """ - Sends a query to the server. - + """Sends a query to the MySQL server + Returns a tuple, when the query returns a result. The tuple consist number of fields and a list containing their descriptions. - If the query doesn't return a result set, the an OKResultPacket - will be returned. + If the query doesn't return a result set, a dictionary with + information contained in an OKResult packet will be returned. """ + nrflds = 0 + fields = None try: - cmd = CommandPacket() - cmd.set_command(ServerCmd.QUERY) - cmd.set_argument(query) - cmd.create() - self.conn.send(cmd.get()) # Errors handled in _handle_error() - - buf = self.conn.recv()[0] - if self.is_ok(buf): - # Query does not return a result (INSERT/DELETE/..) - return OKResultPacket(buf) - - p = PacketIn(buf) - (nrflds, fields, eof) = self._handle_resultset(p) + pkt = self._pkt_make_command(ServerCmd.QUERY,query) + self.conn.send(pkt) # Errors handled in _handle_error() + return self.handle_cmd_result(self._recv_packet()) except: raise - else: - return (nrflds, fields) - - return (0, ()) - - def _cmd_simple(self, servercmd, arg=''): - """Makes a simple CommandPacket with no arguments""" - cmd = CommandPacket() - cmd.set_command(servercmd) - cmd.set_argument(arg) - cmd.create() - - return cmd + @reset_pktnr def cmd_refresh(self, opts): - """Send the Refresh command to the MySQL server. - - The argument should be a bitwise value using the protocol.RefreshOption - constants. - - Usage: - + """Send the Refresh command to the MySQL server + + The argument should be a bitwise value using contants.RefreshOption. + Usage example: + RefreshOption = mysql.connector.RefreshOption refresh = RefreshOption.LOG | RefreshOption.THREADS - db.cmd_refresh(refresh) - - """ - cmd = self._cmd_simple(ServerCmd.REFRESH, opts) - try: - self.conn.send(cmd.get()) - buf = self.conn.recv()[0] - except: - raise + db.protocol().cmd_refresh(refresh) - if self.is_ok(buf): - return True + Returns a dict() with OK-packet information. + """ + pkt = self._pkt_make_command(ServerCmd.REFRESH, opts) + self.conn.send(pkt) + buf = self._recv_packet() + return self._handle_ok(buf) - return False - + @reset_pktnr def cmd_quit(self): - """Closes the current connection with the server.""" - cmd = self._cmd_simple(ServerCmd.QUIT) - self.conn.send(cmd.get()) - + """Closes the current connection with the server + + Returns the packet that was send. + """ + pkt = self._pkt_make_command(ServerCmd.QUIT) + self.conn.send(pkt) + return pkt + + @reset_pktnr def cmd_init_db(self, database): - """ - Send command to server to change databases. - """ - cmd = self._cmd_simple(ServerCmd.INIT_DB, database) - self.conn.send(cmd.get()) - self.conn.recv()[0] + """Change the current database + Change the current (default) database. + + Returns a dict() with OK-packet information. + """ + pkt = self._pkt_make_command(ServerCmd.INIT_DB, database) + self.conn.send(pkt) + buf = self._recv_packet() + return self._handle_ok(buf) + + @reset_pktnr def cmd_shutdown(self): - """Shuts down the MySQL Server. - + """Shuts down the MySQL Server + Careful with this command if you have SUPER privileges! (Which your scripts probably don't need!) - - Returns True if it succeeds. - """ - cmd = self._cmd_simple(ServerCmd.SHUTDOWN) - try: - self.conn.send(cmd.get()) - buf = self.conn.recv()[0] - except: - raise - return True + Returns a dict() with OK-packet information. + """ + pkt = self._pkt_make_command(ServerCmd.SHUTDOWN) + self.conn.send(pkt) + buf = self._recv_packet() + return self._handle_eof(buf) + @reset_pktnr def cmd_statistics(self): """Sends statistics command to the MySQL Server - + Returns a dictionary with various statistical information. """ - cmd = self._cmd_simple(ServerCmd.STATISTICS) - try: - self.conn.send(cmd.get()) - buf = self.conn.recv()[0] - except: - raise - - p = Packet(buf) - info = str(p.data) - + pkt = self._pkt_make_command(ServerCmd.STATISTICS) + self.conn.send(pkt) + buf = self._recv_packet() + buf = buf[4:] + errmsg = "Failed getting COM_STATISTICS information" res = {} - pairs = info.split('\x20\x20') # Information is separated by 2 spaces + # Information is separated by 2 spaces + pairs = buf.split('\x20\x20') for pair in pairs: - (lbl,val) = [ v.strip() for v in pair.split(':') ] + try: + (lbl,val) = [ v.strip() for v in pair.split(':',2) ] + except: + raise errors.InterfaceError(errmsg) + # It's either an integer or a decimal try: res[lbl] = long(val) @@ -369,495 +556,69 @@ class MySQLProtocol(object): try: res[lbl] = Decimal(val) except: - raise ValueError( - "Got wrong value in COM_STATISTICS information (%s : %s)." % (lbl, val)) + raise errors.InterfaceError( + "%s (%s:%s)." % (errmsg, lbl, val)) return res + @reset_pktnr def cmd_process_info(self): - """Gets the process list from the MySQL Server. - - Returns a list of dictionaries which corresponds to the output of - SHOW PROCESSLIST of MySQL. The data is converted to Python types. + """Gets the process list from the MySQL Server + + (Unsupported) """ raise errors.NotSupportedError( "Not implemented. Use a cursor to get processlist information.") - - def cmd_process_kill(self, mypid): - """Kills a MySQL process using it's ID. - - The mypid must be an integer. - - """ - cmd = KillPacket(mypid) - cmd.create() - try: - self.conn.send(cmd.get()) - buf = self.conn.recv()[0] - except: - raise - - if self.is_eof(buf): - return True - return False - + @reset_pktnr + def cmd_process_kill(self, mypid): + """Kills a MySQL process using it's ID + + Returns a dict() with OK-packet information. + """ + pkt = self._pkt_make_command(ServerCmd.PROCESS_KILL, + utils.int4store(mypid)) + self.conn.send(pkt) + buf = self._recv_packet() + return self._handle_ok(buf) + + @reset_pktnr def cmd_debug(self): """Send DEBUG command to the MySQL Server - - Needs SUPER privileges. The output will go to the MySQL server error log. - - Returns True when it was succesful. + + Needs SUPER privileges. The output will go to the MySQL server error + log. + + Returns a dict() with EOF-packet information. """ - cmd = self._cmd_simple(ServerCmd.DEBUG) - try: - self.conn.send(cmd.get()) - buf = self.conn.recv()[0] - except: - raise - - if self.is_eof(buf): - return True - - return False - + pkt = self._pkt_make_command(ServerCmd.DEBUG) + self.conn.send(pkt) + buf = self._recv_packet() + return self._handle_eof(buf) + + @reset_pktnr def cmd_ping(self): + """Ping the MySQL server to check if the connection is still alive + + Raises errors.Error or an error derived from it when it fails + to ping the MySQL server. + + Returns a dict() with OK-packet information. """ - Ping the MySQL server to check if the connection is still alive. + pkt = self._pkt_make_command(ServerCmd.PING) + self.conn.send(pkt) + buf = self._recv_packet() + return self._handle_ok(buf) - Returns True when alive, False when server doesn't respond. + @reset_pktnr + def cmd_change_user(self, username='', password='', database=''): + """Change the user and optionally the current database + + Returns a dict() with OK-packet information. """ - cmd = self._cmd_simple(ServerCmd.PING) - try: - self.conn.send(cmd.get()) - buf = self.conn.recv()[0] - except: - return False - else: - if self.is_ok(buf): - return True - - return False - - def cmd_change_user(self, username, password, database=None): - """Change the user with given username and password to another optional database. - """ - if not database: - database = self.database + _charset = self.charset or 33 - cmd = ChangeUserPacket() - cmd.create(username=username, password=password, database=database, - charset=self.charset, seed=self.handshake.info['seed']) - try: - self.conn.send(cmd.get()) - buf = self.conn.recv()[0] - except: - raise - - if not self.is_ok(buf): - raise errors.OperationalError( - "Failed getting OK Packet after changing user") - - return True - -class BasePacket(object): - - def __len__(self): - try: - return len(self.data) - except: - return 0 - - def is_valid(self, buf=None): - if buf is None: - buf = self.data - - (l, n) = (buf[0:3], buf[3]) - hlength = utils.int3read(l) - rlength = len(buf) - 4 - - if hlength != rlength: - return False - - res = self._is_valid_extra(buf) - if res != None: - return res - - return True - - def _is_valid_extra(self, buf): - return True - -class PacketIn(BasePacket): - def __init__(self, buf=None, pktnr=0): - self.data = '' - self.pktnr = pktnr - self.protocol = 10 - - if buf: - self.is_valid(buf) - self.data = buf[4:] - - if self.data: - self._parse() - - def _parse(self): - pass - -class PacketOut(BasePacket): - """ - Each packet type used in the MySQL Client Protocol is build on the Packet - class. It defines lots of useful functions for parsing and sending - data to and from the MySQL Server. - """ - - def __init__(self, buf=None, pktnr=0): - self.data = '' - self.pktnr = pktnr - self.protocol = 10 - - if buf: - self.set(buf) - - if self.data: - self._parse() - - def _make_header(self): - h = utils.int3store(len(self)) + utils.int1store(self.pktnr) - return h - - def _parse(self): - pass - - def add(self, s): - if not s: - self.add_null() - else: - self.data = self.data + s - - def add_1_int(self, i): - self.add(utils.int1store(i)) - - def add_2_int(self, i): - self.add(utils.int2store(i)) - - def add_3_int(self, i): - self.add(utils.int3store(i)) - - def add_4_int(self, i): - self.add(utils.int4store(i)) - - def add_null(self, nr=1): - self.add('\x00'*nr) - - def get(self): - return self._make_header() + self.data - - def get_header(self): - return self._make_header() - - def set(self, buf): - if not self.is_valid(buf): - raise errors.InterfaceError('Packet not valid.') - - self.data = buf[4:] - - def _is_valid_extra(self, buf=None): - return None - -class Handshake(PacketIn): - - def __init__(self, buf=None): - PacketIn.__init__(self, buf) - - def _parse(self): - version = '' - options = 0 - srvstatus = 0 - - buf = self.data - (buf,self.protocol) = utils.read_int(buf,1) - (buf,version) = utils.read_string(buf,end='\x00') - (buf,thrdid) = utils.read_int(buf,4) - (buf,scramble) = utils.read_bytes(buf, 8) - buf = buf[1:] # Filler 1 * \x00 - (buf,srvcap) = utils.read_int(buf,2) - (buf,charset) = utils.read_int(buf,1) - (buf,serverstatus) = utils.read_int(buf,2) - buf = buf[13:] # Filler 13 * \x00 - (buf,scramble_next) = utils.read_bytes(buf,12) - scramble += scramble_next - - self.info = { - 'version' : version, - 'thrdid' : thrdid, - 'seed' : scramble, - 'capabilities' : srvcap, - 'charset' : charset, - 'serverstatus' : serverstatus, - } - - def get_dict(self): - self._parse() - return self.info - - def _is_valid_extra(self, buf): - - if buf[3] != '\x00': - return False - - return True - -class Auth(PacketOut): - - def __init__(self, packet=None, client_flags=0, pktnr=0): - PacketOut.__init__(self, packet, pktnr) - self.client_flags = 0 - self.username = None - self.password = None - self.database = None - if client_flags: - self.set_client_flags(client_flags) - - def set_client_flags(self, flags): - self.client_flags = flags - - def set_login(self, username, password, database=None): - self.username = username - self.password = password - self.database = database - - def scramble(self, passwd, seed): - - hash4 = None - try: - hash1 = sha1(passwd).digest() - hash2 = sha1(hash1).digest() # Password as found in mysql.user() - hash3 = sha1(seed + hash2).digest() - xored = [ utils.int1read(h1) ^ utils.int1read(h3) - for (h1,h3) in zip(hash1, hash3) ] - hash4 = struct.pack('20B', *xored) - except StandardError, e: - raise errors.ProgrammingError('Failed scrambling password; %s' % e) - else: - return hash4 - return None - - def create(self, username=None, password=None, database=None, seed=None): - self.add_4_int(self.client_flags) - self.add_4_int(10 * 1024 * 1024) - self.add_1_int(8) - self.add_null(23) - self.add(username + '\x00') - if password: - self.add_1_int(20) - self.add(self.scramble(password,seed)) - else: - self.add_null(1) - - if database: - self.add(database + '\x00') - else: - self.add_null() - - -class ChangeUserPacket(Auth): - def __init__(self): - self.command = ServerCmd.CHANGE_USER - Auth.__init__(self) - - def create(self, username=None, password=None, database=None, charset=8, seed=None): - self.add_1_int(self.command) - self.add(username + '\x00') - if password: - self.add_1_int(20) - self.add(self.scramble(password,seed)) - else: - self.add_null(1) - if database: - self.add(database + '\x00') - else: - self.add_null() - - self.add_2_int(charset) - -class ErrorResultPacket(PacketIn): - - def __init__(self, buf=None): - self.errno = 0 - self.errmsg = '' - self.sqlstate = None - PacketIn.__init__(self, buf) - - def _parse(self): - buf = self.data - - if buf[0] != '\xff': - raise errors.InterfaceError('Expected an Error Packet.') - buf = buf[1:] - - (buf,self.errno) = utils.read_int(buf, 2) - - if buf[0] != '\x23': - # Error without SQLState - self.errmsg = buf - else: - (buf,self.sqlstate) = utils.read_bytes(buf[1:],5) - self.errmsg = buf - -class OKResultPacket(PacketIn): - def __init__(self, buf=None): - self.affected_rows = None - self.insert_id = None - self.server_status = 0 - self.warning_count = 0 - self.field_count = 0 - self.info_msg = '' - PacketIn.__init__(self, buf) - - def __str__(self): - if self.affected_rows == 1: - lbl_rows = 'row' - else: - lbl_rows = 'rows' - - xtr = [] - if self.insert_id: - xtr.append('last insert: %d ' % self.insert_id) - if self.warning_count: - xtr.append('warnings: %d' % self.warning_count) - - return "Query OK, %d %s affected %s( sec)" % (self.affected_rows, - lbl_rows, ', '.join(xtr)) - - def _parse(self): - buf = self.data - (buf,self.field_count) = utils.read_int(buf,1) - (buf,self.affected_rows) = utils.read_lc_int(buf) - (buf,self.insert_id) = utils.read_lc_int(buf) - (buf,self.server_status) = utils.read_int(buf,2) - (buf,self.warning_count) = utils.read_int(buf,2) - if buf: - (buf,self.info_msg) = utils.read_lc_string(buf) - -class CommandPacket(PacketOut): - def __init__(self, cmd=None, arg=None): - self.command = cmd - self.argument = arg - PacketOut.__init__(self) - - def create(self): - self.add_1_int(self.command) - self.add(str(self.argument)) - - def set_command(self, cmd): - self.command = cmd - - def set_argument(self, arg): - self.argument = arg - -class KillPacket(CommandPacket): - - def __init__(self, arg): - CommandPacket.__init__(self) - self.set_command(ServerCmd.PROCESS_KILL) - self.set_argument(arg) - - def create(self): - """""" - self.add_1_int(self.command) - self.add_4_int(self.argument) - - def set_argument(self, arg): - if arg and not isinstance(int, long) and arg > 2**32: - raise ValueError, "KillPacket needs integer value as argument not larger than 2^32." - self.argument = arg - -class FieldPacket(PacketIn): - def __init__(self, buf=None): - self.catalog = None - self.db = None - self.table = None - self.org_table = None - self.name = None - self.length = None - self.org_name = None - self.charset = None - self.type = None - self.flags = None - PacketIn.__init__(self, buf) - - def __str__(self): - flags = [] - for k,f in FieldFlag.desc.items(): - if int(self.flags) & f[0]: - flags.append(k) - return """ - Field: catalog: %s ; db:%s ; table:%s ; org_table: %s ; - name: %s ; org_name: %s ; - charset: %s ; lenght: %s ; - type: %02x ; - flags(%d): %s; - """ % (self.catalog,self.db,self.table,self.org_table,self.name,self.org_name, - self.charset, len(self), self.type, - self.flags, '|'.join(flags)) - - def _parse(self): - buf = self.data - - (buf,self.catalog) = utils.read_lc_string(buf) - (buf,self.db) = utils.read_lc_string(buf) - (buf,self.table) = utils.read_lc_string(buf) - (buf,self.org_table) = utils.read_lc_string(buf) - (buf,self.name) = utils.read_lc_string(buf) - (buf,self.org_name) = utils.read_lc_string(buf) - buf = buf[1:] # filler 1 * \x00 - (buf,self.charset) = utils.read_int(buf, 2) - (buf,self.length) = utils.read_int(buf, 4) - (buf,self.type) = utils.read_int(buf, 1) - (buf,self.flags) = utils.read_int(buf, 2) - (buf,self.decimal) = utils.read_int(buf, 1) - buf = buf[2:] # filler 2 * \x00 - - def get_description(self): - """Returns a description as a list useful for cursors. - - This function returns a list as defined in the Python Db API v2.0 - specification. - - """ - return ( - self.name, - self.type, - None, # display_size - None, # internal_size - None, # precision - None, # scale - ~self.flags & FieldFlag.NOT_NULL, # null_ok - self.flags, # MySQL specific - ) - -class EOFPacket(PacketIn): - def __init__(self, buf=None): - self.warning_count = None - self.status_flag = None - PacketIn.__init__(self, buf) - - def __str__(self): - return "EOFPacket: warnings %d / status: %d" % (self.warning_count,self.status_flag) - - def _is_valid_extra(self, buf=None): - if not buf: - buf = self.data - else: - buf = buf[4:] - if buf[0] == '\xfe' and len(buf) == 5: - # An EOF should always start with \xfe and smaller than 9 bytes - return True - return False - - def _parse(self): - buf = self.data - - buf = buf[1:] # disregard the first checking byte - (buf, self.warning_count) = utils.read_int(buf, 2) - (buf, self.status_flag) = utils.read_int(buf, 2) + pkt = self._pkt_make_changeuser(username=username, password=password, + database=database, charset=_charset, seed=self.scramble) + self.conn.send(pkt) + buf = self._recv_packet() + return self._handle_ok(buf) diff --git a/mysql/connector/utils.py b/mysql/connector/utils.py index de2ed84..197d938 100644 --- a/mysql/connector/utils.py +++ b/mysql/connector/utils.py @@ -1,5 +1,5 @@ # MySQL Connector/Python - MySQL driver written in Python. -# Copyright 2009 Sun Microsystems, Inc. All rights reserved +# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved. # Use is subject to license terms. (See COPYING) # This program is free software; you can redistribute it and/or modify @@ -28,91 +28,22 @@ __MYSQL_DEBUG__ = False import struct -def int1read(c): - """ - Takes a bytes and returns it was an integer. - - Returns integer. - """ - if isinstance(c,int): - if c < 0 or c > 254: - raise ValueError('excepts int 0 <= x <= 254') - return c - elif len(c) > 1: - raise ValueError('excepts 1 byte long bytes-object or int') - - return int('%02x' % ord(c),16) - -def int2read(s): - """ - Takes a string of 2 bytes and unpacks it as unsigned integer. - - Returns integer. - """ - if len(s) > 2: - raise ValueError('int2read require s length of maximum 3 bytes') - elif len(s) < 2: - s = s + '\x00' - return struct.unpack(' 3: - raise ValueError('int3read require s length of maximum 3 bytes') - elif len(s) < 4: - s = s + '\x00'*(4-len(s)) - return struct.unpack(' 4: - raise ValueError('int4read require s length of maximum 4 bytes') - elif len(s) < 4: - s = s + '\x00'*(4-len(s)) - return struct.unpack(' 8: - raise ValueError('int4read require s length of maximum 8 bytes') - elif len(s) < 8: - s = s + '\x00'*(8-len(s)) - return struct.unpack(' 4: - raise ValueError('intread expects a string not longer than 4 bytes') - if not isinstance(s, str): - raise ValueError('intread expects a string') - fs = { - 1 : int1read, - 2 : int2read, - 3 : int3read, - 4 : int4read, - 8 : int8read, - } - return fs[l](s) +def intread(b): + """Unpacks the given buffer to an integer""" + try: + if isinstance(b,int): + return b + l = len(b) + if l == 1: + return int(ord(b)) + if l <= 4: + tmp = b + '\x00'*(4-l) + return struct.unpack(' 1: - print "%s: %s" % (label,string.join( [ "%s" % chr(ord(c)) for c in buf ], '')) - except: - raise +def _digest_buffer(buf): + return ''.join([ "\\x%02x" % ord(c) for c in buf ])