""" Database API (part of web.py) """ __all__ = [ "UnknownParamstyle", "UnknownDB", "TransactionError", "sqllist", "sqlors", "reparam", "sqlquote", "SQLQuery", "SQLParam", "sqlparam", "SQLLiteral", "sqlliteral", "database", 'DB', ] import time try: import datetime except ImportError: datetime = None try: set except NameError: from sets import Set as set from utils import threadeddict, storage, iters, iterbetter, safestr, safeunicode try: # db module can work independent of web.py from webapi import debug, config except: import sys debug = sys.stderr config = storage() class UnknownDB(Exception): """raised for unsupported dbms""" pass class _ItplError(ValueError): def __init__(self, text, pos): ValueError.__init__(self) self.text = text self.pos = pos def __str__(self): return "unfinished expression in %s at char %d" % ( repr(self.text), self.pos) class TransactionError(Exception): pass class UnknownParamstyle(Exception): """ raised for unsupported db paramstyles (currently supported: qmark, numeric, format, pyformat) """ pass class SQLParam(object): """ Parameter in SQLQuery. >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")]) >>> q >>> q.query() 'SELECT * FROM test WHERE name=%s' >>> q.values() ['joe'] """ __slots__ = ["value"] def __init__(self, value): self.value = value def get_marker(self, paramstyle='pyformat'): if paramstyle == 'qmark': return '?' elif paramstyle == 'numeric': return ':1' elif paramstyle is None or paramstyle in ['format', 'pyformat']: return '%s' raise UnknownParamstyle, paramstyle def sqlquery(self): return SQLQuery([self]) def __add__(self, other): return self.sqlquery() + other def __radd__(self, other): return other + self.sqlquery() def __str__(self): return str(self.value) def __repr__(self): return '' % repr(self.value) sqlparam = SQLParam class SQLQuery(object): """ You can pass this sort of thing as a clause in any db function. Otherwise, you can pass a dictionary to the keyword argument `vars` and the function will call reparam for you. Internally, consists of `items`, which is a list of strings and SQLParams, which get concatenated to produce the actual query. """ __slots__ = ["items"] # tested in sqlquote's docstring def __init__(self, items=None): r"""Creates a new SQLQuery. >>> SQLQuery("x") >>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)]) >>> q >>> q.query(), q.values() ('SELECT * FROM test WHERE x=%s', [1]) >>> SQLQuery(SQLParam(1)) """ if items is None: self.items = [] elif isinstance(items, list): self.items = items elif isinstance(items, SQLParam): self.items = [items] elif isinstance(items, SQLQuery): self.items = list(items.items) else: self.items = [items] # Take care of SQLLiterals for i, item in enumerate(self.items): if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral): self.items[i] = item.value.v def append(self, value): self.items.append(value) def __add__(self, other): if isinstance(other, basestring): items = [other] elif isinstance(other, SQLQuery): items = other.items else: return NotImplemented return SQLQuery(self.items + items) def __radd__(self, other): if isinstance(other, basestring): items = [other] else: return NotImplemented return SQLQuery(items + self.items) def __iadd__(self, other): if isinstance(other, (basestring, SQLParam)): self.items.append(other) elif isinstance(other, SQLQuery): self.items.extend(other.items) else: return NotImplemented return self def __len__(self): return len(self.query()) def query(self, paramstyle=None): """ Returns the query part of the sql query. >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')]) >>> q.query() 'SELECT * FROM test WHERE name=%s' >>> q.query(paramstyle='qmark') 'SELECT * FROM test WHERE name=?' """ s = [] for x in self.items: if isinstance(x, SQLParam): x = x.get_marker(paramstyle) s.append(safestr(x)) else: x = safestr(x) # automatically escape % characters in the query # For backward compatability, ignore escaping when the query looks already escaped if paramstyle in ['format', 'pyformat']: if '%' in x and '%%' not in x: x = x.replace('%', '%%') s.append(x) return "".join(s) def values(self): """ Returns the values of the parameters used in the sql query. >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')]) >>> q.values() ['joe'] """ return [i.value for i in self.items if isinstance(i, SQLParam)] def join(items, sep=' ', prefix=None, suffix=None, target=None): """ Joins multiple queries. >>> SQLQuery.join(['a', 'b'], ', ') Optinally, prefix and suffix arguments can be provided. >>> SQLQuery.join(['a', 'b'], ', ', prefix='(', suffix=')') If target argument is provided, the items are appended to target instead of creating a new SQLQuery. """ if target is None: target = SQLQuery() target_items = target.items if prefix: target_items.append(prefix) for i, item in enumerate(items): if i != 0: target_items.append(sep) if isinstance(item, SQLQuery): target_items.extend(item.items) else: target_items.append(item) if suffix: target_items.append(suffix) return target join = staticmethod(join) def _str(self): try: return self.query() % tuple([sqlify(x) for x in self.values()]) except (ValueError, TypeError): return self.query() def __str__(self): return safestr(self._str()) def __unicode__(self): return safeunicode(self._str()) def __repr__(self): return '' % repr(str(self)) class SQLLiteral: """ Protects a string from `sqlquote`. >>> sqlquote('NOW()') >>> sqlquote(SQLLiteral('NOW()')) """ def __init__(self, v): self.v = v def __repr__(self): return self.v sqlliteral = SQLLiteral def _sqllist(values): """ >>> _sqllist([1, 2, 3]) """ items = [] items.append('(') for i, v in enumerate(values): if i != 0: items.append(', ') items.append(sqlparam(v)) items.append(')') return SQLQuery(items) def reparam(string_, dictionary): """ Takes a string and a dictionary and interpolates the string using values from the dictionary. Returns an `SQLQuery` for the result. >>> reparam("s = $s", dict(s=True)) >>> reparam("s IN $s", dict(s=[1, 2])) """ dictionary = dictionary.copy() # eval mucks with it vals = [] result = [] for live, chunk in _interpolate(string_): if live: v = eval(chunk, dictionary) result.append(sqlquote(v)) else: result.append(chunk) return SQLQuery.join(result, '') def sqlify(obj): """ converts `obj` to its proper SQL version >>> sqlify(None) 'NULL' >>> sqlify(True) "'t'" >>> sqlify(3) '3' """ # because `1 == True and hash(1) == hash(True)` # we have to do this the hard way... if obj is None: return 'NULL' elif obj is True: return "'t'" elif obj is False: return "'f'" elif datetime and isinstance(obj, datetime.datetime): return repr(obj.isoformat()) else: if isinstance(obj, unicode): obj = obj.encode('utf8') return repr(obj) def sqllist(lst): """ Converts the arguments for use in something like a WHERE clause. >>> sqllist(['a', 'b']) 'a, b' >>> sqllist('a') 'a' >>> sqllist(u'abc') u'abc' """ if isinstance(lst, basestring): return lst else: return ', '.join(lst) def sqlors(left, lst): """ `left is a SQL clause like `tablename.arg = ` and `lst` is a list of values. Returns a reparam-style pair featuring the SQL that ORs together the clause for each item in the lst. >>> sqlors('foo = ', []) >>> sqlors('foo = ', [1]) >>> sqlors('foo = ', 1) >>> sqlors('foo = ', [1,2,3]) """ if isinstance(lst, iters): lst = list(lst) ln = len(lst) if ln == 0: return SQLQuery("1=2") if ln == 1: lst = lst[0] if isinstance(lst, iters): return SQLQuery(['('] + sum([[left, sqlparam(x), ' OR '] for x in lst], []) + ['1=2)'] ) else: return left + sqlparam(lst) def sqlwhere(dictionary, grouping=' AND '): """ Converts a `dictionary` to an SQL WHERE clause `SQLQuery`. >>> sqlwhere({'cust_id': 2, 'order_id':3}) >>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ') >>> sqlwhere({'a': 'a', 'b': 'b'}).query() 'a = %s AND b = %s' """ return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping) def sqlquote(a): """ Ensures `a` is quoted properly for use in a SQL query. >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3) >>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3]) """ if isinstance(a, list): return _sqllist(a) else: return sqlparam(a).sqlquery() class Transaction: """Database transaction.""" def __init__(self, ctx): self.ctx = ctx self.transaction_count = transaction_count = len(ctx.transactions) class transaction_engine: """Transaction Engine used in top level transactions.""" def do_transact(self): ctx.commit(unload=False) def do_commit(self): ctx.commit() def do_rollback(self): ctx.rollback() class subtransaction_engine: """Transaction Engine used in sub transactions.""" def query(self, q): db_cursor = ctx.db.cursor() ctx.db_execute(db_cursor, SQLQuery(q % transaction_count)) def do_transact(self): self.query('SAVEPOINT webpy_sp_%s') def do_commit(self): self.query('RELEASE SAVEPOINT webpy_sp_%s') def do_rollback(self): self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s') class dummy_engine: """Transaction Engine used instead of subtransaction_engine when sub transactions are not supported.""" do_transact = do_commit = do_rollback = lambda self: None if self.transaction_count: # nested transactions are not supported in some databases if self.ctx.get('ignore_nested_transactions'): self.engine = dummy_engine() else: self.engine = subtransaction_engine() else: self.engine = transaction_engine() self.engine.do_transact() self.ctx.transactions.append(self) def __enter__(self): return self def __exit__(self, exctype, excvalue, traceback): if exctype is not None: self.rollback() else: self.commit() def commit(self): if len(self.ctx.transactions) > self.transaction_count: self.engine.do_commit() self.ctx.transactions = self.ctx.transactions[:self.transaction_count] def rollback(self): if len(self.ctx.transactions) > self.transaction_count: self.engine.do_rollback() self.ctx.transactions = self.ctx.transactions[:self.transaction_count] class DB: """Database""" def __init__(self, db_module, keywords): """Creates a database. """ # some DB implementaions take optional paramater `driver` to use a specific driver modue # but it should not be passed to connect keywords.pop('driver', None) self.db_module = db_module self.keywords = keywords self._ctx = threadeddict() # flag to enable/disable printing queries self.printing = config.get('debug_sql', config.get('debug', False)) self.supports_multiple_insert = False try: import DBUtils # enable pooling if DBUtils module is available. self.has_pooling = True except ImportError: self.has_pooling = False # Pooling can be disabled by passing pooling=False in the keywords. self.has_pooling = self.keywords.pop('pooling', True) and self.has_pooling def _getctx(self): if not self._ctx.get('db'): self._load_context(self._ctx) return self._ctx ctx = property(_getctx) def _load_context(self, ctx): ctx.dbq_count = 0 ctx.transactions = [] # stack of transactions if self.has_pooling: ctx.db = self._connect_with_pooling(self.keywords) else: ctx.db = self._connect(self.keywords) ctx.db_execute = self._db_execute if not hasattr(ctx.db, 'commit'): ctx.db.commit = lambda: None if not hasattr(ctx.db, 'rollback'): ctx.db.rollback = lambda: None def commit(unload=True): # do db commit and release the connection if pooling is enabled. ctx.db.commit() if unload and self.has_pooling: self._unload_context(self._ctx) def rollback(): # do db rollback and release the connection if pooling is enabled. ctx.db.rollback() if self.has_pooling: self._unload_context(self._ctx) ctx.commit = commit ctx.rollback = rollback def _unload_context(self, ctx): del ctx.db def _connect(self, keywords): return self.db_module.connect(**keywords) def _connect_with_pooling(self, keywords): def get_pooled_db(): from DBUtils import PooledDB # In DBUtils 0.9.3, `dbapi` argument is renamed as `creator` # see Bug#122112 if PooledDB.__version__.split('.') < '0.9.3'.split('.'): return PooledDB.PooledDB(dbapi=self.db_module, **keywords) else: return PooledDB.PooledDB(creator=self.db_module, **keywords) if getattr(self, '_pooleddb', None) is None: self._pooleddb = get_pooled_db() return self._pooleddb.connection() def _db_cursor(self): return self.ctx.db.cursor() def _param_marker(self): """Returns parameter marker based on paramstyle attribute if this database.""" style = getattr(self, 'paramstyle', 'pyformat') if style == 'qmark': return '?' elif style == 'numeric': return ':1' elif style in ['format', 'pyformat']: return '%s' raise UnknownParamstyle, style def _db_execute(self, cur, sql_query): """executes an sql query""" self.ctx.dbq_count += 1 try: a = time.time() query, params = self._process_query(sql_query) out = cur.execute(query, params) b = time.time() except: if self.printing: print >> debug, 'ERR:', str(sql_query) if self.ctx.transactions: self.ctx.transactions[-1].rollback() else: self.ctx.rollback() raise if self.printing: print >> debug, '%s (%s): %s' % (round(b-a, 2), self.ctx.dbq_count, str(sql_query)) return out def _process_query(self, sql_query): """Takes the SQLQuery object and returns query string and parameters. """ paramstyle = getattr(self, 'paramstyle', 'pyformat') query = sql_query.query(paramstyle) params = sql_query.values() return query, params def _where(self, where, vars): if isinstance(where, (int, long)): where = "id = " + sqlparam(where) #@@@ for backward-compatibility elif isinstance(where, (list, tuple)) and len(where) == 2: where = SQLQuery(where[0], where[1]) elif isinstance(where, SQLQuery): pass else: where = reparam(where, vars) return where def query(self, sql_query, vars=None, processed=False, _test=False): """ Execute SQL query `sql_query` using dictionary `vars` to interpolate it. If `processed=True`, `vars` is a `reparam`-style list to use instead of interpolating. >>> db = DB(None, {}) >>> db.query("SELECT * FROM foo", _test=True) >>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True) >>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True) """ if vars is None: vars = {} if not processed and not isinstance(sql_query, SQLQuery): sql_query = reparam(sql_query, vars) if _test: return sql_query db_cursor = self._db_cursor() self._db_execute(db_cursor, sql_query) if db_cursor.description: names = [x[0] for x in db_cursor.description] def iterwrapper(): row = db_cursor.fetchone() while row: yield storage(dict(zip(names, row))) row = db_cursor.fetchone() out = iterbetter(iterwrapper()) out.__len__ = lambda: int(db_cursor.rowcount) out.list = lambda: [storage(dict(zip(names, x))) \ for x in db_cursor.fetchall()] else: out = db_cursor.rowcount if not self.ctx.transactions: self.ctx.commit() return out def select(self, tables, vars=None, what='*', where=None, order=None, group=None, limit=None, offset=None, _test=False): """ Selects `what` from `tables` with clauses `where`, `order`, `group`, `limit`, and `offset`. Uses vars to interpolate. Otherwise, each clause can be a SQLQuery. >>> db = DB(None, {}) >>> db.select('foo', _test=True) >>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True) """ if vars is None: vars = {} sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset) clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None] qout = SQLQuery.join(clauses) if _test: return qout return self.query(qout, processed=True) def where(self, table, what='*', order=None, group=None, limit=None, offset=None, _test=False, **kwargs): """ Selects from `table` where keys are equal to values in `kwargs`. >>> db = DB(None, {}) >>> db.where('foo', bar_id=3, _test=True) >>> db.where('foo', source=2, crust='dewey', _test=True) >>> db.where('foo', _test=True) """ where_clauses = [] for k, v in kwargs.iteritems(): where_clauses.append(k + ' = ' + sqlquote(v)) if where_clauses: where = SQLQuery.join(where_clauses, " AND ") else: where = None return self.select(table, what=what, order=order, group=group, limit=limit, offset=offset, _test=_test, where=where) def sql_clauses(self, what, tables, where, group, order, limit, offset): return ( ('SELECT', what), ('FROM', sqllist(tables)), ('WHERE', where), ('GROUP BY', group), ('ORDER BY', order), ('LIMIT', limit), ('OFFSET', offset)) def gen_clause(self, sql, val, vars): if isinstance(val, (int, long)): if sql == 'WHERE': nout = 'id = ' + sqlquote(val) else: nout = SQLQuery(val) #@@@ elif isinstance(val, (list, tuple)) and len(val) == 2: nout = SQLQuery(val[0], val[1]) # backwards-compatibility elif isinstance(val, SQLQuery): nout = val else: nout = reparam(val, vars) def xjoin(a, b): if a and b: return a + ' ' + b else: return a or b return xjoin(sql, nout) def insert(self, tablename, seqname=None, _test=False, **values): """ Inserts `values` into `tablename`. Returns current sequence ID. Set `seqname` to the ID if it's not the default, or to `False` if there isn't one. >>> db = DB(None, {}) >>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True) >>> q >>> q.query() 'INSERT INTO foo (age, name, created) VALUES (%s, %s, NOW())' >>> q.values() [2, 'bob'] """ def q(x): return "(" + x + ")" if values: _keys = SQLQuery.join(values.keys(), ', ') _values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ') sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values) else: sql_query = SQLQuery(self._get_insert_default_values_query(tablename)) if _test: return sql_query db_cursor = self._db_cursor() if seqname is not False: sql_query = self._process_insert_query(sql_query, tablename, seqname) if isinstance(sql_query, tuple): # for some databases, a separate query has to be made to find # the id of the inserted row. q1, q2 = sql_query self._db_execute(db_cursor, q1) self._db_execute(db_cursor, q2) else: self._db_execute(db_cursor, sql_query) try: out = db_cursor.fetchone()[0] except Exception: out = None if not self.ctx.transactions: self.ctx.commit() return out def _get_insert_default_values_query(self, table): return "INSERT INTO %s DEFAULT VALUES" % table def multiple_insert(self, tablename, values, seqname=None, _test=False): """ Inserts multiple rows into `tablename`. The `values` must be a list of dictioanries, one for each row to be inserted, each with the same set of keys. Returns the list of ids of the inserted rows. Set `seqname` to the ID if it's not the default, or to `False` if there isn't one. >>> db = DB(None, {}) >>> db.supports_multiple_insert = True >>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}] >>> db.multiple_insert('person', values=values, _test=True) """ if not values: return [] if not self.supports_multiple_insert: out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values] if seqname is False: return None else: return out keys = values[0].keys() #@@ make sure all keys are valid # make sure all rows have same keys. for v in values: if v.keys() != keys: raise ValueError, 'Bad data' sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys))) for i, row in enumerate(values): if i != 0: sql_query.append(", ") SQLQuery.join([SQLParam(row[k]) for k in keys], sep=", ", target=sql_query, prefix="(", suffix=")") if _test: return sql_query db_cursor = self._db_cursor() if seqname is not False: sql_query = self._process_insert_query(sql_query, tablename, seqname) if isinstance(sql_query, tuple): # for some databases, a separate query has to be made to find # the id of the inserted row. q1, q2 = sql_query self._db_execute(db_cursor, q1) self._db_execute(db_cursor, q2) else: self._db_execute(db_cursor, sql_query) try: out = db_cursor.fetchone()[0] out = range(out-len(values)+1, out+1) except Exception: out = None if not self.ctx.transactions: self.ctx.commit() return out def update(self, tables, where, vars=None, _test=False, **values): """ Update `tables` with clause `where` (interpolated using `vars`) and setting `values`. >>> db = DB(None, {}) >>> name = 'Joseph' >>> q = db.update('foo', where='name = $name', name='bob', age=2, ... created=SQLLiteral('NOW()'), vars=locals(), _test=True) >>> q >>> q.query() 'UPDATE foo SET age = %s, name = %s, created = NOW() WHERE name = %s' >>> q.values() [2, 'bob', 'Joseph'] """ if vars is None: vars = {} where = self._where(where, vars) query = ( "UPDATE " + sqllist(tables) + " SET " + sqlwhere(values, ', ') + " WHERE " + where) if _test: return query db_cursor = self._db_cursor() self._db_execute(db_cursor, query) if not self.ctx.transactions: self.ctx.commit() return db_cursor.rowcount def delete(self, table, where, using=None, vars=None, _test=False): """ Deletes from `table` with clauses `where` and `using`. >>> db = DB(None, {}) >>> name = 'Joe' >>> db.delete('foo', where='name = $name', vars=locals(), _test=True) """ if vars is None: vars = {} where = self._where(where, vars) q = 'DELETE FROM ' + table if using: q += ' USING ' + sqllist(using) if where: q += ' WHERE ' + where if _test: return q db_cursor = self._db_cursor() self._db_execute(db_cursor, q) if not self.ctx.transactions: self.ctx.commit() return db_cursor.rowcount def _process_insert_query(self, query, tablename, seqname): return query def transaction(self): """Start a transaction.""" return Transaction(self.ctx) class PostgresDB(DB): """Postgres driver.""" def __init__(self, **keywords): if 'pw' in keywords: keywords['password'] = keywords.pop('pw') db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None)) if db_module.__name__ == "psycopg2": import psycopg2.extensions psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) # if db is not provided postgres driver will take it from PGDATABASE environment variable if 'db' in keywords: keywords['database'] = keywords.pop('db') self.dbname = "postgres" self.paramstyle = db_module.paramstyle DB.__init__(self, db_module, keywords) self.supports_multiple_insert = True self._sequences = None def _process_insert_query(self, query, tablename, seqname): if seqname is None: # when seqname is not provided guess the seqname and make sure it exists seqname = tablename + "_id_seq" if seqname not in self._get_all_sequences(): seqname = None if seqname: query += "; SELECT currval('%s')" % seqname return query def _get_all_sequences(self): """Query postgres to find names of all sequences used in this database.""" if self._sequences is None: q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'" self._sequences = set([c.relname for c in self.query(q)]) return self._sequences def _connect(self, keywords): conn = DB._connect(self, keywords) try: conn.set_client_encoding('UTF8') except AttributeError: # fallback for pgdb driver conn.cursor().execute("set client_encoding to 'UTF-8'") return conn def _connect_with_pooling(self, keywords): conn = DB._connect_with_pooling(self, keywords) conn._con._con.set_client_encoding('UTF8') return conn class MySQLDB(DB): def __init__(self, **keywords): import MySQLdb as db if 'pw' in keywords: keywords['passwd'] = keywords['pw'] del keywords['pw'] if 'charset' not in keywords: keywords['charset'] = 'utf8' elif keywords['charset'] is None: del keywords['charset'] self.paramstyle = db.paramstyle = 'pyformat' # it's both, like psycopg self.dbname = "mysql" DB.__init__(self, db, keywords) self.supports_multiple_insert = True def _process_insert_query(self, query, tablename, seqname): return query, SQLQuery('SELECT last_insert_id();') def _get_insert_default_values_query(self, table): return "INSERT INTO %s () VALUES()" % table def import_driver(drivers, preferred=None): """Import the first available driver or preferred driver. """ if preferred: drivers = [preferred] for d in drivers: try: return __import__(d, None, None, ['x']) except ImportError: pass raise ImportError("Unable to import " + " or ".join(drivers)) class SqliteDB(DB): def __init__(self, **keywords): db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None)) if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]: db.paramstyle = 'qmark' # sqlite driver doesn't create datatime objects for timestamp columns unless `detect_types` option is passed. # It seems to be supported in sqlite3 and pysqlite2 drivers, not surte about sqlite. keywords.setdefault('detect_types', db.PARSE_DECLTYPES) self.paramstyle = db.paramstyle keywords['database'] = keywords.pop('db') keywords['pooling'] = False # sqlite don't allows connections to be shared by threads self.dbname = "sqlite" DB.__init__(self, db, keywords) def _process_insert_query(self, query, tablename, seqname): return query, SQLQuery('SELECT last_insert_rowid();') def query(self, *a, **kw): out = DB.query(self, *a, **kw) if isinstance(out, iterbetter): del out.__len__ return out class FirebirdDB(DB): """Firebird Database. """ def __init__(self, **keywords): try: import kinterbasdb as db except Exception: db = None pass if 'pw' in keywords: keywords['passwd'] = keywords['pw'] del keywords['pw'] keywords['database'] = keywords['db'] del keywords['db'] DB.__init__(self, db, keywords) def delete(self, table, where=None, using=None, vars=None, _test=False): # firebird doesn't support using clause using=None return DB.delete(self, table, where, using, vars, _test) def sql_clauses(self, what, tables, where, group, order, limit, offset): return ( ('SELECT', ''), ('FIRST', limit), ('SKIP', offset), ('', what), ('FROM', sqllist(tables)), ('WHERE', where), ('GROUP BY', group), ('ORDER BY', order) ) class MSSQLDB(DB): def __init__(self, **keywords): import pymssql as db if 'pw' in keywords: keywords['password'] = keywords.pop('pw') keywords['database'] = keywords.pop('db') self.dbname = "mssql" DB.__init__(self, db, keywords) def _process_query(self, sql_query): """Takes the SQLQuery object and returns query string and parameters. """ # MSSQLDB expects params to be a tuple. # Overwriting the default implementation to convert params to tuple. paramstyle = getattr(self, 'paramstyle', 'pyformat') query = sql_query.query(paramstyle) params = sql_query.values() return query, tuple(params) def sql_clauses(self, what, tables, where, group, order, limit, offset): return ( ('SELECT', what), ('TOP', limit), ('FROM', sqllist(tables)), ('WHERE', where), ('GROUP BY', group), ('ORDER BY', order), ('OFFSET', offset)) def _test(self): """Test LIMIT. Fake presence of pymssql module for running tests. >>> import sys >>> sys.modules['pymssql'] = sys.modules['sys'] MSSQL has TOP clause instead of LIMIT clause. >>> db = MSSQLDB(db='test', user='joe', pw='secret') >>> db.select('foo', limit=4, _test=True) """ pass class OracleDB(DB): def __init__(self, **keywords): import cx_Oracle as db if 'pw' in keywords: keywords['password'] = keywords.pop('pw') #@@ TODO: use db.makedsn if host, port is specified keywords['dsn'] = keywords.pop('db') self.dbname = 'oracle' db.paramstyle = 'numeric' self.paramstyle = db.paramstyle # oracle doesn't support pooling keywords.pop('pooling', None) DB.__init__(self, db, keywords) def _process_insert_query(self, query, tablename, seqname): if seqname is None: # It is not possible to get seq name from table name in Oracle return query else: return query + "; SELECT %s.currval FROM dual" % seqname _databases = {} def database(dburl=None, **params): """Creates appropriate database using params. Pooling will be enabled if DBUtils module is available. Pooling can be disabled by passing pooling=False in params. """ dbn = params.pop('dbn') if dbn in _databases: return _databases[dbn](**params) else: raise UnknownDB, dbn def register_database(name, clazz): """ Register a database. >>> class LegacyDB(DB): ... def __init__(self, **params): ... pass ... >>> register_database('legacy', LegacyDB) >>> db = database(dbn='legacy', db='test', user='joe', passwd='secret') """ _databases[name] = clazz register_database('mysql', MySQLDB) register_database('postgres', PostgresDB) register_database('sqlite', SqliteDB) register_database('firebird', FirebirdDB) register_database('mssql', MSSQLDB) register_database('oracle', OracleDB) def _interpolate(format): """ Takes a format string and returns a list of 2-tuples of the form (boolean, string) where boolean says whether string should be evaled or not. from (public domain, Ka-Ping Yee) """ from tokenize import tokenprog def matchorfail(text, pos): match = tokenprog.match(text, pos) if match is None: raise _ItplError(text, pos) return match, match.end() namechars = "abcdefghijklmnopqrstuvwxyz" \ "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; chunks = [] pos = 0 while 1: dollar = format.find("$", pos) if dollar < 0: break nextchar = format[dollar + 1] if nextchar == "{": chunks.append((0, format[pos:dollar])) pos, level = dollar + 2, 1 while level: match, pos = matchorfail(format, pos) tstart, tend = match.regs[3] token = format[tstart:tend] if token == "{": level = level + 1 elif token == "}": level = level - 1 chunks.append((1, format[dollar + 2:pos - 1])) elif nextchar in namechars: chunks.append((0, format[pos:dollar])) match, pos = matchorfail(format, dollar + 1) while pos < len(format): if format[pos] == "." and \ pos + 1 < len(format) and format[pos + 1] in namechars: match, pos = matchorfail(format, pos + 1) elif format[pos] in "([": pos, level = pos + 1, 1 while level: match, pos = matchorfail(format, pos) tstart, tend = match.regs[3] token = format[tstart:tend] if token[0] in "([": level = level + 1 elif token[0] in ")]": level = level - 1 else: break chunks.append((1, format[dollar + 1:pos])) else: chunks.append((0, format[pos:dollar + 1])) pos = dollar + 1 + (nextchar == "$") if pos < len(format): chunks.append((0, format[pos:])) return chunks if __name__ == "__main__": import doctest doctest.testmod()