625 lines
20 KiB
Python
625 lines
20 KiB
Python
# MySQL Connector/Python - MySQL driver written in Python.
|
|
# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved.
|
|
# Use is subject to license terms. (See COPYING)
|
|
|
|
# This program is free software; you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation.
|
|
#
|
|
# There are special exceptions to the terms and conditions of the GNU
|
|
# General Public License as it is applied to this software. View the
|
|
# full text of the exception in file EXCEPTIONS-CLIENT in the directory
|
|
# of this software distribution or see the FOSS License Exception at
|
|
# www.mysql.com.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program; if not, write to the Free Software
|
|
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USAs
|
|
|
|
"""Implementing the MySQL Client/Server protocol
|
|
"""
|
|
|
|
import re
|
|
import struct
|
|
|
|
try:
|
|
from hashlib import sha1
|
|
except ImportError:
|
|
from sha import new as sha1
|
|
|
|
from datetime import datetime
|
|
from time import strptime
|
|
from decimal import Decimal
|
|
|
|
from constants import *
|
|
import errors
|
|
import utils
|
|
|
|
def packet_is_error(idx=0,label=None):
|
|
def deco(func):
|
|
def call(*args, **kwargs):
|
|
try:
|
|
if label:
|
|
pktdata = kwargs[label]
|
|
else:
|
|
pktdata = args[idx]
|
|
except Exception, e:
|
|
raise errors.InterfaceError(
|
|
"Can't check for Error packet; %s" % e)
|
|
|
|
try:
|
|
if pktdata and pktdata[4] == '\xff':
|
|
MySQLProtocol.raise_error(pktdata)
|
|
except errors.Error:
|
|
raise
|
|
except:
|
|
pass
|
|
return func(*args, **kwargs)
|
|
return call
|
|
return deco
|
|
|
|
def packet_is_ok(idx=0,label=None):
|
|
def deco(func):
|
|
def call(*args, **kwargs):
|
|
try:
|
|
if label:
|
|
pktdata = kwargs[label]
|
|
else:
|
|
pktdata = args[idx]
|
|
except:
|
|
raise errors.InterfaceError("Can't check for OK packet.")
|
|
|
|
try:
|
|
if pktdata and pktdata[4] == '\x00':
|
|
return func(*args, **kwargs)
|
|
else:
|
|
raise
|
|
except Exception, e:
|
|
raise errors.InterfaceError("Expected OK packet")
|
|
return call
|
|
return deco
|
|
|
|
def packet_is_eof(idx=0,label=None):
|
|
def deco(func):
|
|
def call(*args, **kwargs):
|
|
try:
|
|
if label:
|
|
pktdata = kwargs[label]
|
|
else:
|
|
pktdata = args[idx]
|
|
except:
|
|
raise errors.InterfaceError("Can't check for EOF packet.")
|
|
if pktdata[4] == '\xfe' and len(pktdata) == 9:
|
|
return func(*args, **kwargs)
|
|
else:
|
|
raise errors.InterfaceError("Expected EOF packet")
|
|
|
|
return call
|
|
return deco
|
|
|
|
def set_pktnr(idx=1,label=None):
|
|
def deco(func):
|
|
def call(*args, **kwargs):
|
|
try:
|
|
if label:
|
|
pktdata = kwargs[label]
|
|
else:
|
|
pktdata = args[idx]
|
|
except:
|
|
raise errors.InterfaceError("Can't check for EOF packet.")
|
|
|
|
try:
|
|
args[0].pktnr = ord(pktdata[3])
|
|
pktdata = pktdata[4:]
|
|
if label:
|
|
kwargs[label] = pktdata
|
|
else:
|
|
args = list(args)
|
|
args[idx] = pktdata
|
|
except:
|
|
raise errors.InterfaceError("Failed getting Packet Number.")
|
|
return func(*args,**kwargs)
|
|
return call
|
|
return deco
|
|
|
|
def reset_pktnr(func):
|
|
def deco(*args, **kwargs):
|
|
try:
|
|
args[0].pktnr = -1
|
|
except:
|
|
pass
|
|
return func(*args, **kwargs)
|
|
|
|
return deco
|
|
|
|
class MySQLProtocolBase(object):
|
|
pass
|
|
|
|
class MySQLProtocol(MySQLProtocolBase):
|
|
|
|
def __init__(self, conn):
|
|
self.client_flags = 0
|
|
self.conn = conn
|
|
self.pktnr = -1
|
|
|
|
@classmethod
|
|
def raise_error(cls, buf):
|
|
"""Raise an errors.Error when buffer has a MySQL error"""
|
|
errno = errmsg = None
|
|
try:
|
|
buf = buf[5:]
|
|
(buf,errno) = utils.read_int(buf, 2)
|
|
if buf[0] != '\x23':
|
|
# Error without SQLState
|
|
errmsg = buf
|
|
else:
|
|
(buf,sqlstate) = utils.read_bytes(buf[1:],5)
|
|
errmsg = buf
|
|
except Exception, e:
|
|
raise errors.InterfaceError("Failed getting Error information (%r)"\
|
|
% e)
|
|
else:
|
|
raise errors.get_mysql_exception(errno,errmsg)
|
|
|
|
def _recv_packet(self):
|
|
"""Getting a packet from the MySQL server"""
|
|
buf = self.conn.recv()
|
|
if buf[4] == '\xff':
|
|
MySQLProtocol.raise_error(buf)
|
|
else:
|
|
return buf
|
|
|
|
def _scramble_password(self, passwd, seed):
|
|
"""Scramble a password ready to send to MySQL"""
|
|
hash4 = None
|
|
try:
|
|
hash1 = sha1(passwd).digest()
|
|
hash2 = sha1(hash1).digest() # Password as found in mysql.user()
|
|
hash3 = sha1(seed + hash2).digest()
|
|
xored = [ utils.intread(h1) ^ utils.intread(h3)
|
|
for (h1,h3) in zip(hash1, hash3) ]
|
|
hash4 = struct.pack('20B', *xored)
|
|
except Exception, e:
|
|
raise errors.InterfaceError('Failed scrambling password; %s' % e)
|
|
|
|
return hash4
|
|
|
|
def _pkt_make_header(self, pktlength, pktnr=None):
|
|
"""Make the header for a MySQL packet"""
|
|
pktnr = pktnr or self.pktnr+1
|
|
return utils.int3store(pktlength) + utils.int1store(pktnr)
|
|
|
|
def _prepare_auth(self, usr, pwd, db, flags, seed):
|
|
|
|
if usr is not None and len(usr) > 0:
|
|
_username = usr + '\x00'
|
|
else:
|
|
_username = '\x00'
|
|
|
|
if pwd is not None and len(pwd) > 0:
|
|
_password = utils.int1store(20) +\
|
|
self._scramble_password(pwd,seed)
|
|
else:
|
|
_password = '\x00'
|
|
|
|
if db is not None and len(db):
|
|
_database = db + '\x00'
|
|
else:
|
|
_database = '\x00'
|
|
|
|
return (_username, _password, _database)
|
|
|
|
def _pkt_make_auth(self, username=None, password=None, database=None,
|
|
seed=None, charset=33, client_flags=0):
|
|
"""Make a MySQL Authentication packet"""
|
|
try:
|
|
seed = seed or self.scramble
|
|
except:
|
|
raise errors.ProgrammingError('Seed missing')
|
|
|
|
(_username, _password, _database) = self._prepare_auth(
|
|
username, password, database, client_flags, seed)
|
|
data = utils.int4store(client_flags) +\
|
|
utils.int4store(10 * 1024 * 1024) +\
|
|
utils.int1store(charset) +\
|
|
'\x00'*23 +\
|
|
_username +\
|
|
_password +\
|
|
_database
|
|
|
|
header = self._pkt_make_header(len(data))
|
|
return header+data
|
|
|
|
def _pkt_make_command(self, command, argument=None):
|
|
"""Make a MySQL packet containing a command"""
|
|
data = utils.int1store(command)
|
|
if argument is not None:
|
|
data += str(argument)
|
|
header = self._pkt_make_header(len(data))
|
|
return header+data
|
|
|
|
def _pkt_make_changeuser(self, username=None, password=None,
|
|
database=None, charset=8, seed=None):
|
|
"""Make a MySQL packet with the Change User command"""
|
|
try:
|
|
seed = seed or self.scramble
|
|
except:
|
|
raise errors.ProgrammingError('Seed missing')
|
|
|
|
(_username, _password, _database) = self._prepare_auth(
|
|
username, password, database, self.client_flags, seed)
|
|
data = utils.int1store(ServerCmd.CHANGE_USER) +\
|
|
_username +\
|
|
_password +\
|
|
_database +\
|
|
utils.int2store(charset)
|
|
|
|
header = self._pkt_make_header(len(data))
|
|
return header+data
|
|
|
|
@set_pktnr(1)
|
|
def _pkt_parse_handshake(self, buf):
|
|
"""Parse a MySQL Handshake-packet"""
|
|
res = {}
|
|
(buf,res['protocol']) = utils.read_int(buf,1)
|
|
(buf,res['server_version_original']) = utils.read_string(buf,end='\x00')
|
|
(buf,res['server_threadid']) = utils.read_int(buf,4)
|
|
(buf,res['scramble']) = utils.read_bytes(buf, 8)
|
|
buf = buf[1:] # Filler 1 * \x00
|
|
(buf,res['capabilities']) = utils.read_int(buf,2)
|
|
(buf,res['charset']) = utils.read_int(buf,1)
|
|
(buf,res['server_status']) = utils.read_int(buf,2)
|
|
buf = buf[13:] # Filler 13 * \x00
|
|
(buf,scramble_next) = utils.read_bytes(buf,12)
|
|
res['scramble'] += scramble_next
|
|
return res
|
|
|
|
@set_pktnr(1)
|
|
def _pkt_parse_ok(self, buf):
|
|
"""Parse a MySQL OK-packet"""
|
|
ok = {}
|
|
(buf,ok['field_count']) = utils.read_int(buf,1)
|
|
(buf,ok['affected_rows']) = utils.read_lc_int(buf)
|
|
(buf,ok['insert_id']) = utils.read_lc_int(buf)
|
|
(buf,ok['server_status']) = utils.read_int(buf,2)
|
|
(buf,ok['warning_count']) = utils.read_int(buf,2)
|
|
if buf:
|
|
(buf,ok['info_msg']) = utils.read_lc_string(buf)
|
|
return ok
|
|
|
|
@set_pktnr(1)
|
|
def _pkt_parse_field(self, buf):
|
|
"""Parse a MySQL Field-packet"""
|
|
field = {}
|
|
(buf,field['catalog']) = utils.read_lc_string(buf)
|
|
(buf,field['db']) = utils.read_lc_string(buf)
|
|
(buf,field['table']) = utils.read_lc_string(buf)
|
|
(buf,field['org_table']) = utils.read_lc_string(buf)
|
|
(buf,field['name']) = utils.read_lc_string(buf)
|
|
(buf,field['org_name']) = utils.read_lc_string(buf)
|
|
buf = buf[1:] # filler 1 * \x00
|
|
(buf,field['charset']) = utils.read_int(buf, 2)
|
|
(buf,field['length']) = utils.read_int(buf, 4)
|
|
(buf,field['type']) = utils.read_int(buf, 1)
|
|
(buf,field['flags']) = utils.read_int(buf, 2)
|
|
(buf,field['decimal']) = utils.read_int(buf, 1)
|
|
buf = buf[2:] # filler 2 * \x00
|
|
|
|
res = (
|
|
field['name'],
|
|
field['type'],
|
|
None, # display_size
|
|
None, # internal_size
|
|
None, # precision
|
|
None, # scale
|
|
~field['flags'] & FieldFlag.NOT_NULL, # null_ok
|
|
field['flags'], # MySQL specific
|
|
)
|
|
return res
|
|
|
|
@set_pktnr(1)
|
|
def _pkt_parse_eof(self, buf):
|
|
"""Parse a MySQL EOF-packet"""
|
|
res = {}
|
|
buf = buf[1:] # disregard the first checking byte
|
|
(buf, res['warning_count']) = utils.read_int(buf, 2)
|
|
(buf, res['status_flag']) = utils.read_int(buf, 2)
|
|
return res
|
|
|
|
def do_handshake(self):
|
|
"""Get the handshake from the MySQL server"""
|
|
try:
|
|
self.conn.open_connection()
|
|
buf = self._recv_packet()
|
|
self.handle_handshake(buf)
|
|
except:
|
|
raise
|
|
|
|
def do_auth(self, username=None, password=None, database=None,
|
|
client_flags=0, charset=33):
|
|
"""Authenticate with the MySQL server
|
|
"""
|
|
pkt = self._pkt_make_auth(username=username, password=password,
|
|
database=database, charset=charset,
|
|
client_flags=client_flags)
|
|
|
|
self.conn.send(pkt)
|
|
buf = self._recv_packet()
|
|
if buf[4] == '\xfe':
|
|
raise errors.NotSupportedError(
|
|
"Authentication with old (insecure) passwords "\
|
|
"is not supported: "\
|
|
"http://dev.mysql.com/doc/refman/5.1/en/password-hashing.html")
|
|
|
|
try:
|
|
if not (client_flags & ClientFlag.CONNECT_WITH_DB) and database:
|
|
self.cmd_init_db(database)
|
|
except:
|
|
raise
|
|
|
|
return True
|
|
|
|
def handle_handshake(self, buf):
|
|
"""Check and handle the MySQL server's handshake
|
|
|
|
Check whether the buffer is a valid handshake. If it is, we set some
|
|
member variables for later usage. The handshake packet is returned for later
|
|
usuage, e.g. authentication.
|
|
"""
|
|
try:
|
|
res = self._pkt_parse_handshake(buf)
|
|
for k,v in res.items():
|
|
self.__dict__[k] = v
|
|
|
|
regex_ver = re.compile("^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)")
|
|
m = regex_ver.match(self.server_version_original)
|
|
if not m:
|
|
raise errors.InterfaceError("Failed parsing MySQL version number.")
|
|
self.server_version = tuple([ int(v) for v in m.groups()[0:3]])
|
|
except errors.Error:
|
|
raise
|
|
except Exception, e:
|
|
raise errors.InterfaceError('Failed handling handshake; %s' % e)
|
|
|
|
@packet_is_ok(1)
|
|
def _handle_ok(self, buf):
|
|
try:
|
|
return self._pkt_parse_ok(buf)
|
|
except:
|
|
raise errors.InterfaceError("Failed parsing OK packet.")
|
|
|
|
@packet_is_eof(1)
|
|
def _handle_eof(self, buf):
|
|
try:
|
|
return self._pkt_parse_eof(buf)
|
|
except:
|
|
raise errors.InterfaceError("Failed parsing EOF packet.")
|
|
|
|
@packet_is_error(1)
|
|
@set_pktnr(1)
|
|
def _handle_resultset(self, buf):
|
|
(buf,nrflds) = utils.read_lc_int(buf)
|
|
if nrflds == 0:
|
|
raise errors.InterfaceError('Empty result set.')
|
|
|
|
fields = []
|
|
for i in xrange(0,nrflds):
|
|
buf = self._recv_packet()
|
|
fields.append(self._pkt_parse_field(buf))
|
|
|
|
buf = self._recv_packet()
|
|
eof = self._handle_eof(buf)
|
|
return (nrflds, fields, eof)
|
|
|
|
|
|
def get_rows(self, cnt=None):
|
|
"""Get all rows
|
|
|
|
Returns a tuple with 2 elements: a list with all rows and
|
|
the EOF packet.
|
|
"""
|
|
rows = []
|
|
eof = None
|
|
rowdata = None
|
|
i = 0
|
|
while True:
|
|
if eof is not None:
|
|
break
|
|
if i == cnt:
|
|
break
|
|
buf = self._recv_packet()
|
|
if buf[4] == '\xfe':
|
|
eof = self._handle_eof(buf)
|
|
rowdata = None
|
|
else:
|
|
eof = None
|
|
rowdata = utils.read_lc_string_list(buf[4:])
|
|
if eof is None and rowdata is not None:
|
|
rows.append(rowdata)
|
|
i += 1
|
|
return (rows,eof)
|
|
|
|
def get_row(self):
|
|
(rows,eof) = self.get_rows(cnt=1)
|
|
if len(rows):
|
|
return (rows[0],eof)
|
|
return (None,eof)
|
|
|
|
def handle_cmd_result(self, buf):
|
|
if buf[4] == '\x00':
|
|
return self._handle_ok(buf)
|
|
else:
|
|
return self._handle_resultset(buf)[0:2]
|
|
|
|
@reset_pktnr
|
|
def cmd_query(self, query):
|
|
"""Sends a query to the MySQL server
|
|
|
|
Returns a tuple, when the query returns a result. The tuple
|
|
consist number of fields and a list containing their descriptions.
|
|
If the query doesn't return a result set, a dictionary with
|
|
information contained in an OKResult packet will be returned.
|
|
"""
|
|
nrflds = 0
|
|
fields = None
|
|
try:
|
|
pkt = self._pkt_make_command(ServerCmd.QUERY,query)
|
|
self.conn.send(pkt) # Errors handled in _handle_error()
|
|
return self.handle_cmd_result(self._recv_packet())
|
|
except:
|
|
raise
|
|
|
|
@reset_pktnr
|
|
def cmd_refresh(self, opts):
|
|
"""Send the Refresh command to the MySQL server
|
|
|
|
The argument should be a bitwise value using contants.RefreshOption.
|
|
Usage example:
|
|
|
|
RefreshOption = mysql.connector.RefreshOption
|
|
refresh = RefreshOption.LOG | RefreshOption.THREADS
|
|
db.protocol().cmd_refresh(refresh)
|
|
|
|
Returns a dict() with OK-packet information.
|
|
"""
|
|
pkt = self._pkt_make_command(ServerCmd.REFRESH, opts)
|
|
self.conn.send(pkt)
|
|
buf = self._recv_packet()
|
|
return self._handle_ok(buf)
|
|
|
|
@reset_pktnr
|
|
def cmd_quit(self):
|
|
"""Closes the current connection with the server
|
|
|
|
Returns the packet that was send.
|
|
"""
|
|
pkt = self._pkt_make_command(ServerCmd.QUIT)
|
|
self.conn.send(pkt)
|
|
return pkt
|
|
|
|
@reset_pktnr
|
|
def cmd_init_db(self, database):
|
|
"""Change the current database
|
|
|
|
Change the current (default) database.
|
|
|
|
Returns a dict() with OK-packet information.
|
|
"""
|
|
pkt = self._pkt_make_command(ServerCmd.INIT_DB, database)
|
|
self.conn.send(pkt)
|
|
buf = self._recv_packet()
|
|
return self._handle_ok(buf)
|
|
|
|
@reset_pktnr
|
|
def cmd_shutdown(self):
|
|
"""Shuts down the MySQL Server
|
|
|
|
Careful with this command if you have SUPER privileges! (Which your
|
|
scripts probably don't need!)
|
|
|
|
Returns a dict() with OK-packet information.
|
|
"""
|
|
pkt = self._pkt_make_command(ServerCmd.SHUTDOWN)
|
|
self.conn.send(pkt)
|
|
buf = self._recv_packet()
|
|
return self._handle_eof(buf)
|
|
|
|
@reset_pktnr
|
|
def cmd_statistics(self):
|
|
"""Sends statistics command to the MySQL Server
|
|
|
|
Returns a dictionary with various statistical information.
|
|
"""
|
|
pkt = self._pkt_make_command(ServerCmd.STATISTICS)
|
|
self.conn.send(pkt)
|
|
buf = self._recv_packet()
|
|
buf = buf[4:]
|
|
errmsg = "Failed getting COM_STATISTICS information"
|
|
res = {}
|
|
# Information is separated by 2 spaces
|
|
pairs = buf.split('\x20\x20')
|
|
for pair in pairs:
|
|
try:
|
|
(lbl,val) = [ v.strip() for v in pair.split(':',2) ]
|
|
except:
|
|
raise errors.InterfaceError(errmsg)
|
|
|
|
# It's either an integer or a decimal
|
|
try:
|
|
res[lbl] = long(val)
|
|
except:
|
|
try:
|
|
res[lbl] = Decimal(val)
|
|
except:
|
|
raise errors.InterfaceError(
|
|
"%s (%s:%s)." % (errmsg, lbl, val))
|
|
return res
|
|
|
|
@reset_pktnr
|
|
def cmd_process_info(self):
|
|
"""Gets the process list from the MySQL Server
|
|
|
|
(Unsupported)
|
|
"""
|
|
raise errors.NotSupportedError(
|
|
"Not implemented. Use a cursor to get processlist information.")
|
|
|
|
@reset_pktnr
|
|
def cmd_process_kill(self, mypid):
|
|
"""Kills a MySQL process using it's ID
|
|
|
|
Returns a dict() with OK-packet information.
|
|
"""
|
|
pkt = self._pkt_make_command(ServerCmd.PROCESS_KILL,
|
|
utils.int4store(mypid))
|
|
self.conn.send(pkt)
|
|
buf = self._recv_packet()
|
|
return self._handle_ok(buf)
|
|
|
|
@reset_pktnr
|
|
def cmd_debug(self):
|
|
"""Send DEBUG command to the MySQL Server
|
|
|
|
Needs SUPER privileges. The output will go to the MySQL server error
|
|
log.
|
|
|
|
Returns a dict() with EOF-packet information.
|
|
"""
|
|
pkt = self._pkt_make_command(ServerCmd.DEBUG)
|
|
self.conn.send(pkt)
|
|
buf = self._recv_packet()
|
|
return self._handle_eof(buf)
|
|
|
|
@reset_pktnr
|
|
def cmd_ping(self):
|
|
"""Ping the MySQL server to check if the connection is still alive
|
|
|
|
Raises errors.Error or an error derived from it when it fails
|
|
to ping the MySQL server.
|
|
|
|
Returns a dict() with OK-packet information.
|
|
"""
|
|
pkt = self._pkt_make_command(ServerCmd.PING)
|
|
self.conn.send(pkt)
|
|
buf = self._recv_packet()
|
|
return self._handle_ok(buf)
|
|
|
|
@reset_pktnr
|
|
def cmd_change_user(self, username='', password='', database=''):
|
|
"""Change the user and optionally the current database
|
|
|
|
Returns a dict() with OK-packet information.
|
|
"""
|
|
_charset = self.charset or 33
|
|
|
|
pkt = self._pkt_make_changeuser(username=username, password=password,
|
|
database=database, charset=_charset, seed=self.scramble)
|
|
self.conn.send(pkt)
|
|
buf = self._recv_packet()
|
|
return self._handle_ok(buf)
|