mythboxee/mysql/connector/protocol.py

864 lines
25 KiB
Python
Raw Normal View History

2010-07-30 19:21:27 -07:00
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright 2009 Sun Microsystems, Inc. All rights reserved
# Use is subject to license terms. (See COPYING)
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# There are special exceptions to the terms and conditions of the GNU
# General Public License as it is applied to this software. View the
# full text of the exception in file EXCEPTIONS-CLIENT in the directory
# of this software distribution or see the FOSS License Exception at
# www.mysql.com.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USAs
"""Implementing the MySQL Client/Server protocol
"""
import string
import socket
import re
import struct
try:
from hashlib import sha1
except ImportError:
from sha import new as sha1
from datetime import datetime
from time import strptime
from decimal import Decimal
from constants import *
import errors
import utils
class MySQLProtocol(object):
"""Class handling the MySQL Protocol.
MySQL v4.1 Client/Server Protocol is currently supported.
"""
def __init__(self, conn, handshake=None):
self.client_flags = 0
self.conn = conn # MySQL Connection
if handshake:
self.set_handshake(handshake)
def handle_header(self, buf):
"""Takes a buffer and readers information from header.
Returns a tuple (pktsize, pktnr)
"""
pktsize = utils.int3read(buf[0:3])
pktnr = utils.int1read(buf[3])
return (pktsize, pktnr)
def do_auth(self, username=None, password=None, database=None,
client_flags=None):
"""
Make and send the authentication using information found in the
handshake packet.
"""
if not client_flags:
client_flags = ClientFlag.get_default()
auth = Auth(client_flags=client_flags,
pktnr=self.handshake.pktnr+1)
auth.create(username=username, password=password,
seed=self.handshake.info['seed'], database=database)
self.conn.send(auth.get())
buf = self.conn.recv()[0]
if self.is_eof(buf):
raise errors.InterfaceError("Found EOF after Auth, expecting OK. Using old passwords?")
connect_with_db = client_flags & ClientFlag.CONNECT_WITH_DB
if self.is_ok(buf) and database and not connect_with_db:
self.cmd_init_db(database)
def handle_handshake(self, buf):
"""
Check whether the buffer is a valid handshake. If it is, we set some
member variables for later usage. The handshake packet is returned for later
usuage, e.g. authentication.
"""
if self.is_error(buf):
# an ErrorPacket is returned by the server
self._handle_error(buf)
handshake = None
try:
handshake = Handshake(buf)
except errors.InterfaceError, msg:
raise errors.InterfaceError(msg)
self.set_handshake(handshake)
def set_handshake(self, handshake):
"""Gather data from the given handshake."""
ver = re.compile("^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)")
version = handshake.info['version']
m = ver.match(version)
if not m:
raise errors.InterfaceError("Could not parse MySQL version, was '%s'" % version)
else:
self.server_version = tuple([ int(v) for v in m.groups()[0:3]])
self.server_version_original = handshake.info['version']
self.server_threadid = handshake.info['thrdid']
self.capabilities = handshake.info['capabilities']
self.charset = handshake.info['charset']
self.threadid = handshake.info['thrdid']
self.handshake = handshake
def _handle_error(self, buf):
"""Raise an OperationalError if result is an error
"""
try:
err = ErrorResultPacket(buf)
except errors.InterfaceError, e:
raise e
else:
raise errors.OperationalError(err)
def is_error(self, buf):
"""Check if the given buffer is a MySQL Error Packet.
Buffer should start with \xff.
Returns boolean.
"""
if buf and buf[4] == '\xff':
self._handle_error(buf)
return True
return False
def _handle_ok(self, buf):
"""
Handle an OK Result Packet. If we got an InterfaceError, raise that
instead.
"""
try:
ok = OKResultPacket(buf)
except errors.InterfaceError, e:
raise e
else:
self.server_status = ok.server_status
self.warning_count = ok.warning_count
self.field_count = ok.field_count
self.affected_rows = ok.affected_rows
self.info_msg = ok.info_msg
def is_ok(self, buf):
"""
Check if the given buffer is a MySQL OK Packet. It should
start with \x00.
Returns boolean.
"""
if buf and buf[4] == '\x00':
self._handle_ok(buf)
return True
return False
def _handle_fields(self, nrflds):
"""Reads a number of fields from a result set."""
i = 0
fields = []
while i < nrflds:
buf = self.conn.recv()[0]
fld = FieldPacket(buf)
fields.append(fld)
i += 1
return fields
def is_eof(self, buf):
"""
Check if the given buffer is a MySQL EOF Packet. It should
start with \xfe and be smaller 9 bytes.
Returns boolean.
"""
l = utils.read_int(buf, 3)[1]
if buf and buf[4] == '\xfe' and l < 9:
return True
return False
def _handle_resultset(self, pkt):
"""Processes a resultset getting fields information.
The argument pkt must be a protocol.Packet with length 1, a byte
which contains the number of fields.
"""
if not isinstance(pkt, PacketIn):
raise ValueError("%s is not a protocol.PacketIn" % pkt)
if len(pkt) == 1:
(buf,nrflds) = utils.read_lc_int(pkt.data)
# Get the fields
fields = self._handle_fields(nrflds)
buf = self.conn.recv()[0]
eof = EOFPacket(buf)
return (nrflds, fields, eof)
else:
raise errors.InterfaceError('Something wrong reading result after query.')
def result_get_row(self):
"""Get data for 1 row
Get one row's data. Should be called after getting the field
descriptions.
Returns a tuple with 2 elements: a row's data and the
EOF packet.
"""
buf = self.conn.recv()[0]
if self.is_eof(buf):
eof = EOFPacket(buf)
rowdata = None
else:
eof = None
rowdata = utils.read_lc_string_list(buf[4:])
return (rowdata, eof)
def result_get_rows(self, cnt=None):
"""Get all rows
Returns a tuple with 2 elements: a list with all rows and
the EOF packet.
"""
rows = []
eof = None
rowdata = None
while eof is None:
(rowdata,eof) = self.result_get_row()
if eof is None and rowdata is not None:
rows.append(rowdata)
return (rows,eof)
def cmd_query(self, query):
"""
Sends a query to the server.
Returns a tuple, when the query returns a result. The tuple
consist number of fields and a list containing their descriptions.
If the query doesn't return a result set, the an OKResultPacket
will be returned.
"""
try:
cmd = CommandPacket()
cmd.set_command(ServerCmd.QUERY)
cmd.set_argument(query)
cmd.create()
self.conn.send(cmd.get()) # Errors handled in _handle_error()
buf = self.conn.recv()[0]
if self.is_ok(buf):
# Query does not return a result (INSERT/DELETE/..)
return OKResultPacket(buf)
p = PacketIn(buf)
(nrflds, fields, eof) = self._handle_resultset(p)
except:
raise
else:
return (nrflds, fields)
return (0, ())
def _cmd_simple(self, servercmd, arg=''):
"""Makes a simple CommandPacket with no arguments"""
cmd = CommandPacket()
cmd.set_command(servercmd)
cmd.set_argument(arg)
cmd.create()
return cmd
def cmd_refresh(self, opts):
"""Send the Refresh command to the MySQL server.
The argument should be a bitwise value using the protocol.RefreshOption
constants.
Usage:
RefreshOption = mysql.connector.RefreshOption
refresh = RefreshOption.LOG | RefreshOption.THREADS
db.cmd_refresh(refresh)
"""
cmd = self._cmd_simple(ServerCmd.REFRESH, opts)
try:
self.conn.send(cmd.get())
buf = self.conn.recv()[0]
except:
raise
if self.is_ok(buf):
return True
return False
def cmd_quit(self):
"""Closes the current connection with the server."""
cmd = self._cmd_simple(ServerCmd.QUIT)
self.conn.send(cmd.get())
def cmd_init_db(self, database):
"""
Send command to server to change databases.
"""
cmd = self._cmd_simple(ServerCmd.INIT_DB, database)
self.conn.send(cmd.get())
self.conn.recv()[0]
def cmd_shutdown(self):
"""Shuts down the MySQL Server.
Careful with this command if you have SUPER privileges! (Which your
scripts probably don't need!)
Returns True if it succeeds.
"""
cmd = self._cmd_simple(ServerCmd.SHUTDOWN)
try:
self.conn.send(cmd.get())
buf = self.conn.recv()[0]
except:
raise
return True
def cmd_statistics(self):
"""Sends statistics command to the MySQL Server
Returns a dictionary with various statistical information.
"""
cmd = self._cmd_simple(ServerCmd.STATISTICS)
try:
self.conn.send(cmd.get())
buf = self.conn.recv()[0]
except:
raise
p = Packet(buf)
info = str(p.data)
res = {}
pairs = info.split('\x20\x20') # Information is separated by 2 spaces
for pair in pairs:
(lbl,val) = [ v.strip() for v in pair.split(':') ]
# It's either an integer or a decimal
try:
res[lbl] = long(val)
except:
try:
res[lbl] = Decimal(val)
except:
raise ValueError(
"Got wrong value in COM_STATISTICS information (%s : %s)." % (lbl, val))
return res
def cmd_process_info(self):
"""Gets the process list from the MySQL Server.
Returns a list of dictionaries which corresponds to the output of
SHOW PROCESSLIST of MySQL. The data is converted to Python types.
"""
raise errors.NotSupportedError(
"Not implemented. Use a cursor to get processlist information.")
def cmd_process_kill(self, mypid):
"""Kills a MySQL process using it's ID.
The mypid must be an integer.
"""
cmd = KillPacket(mypid)
cmd.create()
try:
self.conn.send(cmd.get())
buf = self.conn.recv()[0]
except:
raise
if self.is_eof(buf):
return True
return False
def cmd_debug(self):
"""Send DEBUG command to the MySQL Server
Needs SUPER privileges. The output will go to the MySQL server error log.
Returns True when it was succesful.
"""
cmd = self._cmd_simple(ServerCmd.DEBUG)
try:
self.conn.send(cmd.get())
buf = self.conn.recv()[0]
except:
raise
if self.is_eof(buf):
return True
return False
def cmd_ping(self):
"""
Ping the MySQL server to check if the connection is still alive.
Returns True when alive, False when server doesn't respond.
"""
cmd = self._cmd_simple(ServerCmd.PING)
try:
self.conn.send(cmd.get())
buf = self.conn.recv()[0]
except:
return False
else:
if self.is_ok(buf):
return True
return False
def cmd_change_user(self, username, password, database=None):
"""Change the user with given username and password to another optional database.
"""
if not database:
database = self.database
cmd = ChangeUserPacket()
cmd.create(username=username, password=password, database=database,
charset=self.charset, seed=self.handshake.info['seed'])
try:
self.conn.send(cmd.get())
buf = self.conn.recv()[0]
except:
raise
if not self.is_ok(buf):
raise errors.OperationalError(
"Failed getting OK Packet after changing user")
return True
class BasePacket(object):
def __len__(self):
try:
return len(self.data)
except:
return 0
def is_valid(self, buf=None):
if buf is None:
buf = self.data
(l, n) = (buf[0:3], buf[3])
hlength = utils.int3read(l)
rlength = len(buf) - 4
if hlength != rlength:
return False
res = self._is_valid_extra(buf)
if res != None:
return res
return True
def _is_valid_extra(self, buf):
return True
class PacketIn(BasePacket):
def __init__(self, buf=None, pktnr=0):
self.data = ''
self.pktnr = pktnr
self.protocol = 10
if buf:
self.is_valid(buf)
self.data = buf[4:]
if self.data:
self._parse()
def _parse(self):
pass
class PacketOut(BasePacket):
"""
Each packet type used in the MySQL Client Protocol is build on the Packet
class. It defines lots of useful functions for parsing and sending
data to and from the MySQL Server.
"""
def __init__(self, buf=None, pktnr=0):
self.data = ''
self.pktnr = pktnr
self.protocol = 10
if buf:
self.set(buf)
if self.data:
self._parse()
def _make_header(self):
h = utils.int3store(len(self)) + utils.int1store(self.pktnr)
return h
def _parse(self):
pass
def add(self, s):
if not s:
self.add_null()
else:
self.data = self.data + s
def add_1_int(self, i):
self.add(utils.int1store(i))
def add_2_int(self, i):
self.add(utils.int2store(i))
def add_3_int(self, i):
self.add(utils.int3store(i))
def add_4_int(self, i):
self.add(utils.int4store(i))
def add_null(self, nr=1):
self.add('\x00'*nr)
def get(self):
return self._make_header() + self.data
def get_header(self):
return self._make_header()
def set(self, buf):
if not self.is_valid(buf):
raise errors.InterfaceError('Packet not valid.')
self.data = buf[4:]
def _is_valid_extra(self, buf=None):
return None
class Handshake(PacketIn):
def __init__(self, buf=None):
PacketIn.__init__(self, buf)
def _parse(self):
version = ''
options = 0
srvstatus = 0
buf = self.data
(buf,self.protocol) = utils.read_int(buf,1)
(buf,version) = utils.read_string(buf,end='\x00')
(buf,thrdid) = utils.read_int(buf,4)
(buf,scramble) = utils.read_bytes(buf, 8)
buf = buf[1:] # Filler 1 * \x00
(buf,srvcap) = utils.read_int(buf,2)
(buf,charset) = utils.read_int(buf,1)
(buf,serverstatus) = utils.read_int(buf,2)
buf = buf[13:] # Filler 13 * \x00
(buf,scramble_next) = utils.read_bytes(buf,12)
scramble += scramble_next
self.info = {
'version' : version,
'thrdid' : thrdid,
'seed' : scramble,
'capabilities' : srvcap,
'charset' : charset,
'serverstatus' : serverstatus,
}
def get_dict(self):
self._parse()
return self.info
def _is_valid_extra(self, buf):
if buf[3] != '\x00':
return False
return True
class Auth(PacketOut):
def __init__(self, packet=None, client_flags=0, pktnr=0):
PacketOut.__init__(self, packet, pktnr)
self.client_flags = 0
self.username = None
self.password = None
self.database = None
if client_flags:
self.set_client_flags(client_flags)
def set_client_flags(self, flags):
self.client_flags = flags
def set_login(self, username, password, database=None):
self.username = username
self.password = password
self.database = database
def scramble(self, passwd, seed):
hash4 = None
try:
hash1 = sha1(passwd).digest()
hash2 = sha1(hash1).digest() # Password as found in mysql.user()
hash3 = sha1(seed + hash2).digest()
xored = [ utils.int1read(h1) ^ utils.int1read(h3)
for (h1,h3) in zip(hash1, hash3) ]
hash4 = struct.pack('20B', *xored)
except StandardError, e:
raise errors.ProgrammingError('Failed scrambling password; %s' % e)
else:
return hash4
return None
def create(self, username=None, password=None, database=None, seed=None):
self.add_4_int(self.client_flags)
self.add_4_int(10 * 1024 * 1024)
self.add_1_int(8)
self.add_null(23)
self.add(username + '\x00')
if password:
self.add_1_int(20)
self.add(self.scramble(password,seed))
else:
self.add_null(1)
if database:
self.add(database + '\x00')
else:
self.add_null()
class ChangeUserPacket(Auth):
def __init__(self):
self.command = ServerCmd.CHANGE_USER
Auth.__init__(self)
def create(self, username=None, password=None, database=None, charset=8, seed=None):
self.add_1_int(self.command)
self.add(username + '\x00')
if password:
self.add_1_int(20)
self.add(self.scramble(password,seed))
else:
self.add_null(1)
if database:
self.add(database + '\x00')
else:
self.add_null()
self.add_2_int(charset)
class ErrorResultPacket(PacketIn):
def __init__(self, buf=None):
self.errno = 0
self.errmsg = ''
self.sqlstate = None
PacketIn.__init__(self, buf)
def _parse(self):
buf = self.data
if buf[0] != '\xff':
raise errors.InterfaceError('Expected an Error Packet.')
buf = buf[1:]
(buf,self.errno) = utils.read_int(buf, 2)
if buf[0] != '\x23':
# Error without SQLState
self.errmsg = buf
else:
(buf,self.sqlstate) = utils.read_bytes(buf[1:],5)
self.errmsg = buf
class OKResultPacket(PacketIn):
def __init__(self, buf=None):
self.affected_rows = None
self.insert_id = None
self.server_status = 0
self.warning_count = 0
self.field_count = 0
self.info_msg = ''
PacketIn.__init__(self, buf)
def __str__(self):
if self.affected_rows == 1:
lbl_rows = 'row'
else:
lbl_rows = 'rows'
xtr = []
if self.insert_id:
xtr.append('last insert: %d ' % self.insert_id)
if self.warning_count:
xtr.append('warnings: %d' % self.warning_count)
return "Query OK, %d %s affected %s( sec)" % (self.affected_rows,
lbl_rows, ', '.join(xtr))
def _parse(self):
buf = self.data
(buf,self.field_count) = utils.read_int(buf,1)
(buf,self.affected_rows) = utils.read_lc_int(buf)
(buf,self.insert_id) = utils.read_lc_int(buf)
(buf,self.server_status) = utils.read_int(buf,2)
(buf,self.warning_count) = utils.read_int(buf,2)
if buf:
(buf,self.info_msg) = utils.read_lc_string(buf)
class CommandPacket(PacketOut):
def __init__(self, cmd=None, arg=None):
self.command = cmd
self.argument = arg
PacketOut.__init__(self)
def create(self):
self.add_1_int(self.command)
self.add(str(self.argument))
def set_command(self, cmd):
self.command = cmd
def set_argument(self, arg):
self.argument = arg
class KillPacket(CommandPacket):
def __init__(self, arg):
CommandPacket.__init__(self)
self.set_command(ServerCmd.PROCESS_KILL)
self.set_argument(arg)
def create(self):
""""""
self.add_1_int(self.command)
self.add_4_int(self.argument)
def set_argument(self, arg):
if arg and not isinstance(int, long) and arg > 2**32:
raise ValueError, "KillPacket needs integer value as argument not larger than 2^32."
self.argument = arg
class FieldPacket(PacketIn):
def __init__(self, buf=None):
self.catalog = None
self.db = None
self.table = None
self.org_table = None
self.name = None
self.length = None
self.org_name = None
self.charset = None
self.type = None
self.flags = None
PacketIn.__init__(self, buf)
def __str__(self):
flags = []
for k,f in FieldFlag.desc.items():
if int(self.flags) & f[0]:
flags.append(k)
return """
Field: catalog: %s ; db:%s ; table:%s ; org_table: %s ;
name: %s ; org_name: %s ;
charset: %s ; lenght: %s ;
type: %02x ;
flags(%d): %s;
""" % (self.catalog,self.db,self.table,self.org_table,self.name,self.org_name,
self.charset, len(self), self.type,
self.flags, '|'.join(flags))
def _parse(self):
buf = self.data
(buf,self.catalog) = utils.read_lc_string(buf)
(buf,self.db) = utils.read_lc_string(buf)
(buf,self.table) = utils.read_lc_string(buf)
(buf,self.org_table) = utils.read_lc_string(buf)
(buf,self.name) = utils.read_lc_string(buf)
(buf,self.org_name) = utils.read_lc_string(buf)
buf = buf[1:] # filler 1 * \x00
(buf,self.charset) = utils.read_int(buf, 2)
(buf,self.length) = utils.read_int(buf, 4)
(buf,self.type) = utils.read_int(buf, 1)
(buf,self.flags) = utils.read_int(buf, 2)
(buf,self.decimal) = utils.read_int(buf, 1)
buf = buf[2:] # filler 2 * \x00
def get_description(self):
"""Returns a description as a list useful for cursors.
This function returns a list as defined in the Python Db API v2.0
specification.
"""
return (
self.name,
self.type,
None, # display_size
None, # internal_size
None, # precision
None, # scale
~self.flags & FieldFlag.NOT_NULL, # null_ok
self.flags, # MySQL specific
)
class EOFPacket(PacketIn):
def __init__(self, buf=None):
self.warning_count = None
self.status_flag = None
PacketIn.__init__(self, buf)
def __str__(self):
return "EOFPacket: warnings %d / status: %d" % (self.warning_count,self.status_flag)
def _is_valid_extra(self, buf=None):
if not buf:
buf = self.data
else:
buf = buf[4:]
if buf[0] == '\xfe' and len(buf) == 5:
# An EOF should always start with \xfe and smaller than 9 bytes
return True
return False
def _parse(self):
buf = self.data
buf = buf[1:] # disregard the first checking byte
(buf, self.warning_count) = utils.read_int(buf, 2)
(buf, self.status_flag) = utils.read_int(buf, 2)