mythboxee/mysql/connector/cursor.py

652 lines
20 KiB
Python

# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2009,2010, Oracle and/or its affiliates. All rights reserved.
# Use is subject to license terms. (See COPYING)
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# There are special exceptions to the terms and conditions of the GNU
# General Public License as it is applied to this software. View the
# full text of the exception in file EXCEPTIONS-CLIENT in the directory
# of this software distribution or see the FOSS License Exception at
# www.mysql.com.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
"""Cursor classes
"""
import sys
from collections import deque
import weakref
import re
import constants
import protocol
import errors
import utils
RE_SQL_COMMENT = re.compile("\/\*.*\*\/")
RE_SQL_INSERT_VALUES = re.compile(r'\sVALUES\s*(\(.*\))', re.I)
RE_SQL_INSERT_STMT = re.compile(r'INSERT\s+INTO', re.I)
class CursorBase(object):
"""
Base for defining MySQLCursor. This class is a skeleton and defines
methods and members as required for the Python Database API
Specification v2.0.
It's better to inherite from MySQLCursor.
"""
def __init__(self):
self.description = None
self.rowcount = -1
self.arraysize = 1
def __del__(self):
self.close()
def callproc(self, procname, args=()):
pass
def close(self):
pass
def execute(self, operation, params=()):
pass
def executemany(self, operation, seqparams):
pass
def fetchone(self):
pass
def fetchmany(self, size=1):
pass
def fetchall(self):
pass
def nextset(self):
pass
def setinputsizes(self, sizes):
pass
def setoutputsize(self, size, column=None):
pass
def reset(self):
pass
class MySQLCursor(CursorBase):
"""
Default cursor which fetches all rows and stores it for later
usage. It uses the converter set for the MySQLConnection to map
MySQL types to Python types automatically.
This class should be inherited whenever other functionallity is
required. An example would to change the fetch* member functions
to return dictionaries instead of lists of values.
Implements the Python Database API Specification v2.0.
Possible parameters are:
db
A MySQLConnection instance.
"""
def __init__(self, db=None):
CursorBase.__init__(self)
self.db = None
self._more_results = False
self._results = deque()
self._nextrow = (None, None)
self.lastrowid = None
self._warnings = None
self._warning_count = 0
self._executed = None
self._have_result = False
self._raise_on_warnings = True
if db is not None:
self.set_connection(db)
def __iter__(self):
"""
Iteration over the result set which calls self.fetchone()
and returns the next row.
"""
return iter(self.fetchone, None)
def set_connection(self, db):
try:
if isinstance(db.protocol,protocol.MySQLProtocol):
self.db = weakref.ref(db)
if self not in self.db().cursors:
self.db().cursors.append(self)
except:
raise errors.InterfaceError(errno=2048)
def _reset_result(self):
self.rowcount = -1
self._nextrow = (None, None)
self._have_result = False
try:
self.db().unread_result = False
except:
pass
self._warnings = None
self._warning_count = 0
self.description = ()
self.reset()
def next(self):
"""
Used for iterating over the result set. Calles self.fetchone()
to get the next row.
"""
try:
row = self.fetchone()
except errors.InterfaceError:
raise StopIteration
if not row:
raise StopIteration
return row
def close(self):
"""
Close the cursor, disconnecting it from the MySQL object.
Returns True when succesful, otherwise False.
"""
if self.db is None:
return False
try:
self._reset_result()
self.db().remove_cursor(self)
self.db = None
except:
return False
return True
def _process_params_dict(self, params):
try:
to_mysql = self.db().converter.to_mysql
escape = self.db().converter.escape
quote = self.db().converter.quote
res = {}
for k,v in params.items():
c = v
c = to_mysql(c)
c = escape(c)
c = quote(c)
res[k] = c
except StandardError, e:
raise errors.ProgrammingError(
"Failed processing pyformat-parameters; %s" % e)
else:
return res
return None
def _process_params(self, params):
"""
Process the parameters which were given when self.execute() was
called. It does following using the MySQLConnection converter:
* Convert Python types to MySQL types
* Escapes characters required for MySQL.
* Quote values when needed.
Returns a list.
"""
if isinstance(params,dict):
return self._process_params_dict(params)
try:
res = params
to_mysql = self.db().converter.to_mysql
escape = self.db().converter.escape
quote = self.db().converter.quote
res = map(to_mysql,res)
res = map(escape,res)
res = map(quote,res)
except StandardError, e:
raise errors.ProgrammingError(
"Failed processing format-parameters; %s" % e)
else:
return tuple(res)
return None
def _row_to_python(self, rowdata, desc=None):
res = ()
try:
to_python = self.db().converter.to_python
if not desc:
desc = self.description
for idx,v in enumerate(rowdata):
flddsc = desc[idx]
res += (to_python(flddsc, v),)
except StandardError, e:
raise errors.InterfaceError(
"Failed converting row to Python types; %s" % e)
else:
return res
return None
def _handle_noresultset(self, res):
"""Handles result of execute() when there is no result set
"""
try:
self.rowcount = res['affected_rows']
self.lastrowid = res['insert_id']
self._warning_count = res['warning_count']
if self.db().get_warnings is True and self._warning_count:
self._warnings = self._fetch_warnings()
self._set_more_results(res['server_status'])
except errors.Error:
raise
except StandardError, e:
raise errors.ProgrammingError(
"Failed handling non-resultset; %s" % e)
def _handle_resultset(self):
pass
def _handle_result(self, res):
if isinstance(res, dict):
self.db().unread_result = False
self._have_result = False
self._handle_noresultset(res)
else:
self.description = res[1]
self.db().unread_result = True
self._have_result = True
self._handle_resultset()
def execute(self, operation, params=None):
"""
Executes the given operation. The parameters given through params
are used to substitute %%s in the operation string.
For example, getting all rows where id is 5:
cursor.execute("SELECT * FROM t1 WHERE id = %s", (5,))
If warnings where generated, and db.get_warnings is True, then
self._warnings will be a list containing these warnings.
Raises exceptions when any error happens.
"""
if not operation:
return 0
if self.db().unread_result is True:
raise errors.InternalError("Unread result found.")
self._reset_result()
stmt = ''
try:
if isinstance(operation, unicode):
operation = operation.encode(self.db().charset_name)
if params is not None:
try:
stmt = operation % self._process_params(params)
except TypeError:
raise errors.ProgrammingError(
"Wrong number of arguments during string formatting")
else:
stmt = operation
res = self.db().protocol.cmd_query(stmt)
self._handle_result(res)
except (UnicodeDecodeError,UnicodeEncodeError), e:
raise errors.ProgrammingError(str(e))
except errors.Error:
raise
except StandardError, e:
raise errors.InterfaceError, errors.InterfaceError(
"Failed executing the operation; %s" % e), sys.exc_info()[2]
else:
self._executed = stmt
return self.rowcount
return 0
def executemany(self, operation, seq_params):
"""Loops over seq_params and calls execute()
INSERT statements are optimized by batching the data, that is
using the MySQL multiple rows syntax.
"""
if not operation:
return 0
if self.db().unread_result is True:
raise errors.InternalError("Unread result found.")
# Optimize INSERTs by batching them
if re.match(RE_SQL_INSERT_STMT,operation):
opnocom = re.sub(RE_SQL_COMMENT,'',operation)
m = re.search(RE_SQL_INSERT_VALUES,opnocom)
fmt = m.group(1)
values = []
for params in seq_params:
values.append(fmt % self._process_params(params))
operation = re.sub(re.escape(m.group(1)),
','.join(values),operation,count=1)
self.execute(operation)
else:
rowcnt = 0
try:
for params in seq_params:
self.execute(operation, params)
if self._have_result:
self.fetchall()
rowcnt += self.rowcount
except (ValueError,TypeError), e:
raise errors.InterfaceError(
"Failed executing the operation; %s" % e)
except:
# Raise whatever execute() raises
raise
self.rowcount = rowcnt
return self.rowcount
def _set_more_results(self, flags):
flag = constants.ServerFlag.MORE_RESULTS_EXISTS
self._more_results = constants.flag_is_set(flag, flags)
def next_resultset(self):
"""Gets next result after executing multiple statements
When more results are available, this function will reset the
current result and advance to the next set.
This is useful when executing multiple statements. If you need
to retrieve multiple results after executing a stored procedure
using callproc(), use next_proc_resultset() instead.
"""
if self._more_results is True:
buf = self.db().protocol._recv_packet()
res = self.db().protocol.handle_cmd_result(buf)
self._reset_result()
self._handle_result(res)
return True
return None
def next_proc_resultset(self):
"""Get the next result set after calling a stored procedure
Returns a MySQLCursorBuffered-object"""
try:
return self._results.popleft()
except IndexError:
return None
except:
raise
return None
def callproc(self, procname, args=()):
"""Calls a stored procedue with the given arguments
The arguments will be set during this session, meaning
they will be called like _<procname>__arg<nr> where
<nr> is an enumeration (+1) of the arguments.
Coding Example:
1) Definining the Stored Routine in MySQL:
CREATE PROCEDURE multiply(IN pFac1 INT, IN pFac2 INT, OUT pProd INT)
BEGIN
SET pProd := pFac1 * pFac2;
END
2) Executing in Python:
args = (5,5,0) # 0 is to hold pprod
cursor.callproc(multiply, args)
print cursor.fetchone()
The last print should output ('5', '5', 25L)
Does not return a value, but a result set will be
available when the CALL-statement execute succesfully.
Raises exceptions when something is wrong.
"""
argfmt = "@_%s_arg%d"
self._results = deque()
try:
procargs = self._process_params(args)
argnames = []
for idx,arg in enumerate(procargs):
argname = argfmt % (procname, idx+1)
argnames.append(argname)
setquery = "SET %s=%%s" % argname
self.execute(setquery, (arg,))
call = "CALL %s(%s)" % (procname,','.join(argnames))
res = self.db().protocol.cmd_query(call)
while not isinstance(res, dict):
tmp = MySQLCursorBuffered(self.db())
tmp.description = res[1]
tmp._handle_resultset()
self._results.append(tmp)
buf = self.db().protocol._recv_packet()
res = self.db().protocol.handle_cmd_result(buf)
try:
select = "SELECT %s" % ','.join(argnames)
self.execute(select)
return self.fetchone()
except:
raise
except errors.Error:
raise
except StandardError, e:
raise errors.InterfaceError(
"Failed calling stored routine; %s" % e)
def getlastrowid(self):
return self.lastrowid
def _fetch_warnings(self):
"""
Fetch warnings doing a SHOW WARNINGS. Can be called after getting
the result.
Returns a result set or None when there were no warnings.
"""
res = []
try:
c = self.db().cursor()
cnt = c.execute("SHOW WARNINGS")
res = c.fetchall()
c.close()
except StandardError, e:
raise errors.InterfaceError(
"Failed getting warnings; %s" % e)
if self.db().raise_on_warnings is True:
msg = '; '.join([ "(%s) %s" % (r[1],r[2]) for r in res])
raise errors.get_mysql_exception(res[0][1],res[0][2])
else:
if len(res):
return res
return None
def _handle_eof(self, eof):
self._have_result = False
self.db().unread_result = False
self._nextrow = (None, None)
self._warning_count = eof['warning_count']
if self.db().get_warnings is True and eof['warning_count']:
self._warnings = self._fetch_warnings()
self._set_more_results(eof['status_flag'])
def _fetch_row(self):
if self._have_result is False:
return None
row = None
try:
if self._nextrow == (None, None):
(row, eof) = self.db().protocol.get_row()
else:
(row, eof) = self._nextrow
if row:
(foo, eof) = self._nextrow = \
self.db().protocol.get_row()
if eof is not None:
self._handle_eof(eof)
if self.rowcount == -1:
self.rowcount = 1
else:
self.rowcount += 1
if eof:
self._handle_eof(eof)
except:
raise
else:
return row
return None
def fetchwarnings(self):
return self._warnings
def fetchone(self):
row = self._fetch_row()
if row:
return self._row_to_python(row)
return None
def fetchmany(self,size=None):
res = []
cnt = (size or self.arraysize)
while cnt > 0 and self._have_result:
cnt -= 1
row = self.fetchone()
if row:
res.append(row)
return res
def fetchall(self):
if self._have_result is False:
raise errors.InterfaceError("No result set to fetch from.")
res = []
row = None
while self.db().unread_result:
row = self.fetchone()
if row:
res.append(row)
return res
@property
def column_names(self):
return tuple( [d[0].decode('utf8') for d in self.description] )
def __unicode__(self):
fmt = "MySQLCursor: %s"
if self._executed:
if len(self._executed) > 30:
res = fmt % (self._executed[:30] + '..')
else:
res = fmt % (self._executed)
else:
res = fmt % '(Nothing executed yet)'
return res
def __str__(self):
return repr(self.__unicode__())
class MySQLCursorBuffered(MySQLCursor):
"""Cursor which fetches rows within execute()"""
def __init__(self, db=None):
MySQLCursor.__init__(self, db)
self._rows = None
self._next_row = 0
def _handle_resultset(self):
(self._rows, eof) = self.db().protocol.get_rows()
self.rowcount = len(self._rows)
self._handle_eof(eof)
self._next_row = 0
try:
self.db().unread_result = False
except:
pass
def reset(self):
self._rows = None
def _fetch_row(self):
row = None
try:
row = self._rows[self._next_row]
except:
return None
else:
self._next_row += 1
return row
return None
def fetchall(self):
if self._rows is None:
raise errors.InterfaceError("No result set to fetch from.")
res = []
for row in self._rows:
res.append(self._row_to_python(row))
self._next_row = len(self._rows)
return res
def fetchmany(self,size=None):
res = []
cnt = (size or self.arraysize)
while cnt > 0:
cnt -= 1
row = self.fetchone()
if row:
res.append(row)
return res
class MySQLCursorRaw(MySQLCursor):
def fetchone(self):
row = self._fetch_row()
if row:
return tuple(row)
return None
class MySQLCursorBufferedRaw(MySQLCursorBuffered):
def fetchone(self):
row = self._fetch_row()
if row:
return tuple(row)
return None
def fetchall(self):
if self._rows is None:
raise errors.InterfaceError("No result set to fetch from.")
return [ tuple(r) for r in self._rows ]