mythboxee/mysql/connector/protocol.py

625 lines
20 KiB
Python

# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved.
# Use is subject to license terms. (See COPYING)
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# There are special exceptions to the terms and conditions of the GNU
# General Public License as it is applied to this software. View the
# full text of the exception in file EXCEPTIONS-CLIENT in the directory
# of this software distribution or see the FOSS License Exception at
# www.mysql.com.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USAs
"""Implementing the MySQL Client/Server protocol
"""
import re
import struct
try:
from hashlib import sha1
except ImportError:
from sha import new as sha1
from datetime import datetime
from time import strptime
from decimal import Decimal
from constants import *
import errors
import utils
def packet_is_error(idx=0,label=None):
def deco(func):
def call(*args, **kwargs):
try:
if label:
pktdata = kwargs[label]
else:
pktdata = args[idx]
except Exception, e:
raise errors.InterfaceError(
"Can't check for Error packet; %s" % e)
try:
if pktdata and pktdata[4] == '\xff':
MySQLProtocol.raise_error(pktdata)
except errors.Error:
raise
except:
pass
return func(*args, **kwargs)
return call
return deco
def packet_is_ok(idx=0,label=None):
def deco(func):
def call(*args, **kwargs):
try:
if label:
pktdata = kwargs[label]
else:
pktdata = args[idx]
except:
raise errors.InterfaceError("Can't check for OK packet.")
try:
if pktdata and pktdata[4] == '\x00':
return func(*args, **kwargs)
else:
raise
except Exception, e:
raise errors.InterfaceError("Expected OK packet")
return call
return deco
def packet_is_eof(idx=0,label=None):
def deco(func):
def call(*args, **kwargs):
try:
if label:
pktdata = kwargs[label]
else:
pktdata = args[idx]
except:
raise errors.InterfaceError("Can't check for EOF packet.")
if pktdata[4] == '\xfe' and len(pktdata) == 9:
return func(*args, **kwargs)
else:
raise errors.InterfaceError("Expected EOF packet")
return call
return deco
def set_pktnr(idx=1,label=None):
def deco(func):
def call(*args, **kwargs):
try:
if label:
pktdata = kwargs[label]
else:
pktdata = args[idx]
except:
raise errors.InterfaceError("Can't check for EOF packet.")
try:
args[0].pktnr = ord(pktdata[3])
pktdata = pktdata[4:]
if label:
kwargs[label] = pktdata
else:
args = list(args)
args[idx] = pktdata
except:
raise errors.InterfaceError("Failed getting Packet Number.")
return func(*args,**kwargs)
return call
return deco
def reset_pktnr(func):
def deco(*args, **kwargs):
try:
args[0].pktnr = -1
except:
pass
return func(*args, **kwargs)
return deco
class MySQLProtocolBase(object):
pass
class MySQLProtocol(MySQLProtocolBase):
def __init__(self, conn):
self.client_flags = 0
self.conn = conn
self.pktnr = -1
@classmethod
def raise_error(cls, buf):
"""Raise an errors.Error when buffer has a MySQL error"""
errno = errmsg = None
try:
buf = buf[5:]
(buf,errno) = utils.read_int(buf, 2)
if buf[0] != '\x23':
# Error without SQLState
errmsg = buf
else:
(buf,sqlstate) = utils.read_bytes(buf[1:],5)
errmsg = buf
except Exception, e:
raise errors.InterfaceError("Failed getting Error information (%r)"\
% e)
else:
raise errors.get_mysql_exception(errno,errmsg)
def _recv_packet(self):
"""Getting a packet from the MySQL server"""
buf = self.conn.recv()
if buf[4] == '\xff':
MySQLProtocol.raise_error(buf)
else:
return buf
def _scramble_password(self, passwd, seed):
"""Scramble a password ready to send to MySQL"""
hash4 = None
try:
hash1 = sha1(passwd).digest()
hash2 = sha1(hash1).digest() # Password as found in mysql.user()
hash3 = sha1(seed + hash2).digest()
xored = [ utils.intread(h1) ^ utils.intread(h3)
for (h1,h3) in zip(hash1, hash3) ]
hash4 = struct.pack('20B', *xored)
except Exception, e:
raise errors.InterfaceError('Failed scrambling password; %s' % e)
return hash4
def _pkt_make_header(self, pktlength, pktnr=None):
"""Make the header for a MySQL packet"""
pktnr = pktnr or self.pktnr+1
return utils.int3store(pktlength) + utils.int1store(pktnr)
def _prepare_auth(self, usr, pwd, db, flags, seed):
if usr is not None and len(usr) > 0:
_username = usr + '\x00'
else:
_username = '\x00'
if pwd is not None and len(pwd) > 0:
_password = utils.int1store(20) +\
self._scramble_password(pwd,seed)
else:
_password = '\x00'
if db is not None and len(db):
_database = db + '\x00'
else:
_database = '\x00'
return (_username, _password, _database)
def _pkt_make_auth(self, username=None, password=None, database=None,
seed=None, charset=33, client_flags=0):
"""Make a MySQL Authentication packet"""
try:
seed = seed or self.scramble
except:
raise errors.ProgrammingError('Seed missing')
(_username, _password, _database) = self._prepare_auth(
username, password, database, client_flags, seed)
data = utils.int4store(client_flags) +\
utils.int4store(10 * 1024 * 1024) +\
utils.int1store(charset) +\
'\x00'*23 +\
_username +\
_password +\
_database
header = self._pkt_make_header(len(data))
return header+data
def _pkt_make_command(self, command, argument=None):
"""Make a MySQL packet containing a command"""
data = utils.int1store(command)
if argument is not None:
data += str(argument)
header = self._pkt_make_header(len(data))
return header+data
def _pkt_make_changeuser(self, username=None, password=None,
database=None, charset=8, seed=None):
"""Make a MySQL packet with the Change User command"""
try:
seed = seed or self.scramble
except:
raise errors.ProgrammingError('Seed missing')
(_username, _password, _database) = self._prepare_auth(
username, password, database, self.client_flags, seed)
data = utils.int1store(ServerCmd.CHANGE_USER) +\
_username +\
_password +\
_database +\
utils.int2store(charset)
header = self._pkt_make_header(len(data))
return header+data
@set_pktnr(1)
def _pkt_parse_handshake(self, buf):
"""Parse a MySQL Handshake-packet"""
res = {}
(buf,res['protocol']) = utils.read_int(buf,1)
(buf,res['server_version_original']) = utils.read_string(buf,end='\x00')
(buf,res['server_threadid']) = utils.read_int(buf,4)
(buf,res['scramble']) = utils.read_bytes(buf, 8)
buf = buf[1:] # Filler 1 * \x00
(buf,res['capabilities']) = utils.read_int(buf,2)
(buf,res['charset']) = utils.read_int(buf,1)
(buf,res['server_status']) = utils.read_int(buf,2)
buf = buf[13:] # Filler 13 * \x00
(buf,scramble_next) = utils.read_bytes(buf,12)
res['scramble'] += scramble_next
return res
@set_pktnr(1)
def _pkt_parse_ok(self, buf):
"""Parse a MySQL OK-packet"""
ok = {}
(buf,ok['field_count']) = utils.read_int(buf,1)
(buf,ok['affected_rows']) = utils.read_lc_int(buf)
(buf,ok['insert_id']) = utils.read_lc_int(buf)
(buf,ok['server_status']) = utils.read_int(buf,2)
(buf,ok['warning_count']) = utils.read_int(buf,2)
if buf:
(buf,ok['info_msg']) = utils.read_lc_string(buf)
return ok
@set_pktnr(1)
def _pkt_parse_field(self, buf):
"""Parse a MySQL Field-packet"""
field = {}
(buf,field['catalog']) = utils.read_lc_string(buf)
(buf,field['db']) = utils.read_lc_string(buf)
(buf,field['table']) = utils.read_lc_string(buf)
(buf,field['org_table']) = utils.read_lc_string(buf)
(buf,field['name']) = utils.read_lc_string(buf)
(buf,field['org_name']) = utils.read_lc_string(buf)
buf = buf[1:] # filler 1 * \x00
(buf,field['charset']) = utils.read_int(buf, 2)
(buf,field['length']) = utils.read_int(buf, 4)
(buf,field['type']) = utils.read_int(buf, 1)
(buf,field['flags']) = utils.read_int(buf, 2)
(buf,field['decimal']) = utils.read_int(buf, 1)
buf = buf[2:] # filler 2 * \x00
res = (
field['name'],
field['type'],
None, # display_size
None, # internal_size
None, # precision
None, # scale
~field['flags'] & FieldFlag.NOT_NULL, # null_ok
field['flags'], # MySQL specific
)
return res
@set_pktnr(1)
def _pkt_parse_eof(self, buf):
"""Parse a MySQL EOF-packet"""
res = {}
buf = buf[1:] # disregard the first checking byte
(buf, res['warning_count']) = utils.read_int(buf, 2)
(buf, res['status_flag']) = utils.read_int(buf, 2)
return res
def do_handshake(self):
"""Get the handshake from the MySQL server"""
try:
self.conn.open_connection()
buf = self._recv_packet()
self.handle_handshake(buf)
except:
raise
def do_auth(self, username=None, password=None, database=None,
client_flags=0, charset=33):
"""Authenticate with the MySQL server
"""
pkt = self._pkt_make_auth(username=username, password=password,
database=database, charset=charset,
client_flags=client_flags)
self.conn.send(pkt)
buf = self._recv_packet()
if buf[4] == '\xfe':
raise errors.NotSupportedError(
"Authentication with old (insecure) passwords "\
"is not supported: "\
"http://dev.mysql.com/doc/refman/5.1/en/password-hashing.html")
try:
if not (client_flags & ClientFlag.CONNECT_WITH_DB) and database:
self.cmd_init_db(database)
except:
raise
return True
def handle_handshake(self, buf):
"""Check and handle the MySQL server's handshake
Check whether the buffer is a valid handshake. If it is, we set some
member variables for later usage. The handshake packet is returned for later
usuage, e.g. authentication.
"""
try:
res = self._pkt_parse_handshake(buf)
for k,v in res.items():
self.__dict__[k] = v
regex_ver = re.compile("^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)")
m = regex_ver.match(self.server_version_original)
if not m:
raise errors.InterfaceError("Failed parsing MySQL version number.")
self.server_version = tuple([ int(v) for v in m.groups()[0:3]])
except errors.Error:
raise
except Exception, e:
raise errors.InterfaceError('Failed handling handshake; %s' % e)
@packet_is_ok(1)
def _handle_ok(self, buf):
try:
return self._pkt_parse_ok(buf)
except:
raise errors.InterfaceError("Failed parsing OK packet.")
@packet_is_eof(1)
def _handle_eof(self, buf):
try:
return self._pkt_parse_eof(buf)
except:
raise errors.InterfaceError("Failed parsing EOF packet.")
@packet_is_error(1)
@set_pktnr(1)
def _handle_resultset(self, buf):
(buf,nrflds) = utils.read_lc_int(buf)
if nrflds == 0:
raise errors.InterfaceError('Empty result set.')
fields = []
for i in xrange(0,nrflds):
buf = self._recv_packet()
fields.append(self._pkt_parse_field(buf))
buf = self._recv_packet()
eof = self._handle_eof(buf)
return (nrflds, fields, eof)
def get_rows(self, cnt=None):
"""Get all rows
Returns a tuple with 2 elements: a list with all rows and
the EOF packet.
"""
rows = []
eof = None
rowdata = None
i = 0
while True:
if eof is not None:
break
if i == cnt:
break
buf = self._recv_packet()
if buf[4] == '\xfe':
eof = self._handle_eof(buf)
rowdata = None
else:
eof = None
rowdata = utils.read_lc_string_list(buf[4:])
if eof is None and rowdata is not None:
rows.append(rowdata)
i += 1
return (rows,eof)
def get_row(self):
(rows,eof) = self.get_rows(cnt=1)
if len(rows):
return (rows[0],eof)
return (None,eof)
def handle_cmd_result(self, buf):
if buf[4] == '\x00':
return self._handle_ok(buf)
else:
return self._handle_resultset(buf)[0:2]
@reset_pktnr
def cmd_query(self, query):
"""Sends a query to the MySQL server
Returns a tuple, when the query returns a result. The tuple
consist number of fields and a list containing their descriptions.
If the query doesn't return a result set, a dictionary with
information contained in an OKResult packet will be returned.
"""
nrflds = 0
fields = None
try:
pkt = self._pkt_make_command(ServerCmd.QUERY,query)
self.conn.send(pkt) # Errors handled in _handle_error()
return self.handle_cmd_result(self._recv_packet())
except:
raise
@reset_pktnr
def cmd_refresh(self, opts):
"""Send the Refresh command to the MySQL server
The argument should be a bitwise value using contants.RefreshOption.
Usage example:
RefreshOption = mysql.connector.RefreshOption
refresh = RefreshOption.LOG | RefreshOption.THREADS
db.protocol().cmd_refresh(refresh)
Returns a dict() with OK-packet information.
"""
pkt = self._pkt_make_command(ServerCmd.REFRESH, opts)
self.conn.send(pkt)
buf = self._recv_packet()
return self._handle_ok(buf)
@reset_pktnr
def cmd_quit(self):
"""Closes the current connection with the server
Returns the packet that was send.
"""
pkt = self._pkt_make_command(ServerCmd.QUIT)
self.conn.send(pkt)
return pkt
@reset_pktnr
def cmd_init_db(self, database):
"""Change the current database
Change the current (default) database.
Returns a dict() with OK-packet information.
"""
pkt = self._pkt_make_command(ServerCmd.INIT_DB, database)
self.conn.send(pkt)
buf = self._recv_packet()
return self._handle_ok(buf)
@reset_pktnr
def cmd_shutdown(self):
"""Shuts down the MySQL Server
Careful with this command if you have SUPER privileges! (Which your
scripts probably don't need!)
Returns a dict() with OK-packet information.
"""
pkt = self._pkt_make_command(ServerCmd.SHUTDOWN)
self.conn.send(pkt)
buf = self._recv_packet()
return self._handle_eof(buf)
@reset_pktnr
def cmd_statistics(self):
"""Sends statistics command to the MySQL Server
Returns a dictionary with various statistical information.
"""
pkt = self._pkt_make_command(ServerCmd.STATISTICS)
self.conn.send(pkt)
buf = self._recv_packet()
buf = buf[4:]
errmsg = "Failed getting COM_STATISTICS information"
res = {}
# Information is separated by 2 spaces
pairs = buf.split('\x20\x20')
for pair in pairs:
try:
(lbl,val) = [ v.strip() for v in pair.split(':',2) ]
except:
raise errors.InterfaceError(errmsg)
# It's either an integer or a decimal
try:
res[lbl] = long(val)
except:
try:
res[lbl] = Decimal(val)
except:
raise errors.InterfaceError(
"%s (%s:%s)." % (errmsg, lbl, val))
return res
@reset_pktnr
def cmd_process_info(self):
"""Gets the process list from the MySQL Server
(Unsupported)
"""
raise errors.NotSupportedError(
"Not implemented. Use a cursor to get processlist information.")
@reset_pktnr
def cmd_process_kill(self, mypid):
"""Kills a MySQL process using it's ID
Returns a dict() with OK-packet information.
"""
pkt = self._pkt_make_command(ServerCmd.PROCESS_KILL,
utils.int4store(mypid))
self.conn.send(pkt)
buf = self._recv_packet()
return self._handle_ok(buf)
@reset_pktnr
def cmd_debug(self):
"""Send DEBUG command to the MySQL Server
Needs SUPER privileges. The output will go to the MySQL server error
log.
Returns a dict() with EOF-packet information.
"""
pkt = self._pkt_make_command(ServerCmd.DEBUG)
self.conn.send(pkt)
buf = self._recv_packet()
return self._handle_eof(buf)
@reset_pktnr
def cmd_ping(self):
"""Ping the MySQL server to check if the connection is still alive
Raises errors.Error or an error derived from it when it fails
to ping the MySQL server.
Returns a dict() with OK-packet information.
"""
pkt = self._pkt_make_command(ServerCmd.PING)
self.conn.send(pkt)
buf = self._recv_packet()
return self._handle_ok(buf)
@reset_pktnr
def cmd_change_user(self, username='', password='', database=''):
"""Change the user and optionally the current database
Returns a dict() with OK-packet information.
"""
_charset = self.charset or 33
pkt = self._pkt_make_changeuser(username=username, password=password,
database=database, charset=_charset, seed=self.scramble)
self.conn.send(pkt)
buf = self._recv_packet()
return self._handle_ok(buf)