613 lines
19 KiB
Python
613 lines
19 KiB
Python
# MySQL Connector/Python - MySQL driver written in Python.
|
|
# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved.
|
|
# Use is subject to license terms. (See COPYING)
|
|
|
|
# This program is free software; you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation.
|
|
#
|
|
# There are special exceptions to the terms and conditions of the GNU
|
|
# General Public License as it is applied to this software. View the
|
|
# full text of the exception in file EXCEPTIONS-CLIENT in the directory
|
|
# of this software distribution or see the FOSS License Exception at
|
|
# www.mysql.com.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program; if not, write to the Free Software
|
|
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
|
|
|
"""Implementing communication to MySQL servers
|
|
"""
|
|
|
|
import sys
|
|
import socket
|
|
import logging
|
|
import os
|
|
import weakref
|
|
from collections import deque
|
|
|
|
import constants
|
|
import conversion
|
|
import protocol
|
|
import errors
|
|
import utils
|
|
import cursor
|
|
|
|
logger = logging.getLogger('myconnpy')
|
|
|
|
class MySQLBaseSocket(object):
|
|
"""Base class for MySQL Connections subclasses.
|
|
|
|
Should not be used directly but overloaded, changing the
|
|
open_connection part. Examples of subclasses are
|
|
MySQLTCPSocket
|
|
MySQLUnixSocket
|
|
"""
|
|
def __init__(self):
|
|
self.sock = None # holds the socket connection
|
|
self.connection_timeout = None
|
|
self.buffer = deque()
|
|
self.recvsize = 1024*8
|
|
|
|
def open_connection(self):
|
|
pass
|
|
|
|
def close_connection(self):
|
|
try:
|
|
self.sock.close()
|
|
except:
|
|
pass
|
|
|
|
def get_address(self):
|
|
pass
|
|
|
|
def send(self, buf):
|
|
"""Send packets over the socket
|
|
"""
|
|
pktlen = len(buf)
|
|
try:
|
|
while pktlen:
|
|
pktlen -= self.sock.send(buf)
|
|
except Exception, e:
|
|
raise errors.OperationalError('%s' % e)
|
|
|
|
def recv(self):
|
|
"""Receive packets from the socket
|
|
"""
|
|
try:
|
|
return self.buffer.popleft()
|
|
except IndexError:
|
|
pass
|
|
|
|
pktnr = -1
|
|
try:
|
|
buf = self.sock.recv(self.recvsize)
|
|
while buf:
|
|
totalsize = len(buf)
|
|
if pktnr == -1 and totalsize > 4:
|
|
pktsize = utils.intread(buf[0:3])
|
|
pktnr = utils.intread(buf[3])
|
|
if pktnr > -1 and totalsize >= pktsize+4:
|
|
size = pktsize+4
|
|
self.buffer.append(buf[0:size])
|
|
buf = buf[size:]
|
|
pktnr = -1
|
|
if len(buf) == 0:
|
|
break
|
|
elif len(buf) < pktsize+4:
|
|
buf += self.sock.recv(self.recvsize)
|
|
except socket.timeout, e:
|
|
raise errors.InterfaceError(errno=2013)
|
|
except socket.error, e:
|
|
raise errors.InterfaceError(errno=2055,
|
|
values=dict(socketaddr=self.get_address(),errno=e.errno))
|
|
except:
|
|
raise
|
|
|
|
try:
|
|
return self.buffer.popleft()
|
|
except IndexError, e:
|
|
pass
|
|
|
|
def set_connection_timeout(self, timeout):
|
|
self.connection_timeout = timeout
|
|
|
|
class MySQLUnixSocket(MySQLBaseSocket):
|
|
"""Opens a connection through the UNIX socket of the MySQL Server."""
|
|
|
|
def __init__(self, unix_socket='/tmp/mysql.sock'):
|
|
MySQLBaseSocket.__init__(self)
|
|
self.unix_socket = unix_socket
|
|
|
|
def get_address(self):
|
|
return self.unix_socket
|
|
|
|
def open_connection(self):
|
|
"""Opens a UNIX socket and checks the MySQL handshake."""
|
|
try:
|
|
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
self.sock.settimeout(self.connection_timeout)
|
|
self.sock.connect(self.unix_socket)
|
|
except socket.error, e:
|
|
try:
|
|
m = e.errno
|
|
except:
|
|
m = e
|
|
raise errors.InterfaceError(errno=2002,
|
|
values=dict(socketaddr=self.get_address(),errno=m))
|
|
except StandardError, e:
|
|
raise errors.InterfaceError('%s' % e)
|
|
|
|
class MySQLTCPSocket(MySQLBaseSocket):
|
|
"""Opens a TCP connection to the MySQL Server."""
|
|
|
|
def __init__(self, host='127.0.0.1', port=3306):
|
|
MySQLBaseSocket.__init__(self)
|
|
self.server_host = host
|
|
self.server_port = port
|
|
|
|
def get_address(self):
|
|
return "%s:%s" % (self.server_host,self.server_port)
|
|
|
|
def open_connection(self):
|
|
"""Opens a TCP Connection and checks the MySQL handshake."""
|
|
try:
|
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
self.sock.settimeout(self.connection_timeout)
|
|
self.sock.connect( (self.server_host, self.server_port) )
|
|
except socket.error, e:
|
|
try:
|
|
m = e.errno
|
|
except:
|
|
m = e
|
|
raise errors.InterfaceError(errno=2003,
|
|
values=dict(socketaddr=self.get_address(),errno=m))
|
|
except StandardError, e:
|
|
raise errors.InterfaceError('%s' % e)
|
|
except:
|
|
raise
|
|
|
|
class MySQLConnection(object):
|
|
"""MySQL"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""Initializing"""
|
|
self.conn = None # Holding the connection
|
|
self.protocol = None
|
|
self.converter = None
|
|
self.cursors = []
|
|
|
|
self.client_flags = constants.ClientFlag.get_default()
|
|
self._charset = 33
|
|
|
|
self._username = ''
|
|
self._database = ''
|
|
self._server_host = '127.0.0.1'
|
|
self._server_port = 3306
|
|
self._unix_socket = None
|
|
self.client_host = ''
|
|
self.client_port = 0
|
|
|
|
self.affected_rows = 0
|
|
self.server_status = 0
|
|
self.warning_count = 0
|
|
self.field_count = 0
|
|
self.insert_id = 0
|
|
self.info_msg = ''
|
|
self.use_unicode = True
|
|
self.get_warnings = False
|
|
self.raise_on_warnings = False
|
|
self.connection_timeout = None
|
|
self.buffered = False
|
|
self.unread_result = False
|
|
self.raw = False
|
|
|
|
if len(kwargs) > 0:
|
|
self.connect(*args, **kwargs)
|
|
|
|
def connect(self, database=None, user='', password='',
|
|
host='127.0.0.1', port=3306, unix_socket=None,
|
|
use_unicode=True, charset='utf8', collation=None,
|
|
autocommit=False,
|
|
time_zone=None, sql_mode=None,
|
|
get_warnings=False, raise_on_warnings=False,
|
|
connection_timeout=None, client_flags=0,
|
|
buffered=False, raw=False,
|
|
passwd=None, db=None, connect_timeout=None, dsn=None):
|
|
if db and not database:
|
|
database = db
|
|
if passwd and not password:
|
|
password = passwd
|
|
if connect_timeout and not connection_timeout:
|
|
connection_timeout = connect_timeout
|
|
|
|
if dsn is not None:
|
|
errors.NotSupportedError("Data source name is not supported")
|
|
|
|
self._server_host = host
|
|
self._server_port = port
|
|
self._unix_socket = unix_socket
|
|
if database is not None:
|
|
self._database = database.strip()
|
|
else:
|
|
self._database = None
|
|
self._username = user
|
|
|
|
self.set_warnings(get_warnings,raise_on_warnings)
|
|
self.connection_timeout = connection_timeout
|
|
self.buffered = buffered
|
|
self.raw = raw
|
|
self.use_unicode = use_unicode
|
|
self.set_client_flags(client_flags)
|
|
self._charset = constants.CharacterSet.get_charset_info(charset)[0]
|
|
|
|
if user or password:
|
|
self.set_login(user, password)
|
|
|
|
self.disconnect()
|
|
self._open_connection(username=user, password=password, database=database,
|
|
client_flags=self.client_flags, charset=charset)
|
|
self._post_connection(time_zone=time_zone, sql_mode=sql_mode,
|
|
collation=collation)
|
|
|
|
def _get_connection(self, prtcls=None):
|
|
"""Get connection based on configuration
|
|
|
|
This method will return the appropriated connection object using
|
|
the connection parameters.
|
|
|
|
Returns subclass of MySQLBaseSocket.
|
|
"""
|
|
conn = None
|
|
if self.unix_socket and os.name != 'nt':
|
|
conn = MySQLUnixSocket(unix_socket=self.unix_socket)
|
|
else:
|
|
conn = MySQLTCPSocket(host=self.server_host,
|
|
port=self.server_port)
|
|
conn.set_connection_timeout(self.connection_timeout)
|
|
return conn
|
|
|
|
def _open_connection(self, username=None, password=None, database=None,
|
|
client_flags=None, charset=None):
|
|
"""Opens the connection
|
|
|
|
Open the connection, check the MySQL version, and set the
|
|
protocol.
|
|
"""
|
|
try:
|
|
self.protocol = protocol.MySQLProtocol(self._get_connection())
|
|
self.protocol.do_handshake()
|
|
version = self.protocol.server_version
|
|
if version < (4,1):
|
|
raise errors.InterfaceError(
|
|
"MySQL Version %s is not supported." % version)
|
|
self.protocol.do_auth(username, password, database, client_flags,
|
|
self._charset)
|
|
(self._charset, self.charset_name, c) = \
|
|
constants.CharacterSet.get_charset_info(charset)
|
|
except:
|
|
raise
|
|
|
|
def _post_connection(self, time_zone=None, autocommit=False,
|
|
sql_mode=None, collation=None):
|
|
"""Post connection session setup
|
|
|
|
Should be called after a connection was established"""
|
|
self.set_converter_class(conversion.MySQLConverter)
|
|
try:
|
|
if collation is not None:
|
|
self.collation = collation
|
|
self.autocommit = autocommit
|
|
if time_zone is not None:
|
|
self.time_zone = time_zone
|
|
if sql_mode is not None:
|
|
self.sql_mode = sql_mode
|
|
except:
|
|
raise
|
|
|
|
def is_connected(self):
|
|
"""
|
|
Check whether we are connected to the MySQL server.
|
|
"""
|
|
return self.protocol.cmd_ping()
|
|
ping = is_connected
|
|
|
|
def disconnect(self):
|
|
"""
|
|
Disconnect from the MySQL server.
|
|
"""
|
|
if not self.protocol:
|
|
return
|
|
|
|
if self.protocol.conn.sock is not None:
|
|
self.protocol.cmd_quit()
|
|
try:
|
|
self.protocol.conn.close_connection()
|
|
except:
|
|
pass
|
|
self.protocol = None
|
|
|
|
def set_converter_class(self, convclass):
|
|
"""
|
|
Set the converter class to be used. This should be a class overloading
|
|
methods and members of conversion.MySQLConverter.
|
|
"""
|
|
self.converter_class = convclass
|
|
self.converter = convclass(self.charset_name, self.use_unicode)
|
|
|
|
def get_server_version(self):
|
|
"""Returns the server version as a tuple"""
|
|
try:
|
|
return self.protocol.server_version
|
|
except:
|
|
pass
|
|
|
|
return None
|
|
|
|
def get_server_info(self):
|
|
"""Returns the server version as a string"""
|
|
return self.protocol.server_version_original
|
|
|
|
@property
|
|
def connection_id(self):
|
|
"""MySQL connection ID"""
|
|
threadid = None
|
|
try:
|
|
threadid = self.protocol.server_threadid
|
|
except:
|
|
pass
|
|
return threadid
|
|
|
|
def set_login(self, username=None, password=None):
|
|
"""Set login information for MySQL
|
|
|
|
Set the username and/or password for the user connecting to
|
|
the MySQL Server.
|
|
"""
|
|
if username is not None:
|
|
self.username = username.strip()
|
|
else:
|
|
self.username = ''
|
|
if password is not None:
|
|
self.password = password.strip()
|
|
else:
|
|
self.password = ''
|
|
|
|
def set_unicode(self, value=True):
|
|
"""Toggle unicode mode
|
|
|
|
Set whether we return string fields as unicode or not.
|
|
Default is True.
|
|
"""
|
|
self.use_unicode = value
|
|
if self.converter:
|
|
self.converter.set_unicode(value)
|
|
|
|
def set_charset(self, charset):
|
|
try:
|
|
(idx, charset_name, c) = \
|
|
constants.CharacterSet.get_charset_info(charset)
|
|
self._execute_query("SET NAMES '%s'" % charset_name)
|
|
except:
|
|
raise
|
|
else:
|
|
self._charset = idx
|
|
self.charset_name = charset_name
|
|
self.converter.set_charset(charset_name)
|
|
def get_charset(self):
|
|
return self._info_query(
|
|
"SELECT @@session.character_set_connection")[0]
|
|
charset = property(get_charset, set_charset,
|
|
doc="Character set for this connection")
|
|
|
|
def set_collation(self, collation):
|
|
try:
|
|
self._execute_query(
|
|
"SET @@session.collation_connection = '%s'" % collation)
|
|
except:
|
|
raise
|
|
def get_collation(self):
|
|
return self._info_query(
|
|
"SELECT @@session.collation_connection")[0]
|
|
collation = property(get_collation, set_collation,
|
|
doc="Collation for this connection")
|
|
|
|
def set_warnings(self, fetch=False, raise_on_warnings=False):
|
|
"""Set how to handle warnings coming from MySQL
|
|
|
|
Set wheter we should get warnings whenever an operation produced some.
|
|
If you set raise_on_warnings to True, any warning will be raised
|
|
as a DataError exception.
|
|
"""
|
|
if raise_on_warnings is True:
|
|
self.get_warnings = True
|
|
self.raise_on_warnings = True
|
|
else:
|
|
self.get_warnings = fetch
|
|
self.raise_on_warnings = False
|
|
|
|
def set_client_flags(self, flags):
|
|
"""Set the client flags
|
|
|
|
The flags-argument can be either an int or a list (or tuple) of
|
|
ClientFlag-values. If it is an integer, it will set client_flags
|
|
to flags.
|
|
If flags is a list (or tuple), each flag will be set or unset
|
|
when it's negative.
|
|
|
|
set_client_flags([ClientFlag.FOUND_ROWS,-ClientFlag.LONG_FLAG])
|
|
|
|
Returns self.client_flags
|
|
"""
|
|
if isinstance(flags,int) and flags > 0:
|
|
self.set_client_flags(flags)
|
|
else:
|
|
if isinstance(flags,(tuple,list)):
|
|
for f in flags:
|
|
if f < 0:
|
|
self.unset_client_flag(abs(f))
|
|
else:
|
|
self.set_client_flag(f)
|
|
return self.client_flags
|
|
|
|
def set_client_flag(self, flag):
|
|
if flag > 0:
|
|
self.client_flags |= flag
|
|
|
|
def unset_client_flag(self, flag):
|
|
if flag > 0:
|
|
self.client_flags &= ~flag
|
|
|
|
def isset_client_flag(self, flag):
|
|
if (self.client_flags & flag) > 0:
|
|
return True
|
|
return False
|
|
|
|
@property
|
|
def user(self):
|
|
"""User used while connecting to MySQL"""
|
|
return self._username
|
|
|
|
@property
|
|
def server_host(self):
|
|
"""MySQL server IP address or name"""
|
|
return self._server_host
|
|
|
|
@property
|
|
def server_port(self):
|
|
"MySQL server TCP/IP port"
|
|
return self._server_port
|
|
|
|
@property
|
|
def unix_socket(self):
|
|
"MySQL Unix socket file location"
|
|
return self._unix_socket
|
|
|
|
def set_database(self, value):
|
|
try:
|
|
self.protocol.cmd_query("USE %s" % value)
|
|
except:
|
|
raise
|
|
def get_database(self):
|
|
"""Get the current database"""
|
|
return self._info_query("SELECT DATABASE()")[0]
|
|
database = property(get_database, set_database,
|
|
doc="Current database")
|
|
|
|
def set_time_zone(self, value):
|
|
try:
|
|
self.protocol.cmd_query("SET @@session.time_zone = %s" % value)
|
|
except:
|
|
raise
|
|
def get_time_zone(self):
|
|
return self._info_query("SELECT @@session.time_zone")[0]
|
|
time_zone = property(get_time_zone, set_time_zone,
|
|
doc="time_zone value for current MySQL session")
|
|
|
|
def set_sql_mode(self, value):
|
|
try:
|
|
self.protocol.cmd_query("SET @@session.sql_mode = %s" % value)
|
|
except:
|
|
raise
|
|
def get_sql_mode(self):
|
|
return self._info_query("SELECT @@session.sql_mode")[0]
|
|
sql_mode = property(get_sql_mode, set_sql_mode,
|
|
doc="sql_mode value for current MySQL session")
|
|
|
|
def set_autocommit(self, value):
|
|
try:
|
|
if value:
|
|
s = 'ON'
|
|
else:
|
|
s = 'OFF'
|
|
self._execute_query("SET @@session.autocommit = %s" % s)
|
|
except:
|
|
raise
|
|
def get_autocommit(self):
|
|
value = self._info_query("SELECT @@session.autocommit")[0]
|
|
if value == 1:
|
|
return True
|
|
return False
|
|
autocommit = property(get_autocommit, set_autocommit,
|
|
doc="autocommit value for current MySQL session")
|
|
|
|
def close(self):
|
|
del self.cursors[:]
|
|
self.disconnect()
|
|
|
|
def remove_cursor(self, c):
|
|
try:
|
|
self.cursors.remove(c)
|
|
except ValueError:
|
|
raise errors.ProgrammingError(
|
|
"Cursor could not be removed.")
|
|
|
|
def cursor(self, buffered=None, raw=None, cursor_class=None):
|
|
"""Instantiates and returns a cursor
|
|
|
|
By default, MySQLCursor is returned. Depending on the options
|
|
while connecting, a buffered and/or raw cursor instantiated
|
|
instead.
|
|
|
|
It is possible to also give a custom cursor through the
|
|
cursor_class paramter, but it needs to be a subclass of
|
|
mysql.connector.cursor.CursorBase.
|
|
|
|
Returns a cursor-object
|
|
"""
|
|
if cursor_class is not None:
|
|
if not issubclass(cursor_class, cursor.CursorBase):
|
|
raise errors.ProgrammingError(
|
|
"Cursor class needs be subclass of cursor.CursorBase")
|
|
c = (cursor_class)(self)
|
|
else:
|
|
buffered = buffered or self.buffered
|
|
raw = raw or self.raw
|
|
|
|
t = 0
|
|
if buffered is True:
|
|
t |= 1
|
|
if raw is True:
|
|
t |= 2
|
|
|
|
types = {
|
|
0 : cursor.MySQLCursor,
|
|
1 : cursor.MySQLCursorBuffered,
|
|
2 : cursor.MySQLCursorRaw,
|
|
3 : cursor.MySQLCursorBufferedRaw,
|
|
}
|
|
c = (types[t])(self)
|
|
|
|
if c not in self.cursors:
|
|
self.cursors.append(c)
|
|
return c
|
|
|
|
def commit(self):
|
|
"""Commit current transaction"""
|
|
self._execute_query("COMMIT")
|
|
|
|
def rollback(self):
|
|
"""Rollback current transaction"""
|
|
self._execute_query("ROLLBACK")
|
|
|
|
def _execute_query(self, query):
|
|
if self.unread_result is True:
|
|
raise errors.InternalError("Unread result found.")
|
|
|
|
self.protocol.cmd_query(query)
|
|
|
|
def _info_query(self, query):
|
|
try:
|
|
cur = self.cursor(buffered=True)
|
|
cur.execute(query)
|
|
row = cur.fetchone()
|
|
cur.close()
|
|
except:
|
|
raise
|
|
return row
|