Skip to content
Snippets Groups Projects
Commit 805a0ae5 authored by Thilo Kogge's avatar Thilo Kogge
Browse files

updated peewee, added peewee playhouse, switched to sqliteQ connector to avoid...

updated peewee, added peewee playhouse, switched to sqliteQ connector to avoid multithreading problems with DB access, removed deferred foreig field, using cascade for deletion, added tests, added migration
parent 93de136e
No related branches found
No related tags found
2 merge requests!70fixed 304 logic,!69db model updates
Showing
with 4718 additions and 8 deletions
......@@ -50,6 +50,7 @@ DISTFILES += qml/podqast.qml \
python/mygpoclient/util.py \
python/feedparser/ \
python/peewee.py \
python/playhouse/ \
qml/params.yml \
qml/cover/CoverPage.qml \
qml/pages/GpodderNetPython.qml \
......
......@@ -70,7 +70,7 @@ except ImportError:
mysql = None
__version__ = '3.14.2'
__version__ = '3.14.4'
__all__ = [
'AsIs',
'AutoField',
......@@ -1457,6 +1457,7 @@ class Expression(ColumnBase):
# Set up the appropriate converter if we have a field on the left side.
if isinstance(node, Field) and raw_node._coerce:
overrides['converter'] = node.db_value
overrides['is_fk_expr'] = isinstance(node, ForeignKeyField)
else:
overrides['converter'] = None
......@@ -2633,11 +2634,15 @@ class Insert(_WriteQuery):
if col not in seen:
columns.append(col)
nullable_columns = set()
value_lookups = {}
for column in columns:
lookups = [column, column.name]
if isinstance(column, Field) and column.name != column.column_name:
lookups.append(column.column_name)
if isinstance(column, Field):
if column.name != column.column_name:
lookups.append(column.column_name)
if column.null:
nullable_columns.add(column)
value_lookups[column] = lookups
ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ')
......@@ -2671,6 +2676,8 @@ class Insert(_WriteQuery):
val = defaults[column]
if callable_(val):
val = val()
elif column in nullable_columns:
val = None
else:
raise ValueError('Missing value for %s.' % column.name)
......@@ -6613,11 +6620,12 @@ class Model(with_metaclass(ModelBase, Node)):
# model instance will return the wrong value; since we would return
# the primary key for a given model instance.
#
# This checks to see if we have a converter in the scope, and if so,
# hands the model instance to the converter rather than blindly
# grabbing the primary-key. In the event the provided converter fails
# to handle the model instance, then we will return the primary-key.
if ctx.state.converter is not None:
# This checks to see if we have a converter in the scope, and that we
# are converting a foreign-key expression. If so, we hand the model
# instance to the converter rather than blindly grabbing the primary-
# key. In the event the provided converter fails to handle the model
# instance, then we will return the primary-key.
if ctx.state.converter is not None and ctx.state.is_fk_expr:
try:
return ctx.sql(Value(self, converter=ctx.state.converter))
except (TypeError, ValueError):
......
## Playhouse
The `playhouse` namespace contains numerous extensions to Peewee. These include vendor-specific database extensions, high-level abstractions to simplify working with databases, and tools for low-level database operations and introspection.
### Vendor extensions
* [SQLite extensions](http://docs.peewee-orm.com/en/latest/peewee/sqlite_ext.html)
* Full-text search (FTS3/4/5)
* BM25 ranking algorithm implemented as SQLite C extension, backported to FTS4
* Virtual tables and C extensions
* Closure tables
* JSON extension support
* LSM1 (key/value database) support
* BLOB API
* Online backup API
* [APSW extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#apsw): use Peewee with the powerful [APSW](https://github.com/rogerbinns/apsw) SQLite driver.
* [SQLCipher](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqlcipher-ext): encrypted SQLite databases.
* [SqliteQ](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqliteq): dedicated writer thread for multi-threaded SQLite applications. [More info here](http://charlesleifer.com/blog/multi-threaded-sqlite-without-the-operationalerrors/).
* [Postgresql extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#postgres-ext)
* JSON and JSONB
* HStore
* Arrays
* Server-side cursors
* Full-text search
* [MySQL extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#mysql-ext)
### High-level libraries
* [Extra fields](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#extra-fields)
* Compressed field
* PickleField
* [Shortcuts / helpers](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#shortcuts)
* Model-to-dict serializer
* Dict-to-model deserializer
* [Hybrid attributes](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#hybrid)
* [Signals](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#signals): pre/post-save, pre/post-delete, pre-init.
* [Dataset](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#dataset): high-level API for working with databases popuarlized by the [project of the same name](https://dataset.readthedocs.io/).
* [Key/Value Store](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#kv): key/value store using SQLite. Supports *smart indexing*, for *Pandas*-style queries.
### Database management and framework support
* [pwiz](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pwiz): generate model code from a pre-existing database.
* [Schema migrations](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#migrate): modify your schema using high-level APIs. Even supports dropping or renaming columns in SQLite.
* [Connection pool](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pool): simple connection pooling.
* [Reflection](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#reflection): low-level, cross-platform database introspection
* [Database URLs](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#db-url): use URLs to connect to database
* [Test utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#test-utils): helpers for unit-testing Peewee applications.
* [Flask utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#flask-utils): paginated object lists, database connection management, and more.
/* cache.h - definitions for the LRU cache
*
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
*
* This file is part of pysqlite.
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would be
* appreciated but is not required.
* 2. Altered source versions must be plainly marked as such, and must not be
* misrepresented as being the original software.
* 3. This notice may not be removed or altered from any source distribution.
*/
#ifndef PYSQLITE_CACHE_H
#define PYSQLITE_CACHE_H
#include "Python.h"
/* The LRU cache is implemented as a combination of a doubly-linked with a
* dictionary. The list items are of type 'Node' and the dictionary has the
* nodes as values. */
typedef struct _pysqlite_Node
{
PyObject_HEAD
PyObject* key;
PyObject* data;
long count;
struct _pysqlite_Node* prev;
struct _pysqlite_Node* next;
} pysqlite_Node;
typedef struct
{
PyObject_HEAD
int size;
/* a dictionary mapping keys to Node entries */
PyObject* mapping;
/* the factory callable */
PyObject* factory;
pysqlite_Node* first;
pysqlite_Node* last;
/* if set, decrement the factory function when the Cache is deallocated.
* this is almost always desirable, but not in the pysqlite context */
int decref_factory;
} pysqlite_Cache;
extern PyTypeObject pysqlite_NodeType;
extern PyTypeObject pysqlite_CacheType;
int pysqlite_node_init(pysqlite_Node* self, PyObject* args, PyObject* kwargs);
void pysqlite_node_dealloc(pysqlite_Node* self);
int pysqlite_cache_init(pysqlite_Cache* self, PyObject* args, PyObject* kwargs);
void pysqlite_cache_dealloc(pysqlite_Cache* self);
PyObject* pysqlite_cache_get(pysqlite_Cache* self, PyObject* args);
int pysqlite_cache_setup_types(void);
#endif
/* connection.h - definitions for the connection type
*
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
*
* This file is part of pysqlite.
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would be
* appreciated but is not required.
* 2. Altered source versions must be plainly marked as such, and must not be
* misrepresented as being the original software.
* 3. This notice may not be removed or altered from any source distribution.
*/
#ifndef PYSQLITE_CONNECTION_H
#define PYSQLITE_CONNECTION_H
#include "Python.h"
#include "pythread.h"
#include "structmember.h"
#include "cache.h"
#include "module.h"
#include "sqlite3.h"
typedef struct
{
PyObject_HEAD
sqlite3* db;
/* the type detection mode. Only 0, PARSE_DECLTYPES, PARSE_COLNAMES or a
* bitwise combination thereof makes sense */
int detect_types;
/* the timeout value in seconds for database locks */
double timeout;
/* for internal use in the timeout handler: when did the timeout handler
* first get called with count=0? */
double timeout_started;
/* None for autocommit, otherwise a PyString with the isolation level */
PyObject* isolation_level;
/* NULL for autocommit, otherwise a string with the BEGIN statement; will be
* freed in connection destructor */
char* begin_statement;
/* 1 if a check should be performed for each API call if the connection is
* used from the same thread it was created in */
int check_same_thread;
int initialized;
/* thread identification of the thread the connection was created in */
long thread_ident;
pysqlite_Cache* statement_cache;
/* Lists of weak references to statements and cursors used within this connection */
PyObject* statements;
PyObject* cursors;
/* Counters for how many statements/cursors were created in the connection. May be
* reset to 0 at certain intervals */
int created_statements;
int created_cursors;
PyObject* row_factory;
/* Determines how bytestrings from SQLite are converted to Python objects:
* - PyUnicode_Type: Python Unicode objects are constructed from UTF-8 bytestrings
* - OptimizedUnicode: Like before, but for ASCII data, only PyStrings are created.
* - PyString_Type: PyStrings are created as-is.
* - Any custom callable: Any object returned from the callable called with the bytestring
* as single parameter.
*/
PyObject* text_factory;
/* remember references to functions/classes used in
* create_function/create/aggregate, use these as dictionary keys, so we
* can keep the total system refcount constant by clearing that dictionary
* in connection_dealloc */
PyObject* function_pinboard;
/* a dictionary of registered collation name => collation callable mappings */
PyObject* collations;
/* Exception objects */
PyObject* Warning;
PyObject* Error;
PyObject* InterfaceError;
PyObject* DatabaseError;
PyObject* DataError;
PyObject* OperationalError;
PyObject* IntegrityError;
PyObject* InternalError;
PyObject* ProgrammingError;
PyObject* NotSupportedError;
} pysqlite_Connection;
extern PyTypeObject pysqlite_ConnectionType;
PyObject* pysqlite_connection_alloc(PyTypeObject* type, int aware);
void pysqlite_connection_dealloc(pysqlite_Connection* self);
PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs);
PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args);
PyObject* _pysqlite_connection_begin(pysqlite_Connection* self);
PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args);
PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args);
PyObject* pysqlite_connection_new(PyTypeObject* type, PyObject* args, PyObject* kw);
int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject* kwargs);
int pysqlite_connection_register_cursor(pysqlite_Connection* connection, PyObject* cursor);
int pysqlite_check_thread(pysqlite_Connection* self);
int pysqlite_check_connection(pysqlite_Connection* con);
int pysqlite_connection_setup_types(void);
#endif
/* module.h - definitions for the module
*
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
*
* This file is part of pysqlite.
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would be
* appreciated but is not required.
* 2. Altered source versions must be plainly marked as such, and must not be
* misrepresented as being the original software.
* 3. This notice may not be removed or altered from any source distribution.
*/
#ifndef PYSQLITE_MODULE_H
#define PYSQLITE_MODULE_H
#include "Python.h"
#define PYSQLITE_VERSION "2.8.2"
extern PyObject* pysqlite_Error;
extern PyObject* pysqlite_Warning;
extern PyObject* pysqlite_InterfaceError;
extern PyObject* pysqlite_DatabaseError;
extern PyObject* pysqlite_InternalError;
extern PyObject* pysqlite_OperationalError;
extern PyObject* pysqlite_ProgrammingError;
extern PyObject* pysqlite_IntegrityError;
extern PyObject* pysqlite_DataError;
extern PyObject* pysqlite_NotSupportedError;
extern PyObject* pysqlite_OptimizedUnicode;
/* the functions time.time() and time.sleep() */
extern PyObject* time_time;
extern PyObject* time_sleep;
/* A dictionary, mapping colum types (INTEGER, VARCHAR, etc.) to converter
* functions, that convert the SQL value to the appropriate Python value.
* The key is uppercase.
*/
extern PyObject* converters;
extern int _enable_callback_tracebacks;
extern int pysqlite_BaseTypeAdapted;
#define PARSE_DECLTYPES 1
#define PARSE_COLNAMES 2
#endif
This diff is collapsed.
import sys
from difflib import SequenceMatcher
from random import randint
IS_PY3K = sys.version_info[0] == 3
# String UDF.
def damerau_levenshtein_dist(s1, s2):
cdef:
int i, j, del_cost, add_cost, sub_cost
int s1_len = len(s1), s2_len = len(s2)
list one_ago, two_ago, current_row
list zeroes = [0] * (s2_len + 1)
if IS_PY3K:
current_row = list(range(1, s2_len + 2))
else:
current_row = range(1, s2_len + 2)
current_row[-1] = 0
one_ago = None
for i in range(s1_len):
two_ago = one_ago
one_ago = current_row
current_row = list(zeroes)
current_row[-1] = i + 1
for j in range(s2_len):
del_cost = one_ago[j] + 1
add_cost = current_row[j - 1] + 1
sub_cost = one_ago[j - 1] + (s1[i] != s2[j])
current_row[j] = min(del_cost, add_cost, sub_cost)
# Handle transpositions.
if (i > 0 and j > 0 and s1[i] == s2[j - 1]
and s1[i-1] == s2[j] and s1[i] != s2[j]):
current_row[j] = min(current_row[j], two_ago[j - 2] + 1)
return current_row[s2_len - 1]
# String UDF.
def levenshtein_dist(a, b):
cdef:
int add, delete, change
int i, j
int n = len(a), m = len(b)
list current, previous
list zeroes
if n > m:
a, b = b, a
n, m = m, n
zeroes = [0] * (m + 1)
if IS_PY3K:
current = list(range(n + 1))
else:
current = range(n + 1)
for i in range(1, m + 1):
previous = current
current = list(zeroes)
current[0] = i
for j in range(1, n + 1):
add = previous[j] + 1
delete = current[j - 1] + 1
change = previous[j - 1]
if a[j - 1] != b[i - 1]:
change +=1
current[j] = min(add, delete, change)
return current[n]
# String UDF.
def str_dist(a, b):
cdef:
int t = 0
for i in SequenceMatcher(None, a, b).get_opcodes():
if i[0] == 'equal':
continue
t = t + max(i[4] - i[3], i[2] - i[1])
return t
# Math Aggregate.
cdef class median(object):
cdef:
int ct
list items
def __init__(self):
self.ct = 0
self.items = []
cdef selectKth(self, int k, int s=0, int e=-1):
cdef:
int idx
if e < 0:
e = len(self.items)
idx = randint(s, e-1)
idx = self.partition_k(idx, s, e)
if idx > k:
return self.selectKth(k, s, idx)
elif idx < k:
return self.selectKth(k, idx + 1, e)
else:
return self.items[idx]
cdef int partition_k(self, int pi, int s, int e):
cdef:
int i, x
val = self.items[pi]
# Swap pivot w/last item.
self.items[e - 1], self.items[pi] = self.items[pi], self.items[e - 1]
x = s
for i in range(s, e):
if self.items[i] < val:
self.items[i], self.items[x] = self.items[x], self.items[i]
x += 1
self.items[x], self.items[e-1] = self.items[e-1], self.items[x]
return x
def step(self, item):
self.items.append(item)
self.ct += 1
def finalize(self):
if self.ct == 0:
return None
elif self.ct < 3:
return self.items[0]
else:
return self.selectKth(self.ct / 2)
"""
Peewee integration with APSW, "another python sqlite wrapper".
Project page: https://rogerbinns.github.io/apsw/
APSW is a really neat library that provides a thin wrapper on top of SQLite's
C interface.
Here are just a few reasons to use APSW, taken from the documentation:
* APSW gives all functionality of SQLite, including virtual tables, virtual
file system, blob i/o, backups and file control.
* Connections can be shared across threads without any additional locking.
* Transactions are managed explicitly by your code.
* APSW can handle nested transactions.
* Unicode is handled correctly.
* APSW is faster.
"""
import apsw
from peewee import *
from peewee import __exception_wrapper__
from peewee import BooleanField as _BooleanField
from peewee import DateField as _DateField
from peewee import DateTimeField as _DateTimeField
from peewee import DecimalField as _DecimalField
from peewee import TimeField as _TimeField
from peewee import logger
from playhouse.sqlite_ext import SqliteExtDatabase
class APSWDatabase(SqliteExtDatabase):
server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.'))
def __init__(self, database, **kwargs):
self._modules = {}
super(APSWDatabase, self).__init__(database, **kwargs)
def register_module(self, mod_name, mod_inst):
self._modules[mod_name] = mod_inst
if not self.is_closed():
self.connection().createmodule(mod_name, mod_inst)
def unregister_module(self, mod_name):
del(self._modules[mod_name])
def _connect(self):
conn = apsw.Connection(self.database, **self.connect_params)
if self._timeout is not None:
conn.setbusytimeout(self._timeout * 1000)
try:
self._add_conn_hooks(conn)
except:
conn.close()
raise
return conn
def _add_conn_hooks(self, conn):
super(APSWDatabase, self)._add_conn_hooks(conn)
self._load_modules(conn) # APSW-only.
def _load_modules(self, conn):
for mod_name, mod_inst in self._modules.items():
conn.createmodule(mod_name, mod_inst)
return conn
def _load_aggregates(self, conn):
for name, (klass, num_params) in self._aggregates.items():
def make_aggregate():
return (klass(), klass.step, klass.finalize)
conn.createaggregatefunction(name, make_aggregate)
def _load_collations(self, conn):
for name, fn in self._collations.items():
conn.createcollation(name, fn)
def _load_functions(self, conn):
for name, (fn, num_params) in self._functions.items():
conn.createscalarfunction(name, fn, num_params)
def _load_extensions(self, conn):
conn.enableloadextension(True)
for extension in self._extensions:
conn.loadextension(extension)
def load_extension(self, extension):
self._extensions.add(extension)
if not self.is_closed():
conn = self.connection()
conn.enableloadextension(True)
conn.loadextension(extension)
def last_insert_id(self, cursor, query_type=None):
return cursor.getconnection().last_insert_rowid()
def rows_affected(self, cursor):
return cursor.getconnection().changes()
def begin(self, lock_type='deferred'):
self.cursor().execute('begin %s;' % lock_type)
def commit(self):
with __exception_wrapper__:
curs = self.cursor()
if curs.getconnection().getautocommit():
return False
curs.execute('commit;')
return True
def rollback(self):
with __exception_wrapper__:
curs = self.cursor()
if curs.getconnection().getautocommit():
return False
curs.execute('rollback;')
return True
def execute_sql(self, sql, params=None, commit=True):
logger.debug((sql, params))
with __exception_wrapper__:
cursor = self.cursor()
cursor.execute(sql, params or ())
return cursor
def nh(s, v):
if v is not None:
return str(v)
class BooleanField(_BooleanField):
def db_value(self, v):
v = super(BooleanField, self).db_value(v)
if v is not None:
return v and 1 or 0
class DateField(_DateField):
db_value = nh
class TimeField(_TimeField):
db_value = nh
class DateTimeField(_DateTimeField):
db_value = nh
class DecimalField(_DecimalField):
db_value = nh
import functools
import re
from peewee import *
from peewee import _atomic
from peewee import _manual
from peewee import ColumnMetadata # (name, data_type, null, primary_key, table, default)
from peewee import ForeignKeyMetadata # (column, dest_table, dest_column, table).
from peewee import IndexMetadata
from playhouse.pool import _PooledPostgresqlDatabase
try:
from playhouse.postgres_ext import ArrayField
from playhouse.postgres_ext import BinaryJSONField
from playhouse.postgres_ext import IntervalField
JSONField = BinaryJSONField
except ImportError: # psycopg2 not installed, ignore.
ArrayField = BinaryJSONField = IntervalField = JSONField = None
NESTED_TX_MIN_VERSION = 200100
TXN_ERR_MSG = ('CockroachDB does not support nested transactions. You may '
'alternatively use the @transaction context-manager/decorator, '
'which only wraps the outer-most block in transactional logic. '
'To run a transaction with automatic retries, use the '
'run_transaction() helper.')
class ExceededMaxAttempts(OperationalError): pass
class UUIDKeyField(UUIDField):
auto_increment = True
def __init__(self, *args, **kwargs):
if kwargs.get('constraints'):
raise ValueError('%s cannot specify constraints.' % type(self))
kwargs['constraints'] = [SQL('DEFAULT gen_random_uuid()')]
kwargs.setdefault('primary_key', True)
super(UUIDKeyField, self).__init__(*args, **kwargs)
class RowIDField(AutoField):
field_type = 'INT'
def __init__(self, *args, **kwargs):
if kwargs.get('constraints'):
raise ValueError('%s cannot specify constraints.' % type(self))
kwargs['constraints'] = [SQL('DEFAULT unique_rowid()')]
super(RowIDField, self).__init__(*args, **kwargs)
class CockroachDatabase(PostgresqlDatabase):
field_types = PostgresqlDatabase.field_types.copy()
field_types.update({
'BLOB': 'BYTES',
})
for_update = False
nulls_ordering = False
release_after_rollback = True
def __init__(self, *args, **kwargs):
kwargs.setdefault('user', 'root')
kwargs.setdefault('port', 26257)
super(CockroachDatabase, self).__init__(*args, **kwargs)
def _set_server_version(self, conn):
curs = conn.cursor()
curs.execute('select version()')
raw, = curs.fetchone()
match_obj = re.match(r'^CockroachDB.+?v(\d+)\.(\d+)\.(\d+)', raw)
if match_obj is not None:
clean = '%d%02d%02d' % tuple(int(i) for i in match_obj.groups())
self.server_version = int(clean) # 19.1.5 -> 190105.
else:
# Fallback to use whatever cockroachdb tells us via protocol.
super(CockroachDatabase, self)._set_server_version(conn)
def _get_pk_constraint(self, table, schema=None):
query = ('SELECT constraint_name '
'FROM information_schema.table_constraints '
'WHERE table_name = %s AND table_schema = %s '
'AND constraint_type = %s')
cursor = self.execute_sql(query, (table, schema or 'public',
'PRIMARY KEY'))
row = cursor.fetchone()
return row and row[0] or None
def get_indexes(self, table, schema=None):
# The primary-key index is returned by default, so we will just strip
# it out here.
indexes = super(CockroachDatabase, self).get_indexes(table, schema)
pkc = self._get_pk_constraint(table, schema)
return [idx for idx in indexes if (not pkc) or (idx.name != pkc)]
def conflict_statement(self, on_conflict, query):
if not on_conflict._action: return
action = on_conflict._action.lower()
if action in ('replace', 'upsert'):
return SQL('UPSERT')
elif action not in ('ignore', 'nothing', 'update'):
raise ValueError('Un-supported action for conflict resolution. '
'CockroachDB supports REPLACE (UPSERT), IGNORE '
'and UPDATE.')
def conflict_update(self, oc, query):
action = oc._action.lower() if oc._action else ''
if action in ('ignore', 'nothing'):
return SQL('ON CONFLICT DO NOTHING')
elif action in ('replace', 'upsert'):
# No special stuff is necessary, this is just indicated by starting
# the statement with UPSERT instead of INSERT.
return
elif oc._conflict_constraint:
raise ValueError('CockroachDB does not support the usage of a '
'constraint name. Use the column(s) instead.')
return super(CockroachDatabase, self).conflict_update(oc, query)
def extract_date(self, date_part, date_field):
return fn.extract(date_part, date_field)
def from_timestamp(self, date_field):
# CRDB does not allow casting a decimal/float to timestamp, so we first
# cast to int, then to timestamptz.
return date_field.cast('int').cast('timestamptz')
def begin(self, system_time=None, priority=None):
super(CockroachDatabase, self).begin()
if system_time is not None:
self.execute_sql('SET TRANSACTION AS OF SYSTEM TIME %s',
(system_time,), commit=False)
if priority is not None:
priority = priority.lower()
if priority not in ('low', 'normal', 'high'):
raise ValueError('priority must be low, normal or high')
self.execute_sql('SET TRANSACTION PRIORITY %s' % priority,
commit=False)
def atomic(self, system_time=None, priority=None):
if self.server_version < NESTED_TX_MIN_VERSION:
return _crdb_atomic(self, system_time, priority)
return super(CockroachDatabase, self).atomic(system_time, priority)
def savepoint(self):
if self.server_version < NESTED_TX_MIN_VERSION:
raise NotImplementedError(TXN_ERR_MSG)
return super(CockroachDatabase, self).savepoint()
def retry_transaction(self, max_attempts=None, system_time=None,
priority=None):
def deco(cb):
@functools.wraps(cb)
def new_fn():
return run_transaction(self, cb, max_attempts, system_time,
priority)
return new_fn
return deco
def run_transaction(self, cb, max_attempts=None, system_time=None,
priority=None):
return run_transaction(self, cb, max_attempts, system_time, priority)
class _crdb_atomic(_atomic):
def __enter__(self):
if self.db.transaction_depth() > 0:
if not isinstance(self.db.top_transaction(), _manual):
raise NotImplementedError(TXN_ERR_MSG)
return super(_crdb_atomic, self).__enter__()
def run_transaction(db, callback, max_attempts=None, system_time=None,
priority=None):
"""
Run transactional SQL in a transaction with automatic retries.
User-provided `callback`:
* Must accept one parameter, the `db` instance representing the connection
the transaction is running under.
* Must not attempt to commit, rollback or otherwise manage transactions.
* May be called more than once.
* Should ideally only contain SQL operations.
Additionally, the database must not have any open transaction at the time
this function is called, as CRDB does not support nested transactions.
"""
max_attempts = max_attempts or -1
with db.atomic(system_time=system_time, priority=priority) as txn:
db.execute_sql('SAVEPOINT cockroach_restart')
while max_attempts != 0:
try:
result = callback(db)
db.execute_sql('RELEASE SAVEPOINT cockroach_restart')
return result
except OperationalError as exc:
if exc.orig.pgcode == '40001':
max_attempts -= 1
db.execute_sql('ROLLBACK TO SAVEPOINT cockroach_restart')
continue
raise
raise ExceededMaxAttempts(None, 'unable to commit transaction')
class PooledCockroachDatabase(_PooledPostgresqlDatabase, CockroachDatabase):
pass
import csv
import datetime
from decimal import Decimal
import json
import operator
try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
import sys
from peewee import *
from playhouse.db_url import connect
from playhouse.migrate import migrate
from playhouse.migrate import SchemaMigrator
from playhouse.reflection import Introspector
if sys.version_info[0] == 3:
basestring = str
from functools import reduce
def open_file(f, mode):
return open(f, mode, encoding='utf8')
else:
open_file = open
class DataSet(object):
def __init__(self, url, **kwargs):
if isinstance(url, Database):
self._url = None
self._database = url
self._database_path = self._database.database
else:
self._url = url
parse_result = urlparse(url)
self._database_path = parse_result.path[1:]
# Connect to the database.
self._database = connect(url)
self._database.connect()
# Introspect the database and generate models.
self._introspector = Introspector.from_database(self._database)
self._models = self._introspector.generate_models(
skip_invalid=True,
literal_column_names=True,
**kwargs)
self._migrator = SchemaMigrator.from_database(self._database)
class BaseModel(Model):
class Meta:
database = self._database
self._base_model = BaseModel
self._export_formats = self.get_export_formats()
self._import_formats = self.get_import_formats()
def __repr__(self):
return '<DataSet: %s>' % self._database_path
def get_export_formats(self):
return {
'csv': CSVExporter,
'json': JSONExporter,
'tsv': TSVExporter}
def get_import_formats(self):
return {
'csv': CSVImporter,
'json': JSONImporter,
'tsv': TSVImporter}
def __getitem__(self, table):
if table not in self._models and table in self.tables:
self.update_cache(table)
return Table(self, table, self._models.get(table))
@property
def tables(self):
return self._database.get_tables()
def __contains__(self, table):
return table in self.tables
def connect(self):
self._database.connect()
def close(self):
self._database.close()
def update_cache(self, table=None):
if table:
dependencies = [table]
if table in self._models:
model_class = self._models[table]
dependencies.extend([
related._meta.table_name for _, related, _ in
model_class._meta.model_graph()])
else:
dependencies.extend(self.get_table_dependencies(table))
else:
dependencies = None # Update all tables.
self._models = {}
updated = self._introspector.generate_models(
skip_invalid=True,
table_names=dependencies,
literal_column_names=True)
self._models.update(updated)
def get_table_dependencies(self, table):
stack = [table]
accum = []
seen = set()
while stack:
table = stack.pop()
for fk_meta in self._database.get_foreign_keys(table):
dest = fk_meta.dest_table
if dest not in seen:
stack.append(dest)
accum.append(dest)
return accum
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if not self._database.is_closed():
self.close()
def query(self, sql, params=None, commit=True):
return self._database.execute_sql(sql, params, commit)
def transaction(self):
if self._database.transaction_depth() == 0:
return self._database.transaction()
else:
return self._database.savepoint()
def _check_arguments(self, filename, file_obj, format, format_dict):
if filename and file_obj:
raise ValueError('file is over-specified. Please use either '
'filename or file_obj, but not both.')
if not filename and not file_obj:
raise ValueError('A filename or file-like object must be '
'specified.')
if format not in format_dict:
valid_formats = ', '.join(sorted(format_dict.keys()))
raise ValueError('Unsupported format "%s". Use one of %s.' % (
format, valid_formats))
def freeze(self, query, format='csv', filename=None, file_obj=None,
**kwargs):
self._check_arguments(filename, file_obj, format, self._export_formats)
if filename:
file_obj = open_file(filename, 'w')
exporter = self._export_formats[format](query)
exporter.export(file_obj, **kwargs)
if filename:
file_obj.close()
def thaw(self, table, format='csv', filename=None, file_obj=None,
strict=False, **kwargs):
self._check_arguments(filename, file_obj, format, self._export_formats)
if filename:
file_obj = open_file(filename, 'r')
importer = self._import_formats[format](self[table], strict)
count = importer.load(file_obj, **kwargs)
if filename:
file_obj.close()
return count
class Table(object):
def __init__(self, dataset, name, model_class):
self.dataset = dataset
self.name = name
if model_class is None:
model_class = self._create_model()
model_class.create_table()
self.dataset._models[name] = model_class
@property
def model_class(self):
return self.dataset._models[self.name]
def __repr__(self):
return '<Table: %s>' % self.name
def __len__(self):
return self.find().count()
def __iter__(self):
return iter(self.find().iterator())
def _create_model(self):
class Meta:
table_name = self.name
return type(
str(self.name),
(self.dataset._base_model,),
{'Meta': Meta})
def create_index(self, columns, unique=False):
index = ModelIndex(self.model_class, columns, unique=unique)
self.model_class.add_index(index)
self.dataset._database.execute(index)
def _guess_field_type(self, value):
if isinstance(value, basestring):
return TextField
if isinstance(value, (datetime.date, datetime.datetime)):
return DateTimeField
elif value is True or value is False:
return BooleanField
elif isinstance(value, int):
return IntegerField
elif isinstance(value, float):
return FloatField
elif isinstance(value, Decimal):
return DecimalField
return TextField
@property
def columns(self):
return [f.name for f in self.model_class._meta.sorted_fields]
def _migrate_new_columns(self, data):
new_keys = set(data) - set(self.model_class._meta.fields)
if new_keys:
operations = []
for key in new_keys:
field_class = self._guess_field_type(data[key])
field = field_class(null=True)
operations.append(
self.dataset._migrator.add_column(self.name, key, field))
field.bind(self.model_class, key)
migrate(*operations)
self.dataset.update_cache(self.name)
def __getitem__(self, item):
try:
return self.model_class[item]
except self.model_class.DoesNotExist:
pass
def __setitem__(self, item, value):
if not isinstance(value, dict):
raise ValueError('Table.__setitem__() value must be a dict')
pk = self.model_class._meta.primary_key
value[pk.name] = item
try:
with self.dataset.transaction() as txn:
self.insert(**value)
except IntegrityError:
self.dataset.update_cache(self.name)
self.update(columns=[pk.name], **value)
def __delitem__(self, item):
del self.model_class[item]
def insert(self, **data):
self._migrate_new_columns(data)
return self.model_class.insert(**data).execute()
def _apply_where(self, query, filters, conjunction=None):
conjunction = conjunction or operator.and_
if filters:
expressions = [
(self.model_class._meta.fields[column] == value)
for column, value in filters.items()]
query = query.where(reduce(conjunction, expressions))
return query
def update(self, columns=None, conjunction=None, **data):
self._migrate_new_columns(data)
filters = {}
if columns:
for column in columns:
filters[column] = data.pop(column)
return self._apply_where(
self.model_class.update(**data),
filters,
conjunction).execute()
def _query(self, **query):
return self._apply_where(self.model_class.select(), query)
def find(self, **query):
return self._query(**query).dicts()
def find_one(self, **query):
try:
return self.find(**query).get()
except self.model_class.DoesNotExist:
return None
def all(self):
return self.find()
def delete(self, **query):
return self._apply_where(self.model_class.delete(), query).execute()
def freeze(self, *args, **kwargs):
return self.dataset.freeze(self.all(), *args, **kwargs)
def thaw(self, *args, **kwargs):
return self.dataset.thaw(self.name, *args, **kwargs)
class Exporter(object):
def __init__(self, query):
self.query = query
def export(self, file_obj):
raise NotImplementedError
class JSONExporter(Exporter):
def __init__(self, query, iso8601_datetimes=False):
super(JSONExporter, self).__init__(query)
self.iso8601_datetimes = iso8601_datetimes
def _make_default(self):
datetime_types = (datetime.datetime, datetime.date, datetime.time)
if self.iso8601_datetimes:
def default(o):
if isinstance(o, datetime_types):
return o.isoformat()
elif isinstance(o, Decimal):
return str(o)
raise TypeError('Unable to serialize %r as JSON' % o)
else:
def default(o):
if isinstance(o, datetime_types + (Decimal,)):
return str(o)
raise TypeError('Unable to serialize %r as JSON' % o)
return default
def export(self, file_obj, **kwargs):
json.dump(
list(self.query),
file_obj,
default=self._make_default(),
**kwargs)
class CSVExporter(Exporter):
def export(self, file_obj, header=True, **kwargs):
writer = csv.writer(file_obj, **kwargs)
tuples = self.query.tuples().execute()
tuples.initialize()
if header and getattr(tuples, 'columns', None):
writer.writerow([column for column in tuples.columns])
for row in tuples:
writer.writerow(row)
class TSVExporter(CSVExporter):
def export(self, file_obj, header=True, **kwargs):
kwargs.setdefault('delimiter', '\t')
return super(TSVExporter, self).export(file_obj, header, **kwargs)
class Importer(object):
def __init__(self, table, strict=False):
self.table = table
self.strict = strict
model = self.table.model_class
self.columns = model._meta.columns
self.columns.update(model._meta.fields)
def load(self, file_obj):
raise NotImplementedError
class JSONImporter(Importer):
def load(self, file_obj, **kwargs):
data = json.load(file_obj, **kwargs)
count = 0
for row in data:
if self.strict:
obj = {}
for key in row:
field = self.columns.get(key)
if field is not None:
obj[field.name] = field.python_value(row[key])
else:
obj = row
if obj:
self.table.insert(**obj)
count += 1
return count
class CSVImporter(Importer):
def load(self, file_obj, header=True, **kwargs):
count = 0
reader = csv.reader(file_obj, **kwargs)
if header:
try:
header_keys = next(reader)
except StopIteration:
return count
if self.strict:
header_fields = []
for idx, key in enumerate(header_keys):
if key in self.columns:
header_fields.append((idx, self.columns[key]))
else:
header_fields = list(enumerate(header_keys))
else:
header_fields = list(enumerate(self.model._meta.sorted_fields))
if not header_fields:
return count
for row in reader:
obj = {}
for idx, field in header_fields:
if self.strict:
obj[field.name] = field.python_value(row[idx])
else:
obj[field] = row[idx]
self.table.insert(**obj)
count += 1
return count
class TSVImporter(CSVImporter):
def load(self, file_obj, header=True, **kwargs):
kwargs.setdefault('delimiter', '\t')
return super(TSVImporter, self).load(file_obj, header, **kwargs)
try:
from urlparse import parse_qsl, unquote, urlparse
except ImportError:
from urllib.parse import parse_qsl, unquote, urlparse
from peewee import *
from playhouse.cockroachdb import CockroachDatabase
from playhouse.cockroachdb import PooledCockroachDatabase
from playhouse.pool import PooledMySQLDatabase
from playhouse.pool import PooledPostgresqlDatabase
from playhouse.pool import PooledSqliteDatabase
from playhouse.pool import PooledSqliteExtDatabase
from playhouse.sqlite_ext import SqliteExtDatabase
schemes = {
'cockroachdb': CockroachDatabase,
'cockroachdb+pool': PooledCockroachDatabase,
'crdb': CockroachDatabase,
'crdb+pool': PooledCockroachDatabase,
'mysql': MySQLDatabase,
'mysql+pool': PooledMySQLDatabase,
'postgres': PostgresqlDatabase,
'postgresql': PostgresqlDatabase,
'postgres+pool': PooledPostgresqlDatabase,
'postgresql+pool': PooledPostgresqlDatabase,
'sqlite': SqliteDatabase,
'sqliteext': SqliteExtDatabase,
'sqlite+pool': PooledSqliteDatabase,
'sqliteext+pool': PooledSqliteExtDatabase,
}
def register_database(db_class, *names):
global schemes
for name in names:
schemes[name] = db_class
def parseresult_to_dict(parsed, unquote_password=False):
# urlparse in python 2.6 is broken so query will be empty and instead
# appended to path complete with '?'
path_parts = parsed.path[1:].split('?')
try:
query = path_parts[1]
except IndexError:
query = parsed.query
connect_kwargs = {'database': path_parts[0]}
if parsed.username:
connect_kwargs['user'] = parsed.username
if parsed.password:
connect_kwargs['password'] = parsed.password
if unquote_password:
connect_kwargs['password'] = unquote(connect_kwargs['password'])
if parsed.hostname:
connect_kwargs['host'] = parsed.hostname
if parsed.port:
connect_kwargs['port'] = parsed.port
# Adjust parameters for MySQL.
if parsed.scheme == 'mysql' and 'password' in connect_kwargs:
connect_kwargs['passwd'] = connect_kwargs.pop('password')
elif 'sqlite' in parsed.scheme and not connect_kwargs['database']:
connect_kwargs['database'] = ':memory:'
# Get additional connection args from the query string
qs_args = parse_qsl(query, keep_blank_values=True)
for key, value in qs_args:
if value.lower() == 'false':
value = False
elif value.lower() == 'true':
value = True
elif value.isdigit():
value = int(value)
elif '.' in value and all(p.isdigit() for p in value.split('.', 1)):
try:
value = float(value)
except ValueError:
pass
elif value.lower() in ('null', 'none'):
value = None
connect_kwargs[key] = value
return connect_kwargs
def parse(url, unquote_password=False):
parsed = urlparse(url)
return parseresult_to_dict(parsed, unquote_password)
def connect(url, unquote_password=False, **connect_params):
parsed = urlparse(url)
connect_kwargs = parseresult_to_dict(parsed, unquote_password)
connect_kwargs.update(connect_params)
database_class = schemes.get(parsed.scheme)
if database_class is None:
if database_class in schemes:
raise RuntimeError('Attempted to use "%s" but a required library '
'could not be imported.' % parsed.scheme)
else:
raise RuntimeError('Unrecognized or unsupported scheme: "%s".' %
parsed.scheme)
return database_class(**connect_kwargs)
# Conditionally register additional databases.
try:
from playhouse.pool import PooledPostgresqlExtDatabase
except ImportError:
pass
else:
register_database(
PooledPostgresqlExtDatabase,
'postgresext+pool',
'postgresqlext+pool')
try:
from playhouse.apsw_ext import APSWDatabase
except ImportError:
pass
else:
register_database(APSWDatabase, 'apsw')
try:
from playhouse.postgres_ext import PostgresqlExtDatabase
except ImportError:
pass
else:
register_database(PostgresqlExtDatabase, 'postgresext', 'postgresqlext')
try:
import bz2
except ImportError:
bz2 = None
try:
import zlib
except ImportError:
zlib = None
try:
import cPickle as pickle
except ImportError:
import pickle
import sys
from peewee import BlobField
from peewee import buffer_type
PY2 = sys.version_info[0] == 2
class CompressedField(BlobField):
ZLIB = 'zlib'
BZ2 = 'bz2'
algorithm_to_import = {
ZLIB: zlib,
BZ2: bz2,
}
def __init__(self, compression_level=6, algorithm=ZLIB, *args,
**kwargs):
self.compression_level = compression_level
if algorithm not in self.algorithm_to_import:
raise ValueError('Unrecognized algorithm %s' % algorithm)
compress_module = self.algorithm_to_import[algorithm]
if compress_module is None:
raise ValueError('Missing library required for %s.' % algorithm)
self.algorithm = algorithm
self.compress = compress_module.compress
self.decompress = compress_module.decompress
super(CompressedField, self).__init__(*args, **kwargs)
def python_value(self, value):
if value is not None:
return self.decompress(value)
def db_value(self, value):
if value is not None:
return self._constructor(
self.compress(value, self.compression_level))
class PickleField(BlobField):
def python_value(self, value):
if value is not None:
if isinstance(value, buffer_type):
value = bytes(value)
return pickle.loads(value)
def db_value(self, value):
if value is not None:
pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
return self._constructor(pickled)
import math
import sys
from flask import abort
from flask import render_template
from flask import request
from peewee import Database
from peewee import DoesNotExist
from peewee import Model
from peewee import Proxy
from peewee import SelectQuery
from playhouse.db_url import connect as db_url_connect
class PaginatedQuery(object):
def __init__(self, query_or_model, paginate_by, page_var='page', page=None,
check_bounds=False):
self.paginate_by = paginate_by
self.page_var = page_var
self.page = page or None
self.check_bounds = check_bounds
if isinstance(query_or_model, SelectQuery):
self.query = query_or_model
self.model = self.query.model
else:
self.model = query_or_model
self.query = self.model.select()
def get_page(self):
if self.page is not None:
return self.page
curr_page = request.args.get(self.page_var)
if curr_page and curr_page.isdigit():
return max(1, int(curr_page))
return 1
def get_page_count(self):
if not hasattr(self, '_page_count'):
self._page_count = int(math.ceil(
float(self.query.count()) / self.paginate_by))
return self._page_count
def get_object_list(self):
if self.check_bounds and self.get_page() > self.get_page_count():
abort(404)
return self.query.paginate(self.get_page(), self.paginate_by)
def get_object_or_404(query_or_model, *query):
if not isinstance(query_or_model, SelectQuery):
query_or_model = query_or_model.select()
try:
return query_or_model.where(*query).get()
except DoesNotExist:
abort(404)
def object_list(template_name, query, context_variable='object_list',
paginate_by=20, page_var='page', page=None, check_bounds=True,
**kwargs):
paginated_query = PaginatedQuery(
query,
paginate_by=paginate_by,
page_var=page_var,
page=page,
check_bounds=check_bounds)
kwargs[context_variable] = paginated_query.get_object_list()
return render_template(
template_name,
pagination=paginated_query,
page=paginated_query.get_page(),
**kwargs)
def get_current_url():
if not request.query_string:
return request.path
return '%s?%s' % (request.path, request.query_string)
def get_next_url(default='/'):
if request.args.get('next'):
return request.args['next']
elif request.form.get('next'):
return request.form['next']
return default
class FlaskDB(object):
def __init__(self, app=None, database=None, model_class=Model):
self.database = None # Reference to actual Peewee database instance.
self.base_model_class = model_class
self._app = app
self._db = database # dict, url, Database, or None (default).
if app is not None:
self.init_app(app)
def init_app(self, app):
self._app = app
if self._db is None:
if 'DATABASE' in app.config:
initial_db = app.config['DATABASE']
elif 'DATABASE_URL' in app.config:
initial_db = app.config['DATABASE_URL']
else:
raise ValueError('Missing required configuration data for '
'database: DATABASE or DATABASE_URL.')
else:
initial_db = self._db
self._load_database(app, initial_db)
self._register_handlers(app)
def _load_database(self, app, config_value):
if isinstance(config_value, Database):
database = config_value
elif isinstance(config_value, dict):
database = self._load_from_config_dict(dict(config_value))
else:
# Assume a database connection URL.
database = db_url_connect(config_value)
if isinstance(self.database, Proxy):
self.database.initialize(database)
else:
self.database = database
def _load_from_config_dict(self, config_dict):
try:
name = config_dict.pop('name')
engine = config_dict.pop('engine')
except KeyError:
raise RuntimeError('DATABASE configuration must specify a '
'`name` and `engine`.')
if '.' in engine:
path, class_name = engine.rsplit('.', 1)
else:
path, class_name = 'peewee', engine
try:
__import__(path)
module = sys.modules[path]
database_class = getattr(module, class_name)
assert issubclass(database_class, Database)
except ImportError:
raise RuntimeError('Unable to import %s' % engine)
except AttributeError:
raise RuntimeError('Database engine not found %s' % engine)
except AssertionError:
raise RuntimeError('Database engine not a subclass of '
'peewee.Database: %s' % engine)
return database_class(name, **config_dict)
def _register_handlers(self, app):
app.before_request(self.connect_db)
app.teardown_request(self.close_db)
def get_model_class(self):
if self.database is None:
raise RuntimeError('Database must be initialized.')
class BaseModel(self.base_model_class):
class Meta:
database = self.database
return BaseModel
@property
def Model(self):
if self._app is None:
database = getattr(self, 'database', None)
if database is None:
self.database = Proxy()
if not hasattr(self, '_model_class'):
self._model_class = self.get_model_class()
return self._model_class
def connect_db(self):
self.database.connect()
def close_db(self, exc):
if not self.database.is_closed():
self.database.close()
from peewee import ModelDescriptor
# Hybrid methods/attributes, based on similar functionality in SQLAlchemy:
# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html
class hybrid_method(ModelDescriptor):
def __init__(self, func, expr=None):
self.func = func
self.expr = expr or func
def __get__(self, instance, instance_type):
if instance is None:
return self.expr.__get__(instance_type, instance_type.__class__)
return self.func.__get__(instance, instance_type)
def expression(self, expr):
self.expr = expr
return self
class hybrid_property(ModelDescriptor):
def __init__(self, fget, fset=None, fdel=None, expr=None):
self.fget = fget
self.fset = fset
self.fdel = fdel
self.expr = expr or fget
def __get__(self, instance, instance_type):
if instance is None:
return self.expr(instance_type)
return self.fget(instance)
def __set__(self, instance, value):
if self.fset is None:
raise AttributeError('Cannot set attribute.')
self.fset(instance, value)
def __delete__(self, instance):
if self.fdel is None:
raise AttributeError('Cannot delete attribute.')
self.fdel(instance)
def setter(self, fset):
self.fset = fset
return self
def deleter(self, fdel):
self.fdel = fdel
return self
def expression(self, expr):
self.expr = expr
return self
import operator
from peewee import *
from peewee import Expression
from playhouse.fields import PickleField
try:
from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase
except ImportError:
from playhouse.sqlite_ext import SqliteExtDatabase
Sentinel = type('Sentinel', (object,), {})
class KeyValue(object):
"""
Persistent dictionary.
:param Field key_field: field to use for key. Defaults to CharField.
:param Field value_field: field to use for value. Defaults to PickleField.
:param bool ordered: data should be returned in key-sorted order.
:param Database database: database where key/value data is stored.
:param str table_name: table name for data.
"""
def __init__(self, key_field=None, value_field=None, ordered=False,
database=None, table_name='keyvalue'):
if key_field is None:
key_field = CharField(max_length=255, primary_key=True)
if not key_field.primary_key:
raise ValueError('key_field must have primary_key=True.')
if value_field is None:
value_field = PickleField()
self._key_field = key_field
self._value_field = value_field
self._ordered = ordered
self._database = database or SqliteExtDatabase(':memory:')
self._table_name = table_name
if isinstance(self._database, PostgresqlDatabase):
self.upsert = self._postgres_upsert
self.update = self._postgres_update
else:
self.upsert = self._upsert
self.update = self._update
self.model = self.create_model()
self.key = self.model.key
self.value = self.model.value
# Ensure table exists.
self.model.create_table()
def create_model(self):
class KeyValue(Model):
key = self._key_field
value = self._value_field
class Meta:
database = self._database
table_name = self._table_name
return KeyValue
def query(self, *select):
query = self.model.select(*select).tuples()
if self._ordered:
query = query.order_by(self.key)
return query
def convert_expression(self, expr):
if not isinstance(expr, Expression):
return (self.key == expr), True
return expr, False
def __contains__(self, key):
expr, _ = self.convert_expression(key)
return self.model.select().where(expr).exists()
def __len__(self):
return len(self.model)
def __getitem__(self, expr):
converted, is_single = self.convert_expression(expr)
query = self.query(self.value).where(converted)
item_getter = operator.itemgetter(0)
result = [item_getter(row) for row in query]
if len(result) == 0 and is_single:
raise KeyError(expr)
elif is_single:
return result[0]
return result
def _upsert(self, key, value):
(self.model
.insert(key=key, value=value)
.on_conflict('replace')
.execute())
def _postgres_upsert(self, key, value):
(self.model
.insert(key=key, value=value)
.on_conflict(conflict_target=[self.key],
preserve=[self.value])
.execute())
def __setitem__(self, expr, value):
if isinstance(expr, Expression):
self.model.update(value=value).where(expr).execute()
else:
self.upsert(expr, value)
def __delitem__(self, expr):
converted, _ = self.convert_expression(expr)
self.model.delete().where(converted).execute()
def __iter__(self):
return iter(self.query().execute())
def keys(self):
return map(operator.itemgetter(0), self.query(self.key))
def values(self):
return map(operator.itemgetter(0), self.query(self.value))
def items(self):
return iter(self.query().execute())
def _update(self, __data=None, **mapping):
if __data is not None:
mapping.update(__data)
return (self.model
.insert_many(list(mapping.items()),
fields=[self.key, self.value])
.on_conflict('replace')
.execute())
def _postgres_update(self, __data=None, **mapping):
if __data is not None:
mapping.update(__data)
return (self.model
.insert_many(list(mapping.items()),
fields=[self.key, self.value])
.on_conflict(conflict_target=[self.key],
preserve=[self.value])
.execute())
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def setdefault(self, key, default=None):
try:
return self[key]
except KeyError:
self[key] = default
return default
def pop(self, key, default=Sentinel):
with self._database.atomic():
try:
result = self[key]
except KeyError:
if default is Sentinel:
raise
return default
del self[key]
return result
def clear(self):
self.model.delete().execute()
This diff is collapsed.
import json
try:
import mysql.connector as mysql_connector
except ImportError:
mysql_connector = None
from peewee import ImproperlyConfigured
from peewee import MySQLDatabase
from peewee import NodeList
from peewee import SQL
from peewee import TextField
from peewee import fn
class MySQLConnectorDatabase(MySQLDatabase):
def _connect(self):
if mysql_connector is None:
raise ImproperlyConfigured('MySQL connector not installed!')
return mysql_connector.connect(db=self.database, **self.connect_params)
def cursor(self, commit=None):
if self.is_closed():
if self.autoconnect:
self.connect()
else:
raise InterfaceError('Error, database connection not opened.')
return self._state.conn.cursor(buffered=True)
class JSONField(TextField):
field_type = 'JSON'
def db_value(self, value):
if value is not None:
return json.dumps(value)
def python_value(self, value):
if value is not None:
return json.loads(value)
def Match(columns, expr, modifier=None):
if isinstance(columns, (list, tuple)):
match = fn.MATCH(*columns) # Tuple of one or more columns / fields.
else:
match = fn.MATCH(columns) # Single column / field.
args = expr if modifier is None else NodeList((expr, SQL(modifier)))
return NodeList((match, fn.AGAINST(args)))
"""
Lightweight connection pooling for peewee.
In a multi-threaded application, up to `max_connections` will be opened. Each
thread (or, if using gevent, greenlet) will have it's own connection.
In a single-threaded application, only one connection will be created. It will
be continually recycled until either it exceeds the stale timeout or is closed
explicitly (using `.manual_close()`).
By default, all your application needs to do is ensure that connections are
closed when you are finished with them, and they will be returned to the pool.
For web applications, this typically means that at the beginning of a request,
you will open a connection, and when you return a response, you will close the
connection.
Simple Postgres pool example code:
# Use the special postgresql extensions.
from playhouse.pool import PooledPostgresqlExtDatabase
db = PooledPostgresqlExtDatabase(
'my_app',
max_connections=32,
stale_timeout=300, # 5 minutes.
user='postgres')
class BaseModel(Model):
class Meta:
database = db
That's it!
"""
import heapq
import logging
import random
import time
from collections import namedtuple
from itertools import chain
try:
from psycopg2.extensions import TRANSACTION_STATUS_IDLE
from psycopg2.extensions import TRANSACTION_STATUS_INERROR
from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN
except ImportError:
TRANSACTION_STATUS_IDLE = \
TRANSACTION_STATUS_INERROR = \
TRANSACTION_STATUS_UNKNOWN = None
from peewee import MySQLDatabase
from peewee import PostgresqlDatabase
from peewee import SqliteDatabase
logger = logging.getLogger('peewee.pool')
def make_int(val):
if val is not None and not isinstance(val, (int, float)):
return int(val)
return val
class MaxConnectionsExceeded(ValueError): pass
PoolConnection = namedtuple('PoolConnection', ('timestamp', 'connection',
'checked_out'))
class PooledDatabase(object):
def __init__(self, database, max_connections=20, stale_timeout=None,
timeout=None, **kwargs):
self._max_connections = make_int(max_connections)
self._stale_timeout = make_int(stale_timeout)
self._wait_timeout = make_int(timeout)
if self._wait_timeout == 0:
self._wait_timeout = float('inf')
# Available / idle connections stored in a heap, sorted oldest first.
self._connections = []
# Mapping of connection id to PoolConnection. Ordinarily we would want
# to use something like a WeakKeyDictionary, but Python typically won't
# allow us to create weak references to connection objects.
self._in_use = {}
# Use the memory address of the connection as the key in the event the
# connection object is not hashable. Connections will not get
# garbage-collected, however, because a reference to them will persist
# in "_in_use" as long as the conn has not been closed.
self.conn_key = id
super(PooledDatabase, self).__init__(database, **kwargs)
def init(self, database, max_connections=None, stale_timeout=None,
timeout=None, **connect_kwargs):
super(PooledDatabase, self).init(database, **connect_kwargs)
if max_connections is not None:
self._max_connections = make_int(max_connections)
if stale_timeout is not None:
self._stale_timeout = make_int(stale_timeout)
if timeout is not None:
self._wait_timeout = make_int(timeout)
if self._wait_timeout == 0:
self._wait_timeout = float('inf')
def connect(self, reuse_if_open=False):
if not self._wait_timeout:
return super(PooledDatabase, self).connect(reuse_if_open)
expires = time.time() + self._wait_timeout
while expires > time.time():
try:
ret = super(PooledDatabase, self).connect(reuse_if_open)
except MaxConnectionsExceeded:
time.sleep(0.1)
else:
return ret
raise MaxConnectionsExceeded('Max connections exceeded, timed out '
'attempting to connect.')
def _connect(self):
while True:
try:
# Remove the oldest connection from the heap.
ts, conn = heapq.heappop(self._connections)
key = self.conn_key(conn)
except IndexError:
ts = conn = None
logger.debug('No connection available in pool.')
break
else:
if self._is_closed(conn):
# This connecton was closed, but since it was not stale
# it got added back to the queue of available conns. We
# then closed it and marked it as explicitly closed, so
# it's safe to throw it away now.
# (Because Database.close() calls Database._close()).
logger.debug('Connection %s was closed.', key)
ts = conn = None
elif self._stale_timeout and self._is_stale(ts):
# If we are attempting to check out a stale connection,
# then close it. We don't need to mark it in the "closed"
# set, because it is not in the list of available conns
# anymore.
logger.debug('Connection %s was stale, closing.', key)
self._close(conn, True)
ts = conn = None
else:
break
if conn is None:
if self._max_connections and (
len(self._in_use) >= self._max_connections):
raise MaxConnectionsExceeded('Exceeded maximum connections.')
conn = super(PooledDatabase, self)._connect()
ts = time.time() - random.random() / 1000
key = self.conn_key(conn)
logger.debug('Created new connection %s.', key)
self._in_use[key] = PoolConnection(ts, conn, time.time())
return conn
def _is_stale(self, timestamp):
# Called on check-out and check-in to ensure the connection has
# not outlived the stale timeout.
return (time.time() - timestamp) > self._stale_timeout
def _is_closed(self, conn):
return False
def _can_reuse(self, conn):
# Called on check-in to make sure the connection can be re-used.
return True
def _close(self, conn, close_conn=False):
key = self.conn_key(conn)
if close_conn:
super(PooledDatabase, self)._close(conn)
elif key in self._in_use:
pool_conn = self._in_use.pop(key)
if self._stale_timeout and self._is_stale(pool_conn.timestamp):
logger.debug('Closing stale connection %s.', key)
super(PooledDatabase, self)._close(conn)
elif self._can_reuse(conn):
logger.debug('Returning %s to pool.', key)
heapq.heappush(self._connections, (pool_conn.timestamp, conn))
else:
logger.debug('Closed %s.', key)
def manual_close(self):
"""
Close the underlying connection without returning it to the pool.
"""
if self.is_closed():
return False
# Obtain reference to the connection in-use by the calling thread.
conn = self.connection()
# A connection will only be re-added to the available list if it is
# marked as "in use" at the time it is closed. We will explicitly
# remove it from the "in use" list, call "close()" for the
# side-effects, and then explicitly close the connection.
self._in_use.pop(self.conn_key(conn), None)
self.close()
self._close(conn, close_conn=True)
def close_idle(self):
# Close any open connections that are not currently in-use.
with self._lock:
for _, conn in self._connections:
self._close(conn, close_conn=True)
self._connections = []
def close_stale(self, age=600):
# Close any connections that are in-use but were checked out quite some
# time ago and can be considered stale.
with self._lock:
in_use = {}
cutoff = time.time() - age
n = 0
for key, pool_conn in self._in_use.items():
if pool_conn.checked_out < cutoff:
self._close(pool_conn.connection, close_conn=True)
n += 1
else:
in_use[key] = pool_conn
self._in_use = in_use
return n
def close_all(self):
# Close all connections -- available and in-use. Warning: may break any
# active connections used by other threads.
self.close()
with self._lock:
for _, conn in self._connections:
self._close(conn, close_conn=True)
for pool_conn in self._in_use.values():
self._close(pool_conn.connection, close_conn=True)
self._connections = []
self._in_use = {}
class PooledMySQLDatabase(PooledDatabase, MySQLDatabase):
def _is_closed(self, conn):
try:
conn.ping(False)
except:
return True
else:
return False
class _PooledPostgresqlDatabase(PooledDatabase):
def _is_closed(self, conn):
if conn.closed:
return True
txn_status = conn.get_transaction_status()
if txn_status == TRANSACTION_STATUS_UNKNOWN:
return True
elif txn_status != TRANSACTION_STATUS_IDLE:
conn.rollback()
return False
def _can_reuse(self, conn):
txn_status = conn.get_transaction_status()
# Do not return connection in an error state, as subsequent queries
# will all fail. If the status is unknown then we lost the connection
# to the server and the connection should not be re-used.
if txn_status == TRANSACTION_STATUS_UNKNOWN:
return False
elif txn_status == TRANSACTION_STATUS_INERROR:
conn.reset()
elif txn_status != TRANSACTION_STATUS_IDLE:
conn.rollback()
return True
class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase):
pass
try:
from playhouse.postgres_ext import PostgresqlExtDatabase
class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase):
pass
except ImportError:
PooledPostgresqlExtDatabase = None
class _PooledSqliteDatabase(PooledDatabase):
def _is_closed(self, conn):
try:
conn.total_changes
except:
return True
else:
return False
class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase):
pass
try:
from playhouse.sqlite_ext import SqliteExtDatabase
class PooledSqliteExtDatabase(_PooledSqliteDatabase, SqliteExtDatabase):
pass
except ImportError:
PooledSqliteExtDatabase = None
try:
from playhouse.sqlite_ext import CSqliteExtDatabase
class PooledCSqliteExtDatabase(_PooledSqliteDatabase, CSqliteExtDatabase):
pass
except ImportError:
PooledCSqliteExtDatabase = None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment