Source code for sparrow.sql


from functools import wraps
import copy

import psycopg2
import momoko

from .util import *


# Exceptions
# ==========

[docs]class SqlError(Error): """Exception raised while executing a query (or command). Wraps a psycopg2 error to also include the query that went wrong. """ def __init__(self, err: psycopg2.Error, query: "Sql", data: dict): self.err = err self.query = query self.data = data def __str__(self): return "While executing this SQL:\n{s.query}\nWith this data:{data}\nThis exception occured:{s.err}".format( s = self, data = repr(self.data))
[docs]class NotSingle(Error): pass
# Classes # =======
[docs]class Database: """Class for Postgres database.""" def __init__(self, ioloop, dbname, user="postgres", password="postgres", host="localhost", port=5432, momoko_poolsize=5): dsn = "dbname={dbname} user={user} password={password} host={host} port={port}".format( dbname=dbname, user=user, password=password, host=host, port=port) self.pdb = momoko.Pool(dsn=dsn, size=momoko_poolsize, ioloop=ioloop) self.pdb.connect()
[docs] async def get_cursor(self, statement: "Sql", unsafe_dict: dict): statement = str(statement) cursor = await self.pdb.execute(statement, unsafe_dict) return cursor
[docs]class GlobalDb: db = None @classmethod
[docs] def get(cls): return cls.db
@classmethod
[docs] def set(cls, db): cls.db = db
@classmethod
[docs] def globalize(cls, db): if db is not None: return cls.db
# Helper classes for building a query # -----------------------------------
[docs]class Unsafe: """Wrapper for unsafe data. (For data that needs to inserted later, use Field.)""" def __new__(typ, val, *args, **kwargs): if isinstance(val, Unsafe): return val obj = object.__new__(typ, *args, **kwargs) return obj def __init__(self, value): self.key = str(id(self)) self.text = "%({0})s".format(self.key) self.value = value def __str__(self): return self.text
[docs]class Field: """Wrapper for data that is to be inserted into the query (with Sql.with_data) later on.""" def __init__(self, name): self.text = "%({0})s".format(name) def __str__(self): return self.text
# Main classes for queries and their results # ------------------------------------------
[docs]class SqlResult: """Class wrapping a database cursor. Note that most of the time, you can use the 'easier' methods in Sql. Instead of: >>> res = await User.get(User.name == "Evert").exec(db) >>> u = res.single() You can do: >>> u = await User.get(User.name == "Evert").single(db) This doesn't work for scrolling and getting raw data. The methods `single`, `all` and `amount` will try to interpret the result as object(s) of the given class in `self.query.cls`, don't try them if it that class is `None`. """ def __init__(self, cursor, query: "Sql"): self.cursor = cursor self.query = query
[docs] def raw(self): """ Returns the raw (single) result, without any interpreting. If you want to do more than getting a single raw value, consider accessing self.cursor directly. """ return self.cursor.fetchone()
[docs] def raw_all(self): """Returns all raw values.""" return self.cursor.fetchall()
[docs] def single(self): """Returns a single object (and raises NotSingle if there is not only one.""" if self.cursor.rowcount != 1: raise NotSingle("Not 1 result but {} result(s).".format(self.cursor.rowcount)) return self.query.cls(db_args=self.cursor.fetchone())
[docs] def all(self): """Returns all objects in the query.""" return [self.query.cls(db_args=t) for t in self.cursor.fetchall()]
[docs] def amount(self, i: int): """Returns a given number of objects in the query.""" # TODO consider creating a version that asserts the amount specified is found return [self.query.cls(db_args=t) for t in self.cursor.fetchmany(size=i)]
[docs] def scroll(self, i: int): """Scroll the cursor `i` steps. `i` can be negative. This method is chainable.""" self.cursor.scroll(i) return self
[docs] def count(self): """Return the number of rows found in the query.""" return self.cursor.rowcount
def _wrapper_sqlresult(method): @wraps(method) async def wrapper(self, db: Database = None, *args, **kwargs): if db is None: db = GlobalDb.get() result = await self.exec(db) return method(result, *args, **kwargs) wrapper.__doc__ += "\n\nWrapped version, first argument is the database." return wrapper
[docs]class Sql: """Main class to save a given SQL query/command. Do not use directly, use the subclasses.""" def __init__(self, data = {}): if hasattr(self, "data"): self.data.update(data) else: self.data = data def __preinit__(self): self.data = {} # By default, there is no class cls = None
[docs] async def exec(self, db: Database = None): """Execute the SQL statement on the given database.""" if db is None: db = GlobalDb.get() try: return SqlResult(await db.get_cursor(str(self), self.data), self) except psycopg2.Error as e: raise SqlError(e, str(self), self.data)
# Allows you to call these method immediatly on a statement: single = _wrapper_sqlresult(SqlResult.single) all = _wrapper_sqlresult(SqlResult.all) amount = _wrapper_sqlresult(SqlResult.amount) count = _wrapper_sqlresult(SqlResult.count) raw = _wrapper_sqlresult(SqlResult.raw) raw_all = _wrapper_sqlresult(SqlResult.raw_all)
[docs] def with_data(self, **kwargs): """This function creates a copy of the statement with added data, passed as keyword arguments.""" newself = self.copy() newself.data.update(kwargs) return newself
# By default simply create a deepcopy
[docs] def copy(self): """Create a copy of the statement, by default uses `copy.deepcopy`.""" return copy.deepcopy(self)
[docs] def check(self, what): """Helper function to *parse* parts of a query and insert their data in `self.data`. Handles `Unsafe`, `Field`, other `Sql` instances and tuples. It will simply return all others types. """ if isinstance(what, Sql): self.data.update(what.data) return what.to_raw() elif isinstance(what, Unsafe): self.data[what.key] = what.value elif isinstance(what, Field): return str(what) elif isinstance(what, tuple): l = [] for t in what: l.append(self.check(t)) return tuple(l) return what
[docs] def to_raw(self): """Compile this to a `RawSql` instance for more performance!""" return RawSql(str(self), self.data)
def __str__(self): return "undefined so far"
[docs]class ClassedSql(Sql): """Version of `Sql` that also saves a given class. `SqlResult` will later try to parse its result as instances of this class. """ def __init__(self, cls: type, data={}): self.cls = cls Sql.__init__(self, data)
[docs] def to_raw(self): return RawClassedSql(self.cls, str(self), self.data)
[docs]class RawSql(Sql): """Simply saves a string, and also some data. This is in contrast with `Sql`, which may save the query in a more abstract way. """ def __init__(self, text, data = {}): self.text = text Sql.__init__(self, data)
[docs] def to_raw(self): """Already raw, just return self.""" return self
[docs] def copy(self): """More optimized version of copy.""" return RawSql(self.text, copy.copy(self.data))
def __str__(self): return self.text
[docs]class RawClassedSql(RawSql, ClassedSql): """Version of `RawSql` that also saves a given class like `ClassedSql`.""" def __init__(self, cls, text, data = {}): # TODO possibly make this use super but I suspect it will fuck around self.cls = cls self.text = text Sql.__init__(self, data)
[docs] def copy(self): return RawClassedSql(self.cls, self.text, copy.copy(self.data))
[docs]class Condition(Sql): pass
[docs]class Not(Condition): def __init__(self, cond): Sql.__preinit__(self) self.cond = self.check(cond) Sql.__init__(self) def __str__(self): return "(NOT {})".format(self.cond)
[docs]class MultiCondition(Condition): def __init__(self, *conds): Sql.__preinit__(self) self.conditions = [self.check(c) for c in conds] Sql.__init__(self)
[docs]class And(MultiCondition): def __str__(self): return "(" + " AND ".join([str(c) for c in self.conditions]) + ")"
[docs]class Or(MultiCondition): def __str__(self): return "(" + " OR ".join([str(c) for c in self.conditions]) + ")"
[docs]class Where(Condition): """Encodes a single 'WHERE' clause.""" def __init__(self, lfield, op: str, rfield, data={}): """Initialize a 'WHERE' clause. Parameters: - `lfield` and `rfield`: Anything that can be interpreted as a part of an SQL query, could be of type string, `Sql`, `Field`, `Unsafe`, ... - `op`: Some operation that needs to be performed. Examples: '==', '>', ... """ Sql.__preinit__(self) self.lfield = self.check(lfield) self.op = op self.rfield = self.check(rfield) Sql.__init__(self, data) def __str__(self): return "{s.lfield} {s.op} {s.rfield}".format(s=self)
[docs]class Order(Sql): # TODO order on multiple attributes (might work already?) """Encodes an 'ORDER BY' clause.""" def __init__(self, field, op: str, data={}): """Initialize an 'ORDER BY' clause. Parameters: - `field`: Anything that can be interpreted as a part of an SQL query, could be of type string, `Sql`, `Field`, `Unsafe`, ... - `op`: Either 'ASC' or 'DESC'. """ Sql.__preinit__(self) self.field = self.check(field) self.op = op Sql.__init__(self, data) def __str__(self): return "{s.field} {s.op}".format(s=self)
[docs]class Select(ClassedSql): """Encodes a 'SELECT' query.""" def __init__(self, cls, where_clauses = [], order: Order = None, offset=None, limit=None): """Initialize a 'SELECT' query. Most likely you will use `SomeEntityClass.get(...)` instead of this.""" Sql.__preinit__(self) self.where_clauses = [self.check(c) for c in where_clauses] self._order = self.check(order) self._offset = self.check(offset) self._limit = self.check(limit) ClassedSql.__init__(self, cls)
[docs] def limit(self, l): """'LIMIT' the query. Can be used for chaining.""" self._limit = l return self
[docs] def offset(self, o): """'OFFSET' the query. Can be used for chaining.""" self._offset = o return self
[docs] def where(self, *clauses): """'OFFSET' the query. Can be used for chaining. Parameters: a number of Where clauses. """ self.where_clauses.extend(clauses) return self
[docs] def order(self, _order: Order): """'ORDER' the query. Can be used for chaining.""" if not isinstance(_order, Order): _order = +_order self._order = _order return self
def __str__(self): s = "SELECT {props} FROM {cls._table_name}".format(cls=self.cls, props=self.cls._select_props) if len(self.where_clauses) > 0: s += " WHERE " + " AND ".join(["("+str(c)+")" for c in self.where_clauses]) if self._order is not None: s += " ORDER BY {}".format(self._order) if self._limit is not None: s += " LIMIT {}".format(self._limit) if self._offset is not None: s += " OFFSET {}".format(self._offset) return s
[docs]class Command(ClassedSql): """For INSERT, DELETE, UPDATE, CREATE TABLE, DROP TABLE, ... statements.""" pass
sql_create_table_template = """ CREATE TABLE {tname} ( {stuff} ); """
[docs]class CreateTable(Command): """For CREATE TABLE statements.""" def __init__(self, cls): Command.__init__(self, cls) def __str__(self): return sql_create_table_template.format( tname = self.cls._table_name, stuff = ",\n".join([p.sql_def() for p in self.cls._props] + [r.sql_constraint() for r in self.cls._refs] + [self.cls.key.sql_constraint()]) )
[docs]class DropTable(Command): """For DROP TABLE statements.""" def __init__(self, cls): Command.__init__(self, cls) self.tname = cls._table_name def __str__(self): return "DROP TABLE IF EXISTS {tname} CASCADE".format( tname = self.tname )
[docs]class EntityCommand(Command): """Class for commands that work for both classes (will require inserting data later on) as objects. """ def __preinit__(self, what): Sql.__preinit__(self) if isinstance(what, type): # an Insert that needs to be filled later on self.cls = what else: self.cls = type(what) for p in self.cls._complete_props: self.data[p.name] = what.__dict__[p.dataname]
[docs]class Insert(EntityCommand): """For INSERT statements.""" def __init__(self, what, returning=None, replace=False): """ Parameters: - `what`: Either an instance of an `Entity` or a subclass of `Entity`. - `returning`: What should the database return? Expects a `Property` or None. """ EntityCommand.__preinit__(self, what) self._returning = self.check(returning) self._replace = replace Command.__init__(self, self.cls)
[docs] def returning(self, prop): # TODO return multiple attributes? """Set what the database should return. Can be chained.""" self._returning = prop return self
def __str__(self): s = "INSERT INTO {cls._table_name} ({props}) VALUES({vals})".format( cls = self.cls, props = ", ".join([p.name for p in self.cls._complete_props]), vals = ", ".join(["%("+p.name+")s" for p in self.cls._complete_props]) ) if self._replace: s += " ON CONFLICT ({keys}) DO UPDATE SET {vals}".format( keys = ", ".join([k.name for k in self.cls.key.referencing_props()]), vals = ", ".join(["{0} = EXCLUDED.{0}".format(p.name) for p in self.cls._props])) if self._returning is not None: s += " RETURNING " + str(self._returning) return s
[docs]class Update(EntityCommand): def __init__(self, what): """ Parameters: - `what`: Either an instance of an `Entity` or a subclass of `Entity`. """ EntityCommand.__preinit__(self, what) EntityCommand.__init__(self, self.cls) def __str__(self): return "UPDATE {cls._table_name} SET ({props}) = ({vals}) WHERE {cls.key} = ({keyvals})".format( cls = self.cls, props = ", ".join([p.name for p in self.cls._complete_props]), vals = ", ".join(["%("+p.name+")s" for p in self.cls._complete_props]), keyvals = ", ".join(["%("+p.name+")s" for p in self.cls.key.referencing_props()]) )
[docs]class Delete(Command): def __init__(self, what): """ Parameters: - `what`: Either an instance of an `Entity` or a subclass of `Entity`. """ EntityCommand.__preinit__(self, what) EntityCommand.__init__(self, self.cls) def __str__(self): return "DELETE FROM {cls._table_name} WHERE {cls.key} = ({keyvals})".format( cls = self.cls, keyvals = ", ".join(["%("+p.name+")s" for p in self.cls.key.referencing_props()]) )