# -*- coding: utf-8 -*-
"""
Extend the power of declarative base.
"""
import math
from typing import Union, List, Tuple, Dict, Any
from collections import OrderedDict
from copy import deepcopy
from sqlalchemy import inspect, func, text, select, update, Column
from sqlalchemy.sql.expression import TextClause
from sqlalchemy.engine import Engine
from sqlalchemy.orm import declarative_base, Session, InstrumentedAttribute
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import FlushError
from ..utils import (
ensure_exact_one_arg_is_not_none, ensure_list, grouper_list,
ensure_session, clean_session,
)
Base = declarative_base()
[docs]class ExtendedBase(Base):
"""
Provide additional method.
Example::
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class User(Base, ExtendedBase):
... do what you do with sqlalchemy ORM
**中文文档**
提供了三个快捷函数, 分别用于获得列表形式的 primary key names, fields, values
- :meth:`ExtendedBase.pk_names`
- :meth:`ExtendedBase.pk_fields`
- :meth:`ExtendedBase.pk_values`
另外提供了三个快捷函数, 专门针对只有一个 primary key 的情况, 分别用于获得单个形式的
primary key name, field, value.
- :meth:`ExtendedBase.id_field_name`
- :meth:`ExtendedBase.id_field`
- :meth:`ExtendedBase.id_field_value`
所有参数包括 ``engine_or_session`` 的函数需返回
- 所有的 insert / update
- 所有的 select 相关的 method 返回的不是 ResultProxy, 因为有 engine_or_session
这个参数如果输入是 engine, 用于执行的 session 都是临时对象, 离开了这个 method,
session 将被摧毁. 而返回的 Result 是跟当前的 session 绑定相关的, session 一旦
被关闭, Result 理应不进行任何后续操作. 所以建议全部返回所有结果的列表而不是迭代器.
"""
__abstract__ = True
_settings_major_attrs: list = None
_cache_pk_names: tuple = None
# --- No DB interaction APIs ---
[docs] @classmethod
def pk_names(cls) -> Tuple[str]:
"""
Primary key column name list.
"""
if cls._cache_pk_names is None:
cls._cache_pk_names = tuple([
col.name for col in inspect(cls).primary_key
])
return cls._cache_pk_names
_cache_pk_fields: tuple = None
[docs] @classmethod
def pk_fields(cls) -> Tuple[InstrumentedAttribute]:
"""
Primary key columns instance. For example::
class User(Base):
id = Column(..., primary_key=True)
name = Column(...)
User.pk_fields() # (User.id,)
:rtype: tuple
"""
if cls._cache_pk_fields is None:
cls._cache_pk_fields = tuple([
getattr(cls, name) for name in cls.pk_names()
])
return cls._cache_pk_fields
[docs] def pk_values(self) -> tuple:
"""
Primary key values
:rtype: tuple
"""
return tuple([getattr(self, name) for name in self.pk_names()])
# id_field_xxx() method are only valid if there's only one primary key
_id_field_name: str = None
[docs] @classmethod
def id_field_name(cls) -> str:
"""
If only one primary_key, then return the name of it.
Otherwise, raise ValueError.
"""
if cls._id_field_name is None:
if len(cls.pk_names()) == 1:
cls._id_field_name = cls.pk_names()[0]
else:
raise ValueError(
"{classname} has more than 1 primary key!"
.format(classname=cls.__name__)
)
return cls._id_field_name
_id_field = None
[docs] @classmethod
def id_field(cls) -> InstrumentedAttribute:
"""
If only one primary_key, then return the Class.field name object.
Otherwise, raise ValueError.
"""
if cls._id_field is None:
cls._id_field = getattr(cls, cls.id_field_name())
return cls._id_field
[docs] def id_field_value(self):
"""
If only one primary_key, then return the value of primary key.
Otherwise, raise ValueError
"""
return getattr(self, self.id_field_name())
_cache_keys: List[str] = None
[docs] @classmethod
def keys(cls) -> List[str]:
"""
return list of all declared columns.
:rtype: List[str]
"""
if cls._cache_keys is None:
cls._cache_keys = [c.name for c in cls.__table__.columns]
return cls._cache_keys
[docs] def values(self) -> list:
"""
return list of value of all declared columns.
"""
return [getattr(self, c.name, None) for c in self.__table__.columns]
[docs] def items(self) -> List[Tuple[str, Any]]:
"""
return list of pair of name and value of all declared columns.
"""
return [
(c.name, getattr(self, c.name, None))
for c in self.__table__.columns
]
def __repr__(self):
kwargs = list()
for attr, value in self.items():
kwargs.append("%s=%r" % (attr, value))
return "%s(%s)" % (self.__class__.__name__, ", ".join(kwargs))
def __str__(self):
return self.__repr__()
[docs] def to_dict(self, include_null=True) -> Dict[str, Any]:
"""
Convert to dict.
:rtype: dict
"""
if include_null:
return dict(self.items())
else:
return {
attr: value
for attr, value in self.__dict__.items()
if not attr.startswith("_sa_")
}
[docs] def to_OrderedDict(
self,
include_null: bool = True,
) -> OrderedDict:
"""
Convert to OrderedDict.
"""
if include_null:
return OrderedDict(self.items())
else:
items = list()
for c in self.__table__._columns:
try:
items.append((c.name, self.__dict__[c.name]))
except KeyError:
pass
return OrderedDict(items)
_cache_major_attrs: tuple = None
@classmethod
def _major_attrs(cls):
if cls._cache_major_attrs is None:
l = list()
for item in cls._settings_major_attrs:
if isinstance(item, Column):
l.append(item.name)
elif isinstance(item, str):
l.append(item)
else: # pragma: no cover
raise TypeError
if len(set(l)) != len(l): # pragma: no cover
raise ValueError
cls._cache_major_attrs = tuple(l)
return cls._cache_major_attrs
[docs] def glance(self, _verbose: bool = True): # pragma: no cover
"""
Print itself, only display attributes defined in
:attr:`ExtendedBase._settings_major_attrs`
:param _verbose: internal param for unit testing
"""
if self._settings_major_attrs is None:
msg = ("Please specify attributes you want to include "
"in `class._settings_major_attrs`!")
raise NotImplementedError(msg)
kwargs = [
(attr, getattr(self, attr))
for attr in self._major_attrs()
]
text = "{classname}({kwargs})".format(
classname=self.__class__.__name__,
kwargs=", ".join([
"%s=%r" % (attr, value)
for attr, value in kwargs
])
)
if _verbose: # pragma: no cover
print(text)
[docs] def absorb(
self,
other: 'ExtendedBase',
ignore_none: bool = True,
) -> 'ExtendedBase':
"""
For attributes of others that value is not None, assign it to self.
**中文文档**
将另一个文档中的数据更新到本条文档。当且仅当数据值不为None时。
"""
if not isinstance(other, self.__class__):
raise TypeError("`other` has to be a instance of %s!" %
self.__class__)
if ignore_none:
for attr, value in other.items():
if value is not None:
setattr(self, attr, deepcopy(value))
else:
for attr, value in other.items():
setattr(self, attr, deepcopy(value))
return self
[docs] def revise(
self,
data: dict,
ignore_none: bool = True,
) -> 'ExtendedBase':
"""
Revise attributes value with dictionary data.
**中文文档**
将一个字典中的数据更新到本条文档. 当且仅当数据值不为 None 时.
"""
if not isinstance(data, dict):
raise TypeError("`data` has to be a dict!")
if ignore_none:
for key, value in data.items():
if value is not None:
setattr(self, key, deepcopy(value))
else:
for key, value in data.items():
setattr(self, key, deepcopy(value))
return self
# --- DB interaction APIs ---
[docs] @classmethod
def by_pk(
cls,
engine_or_session: Union[Engine, Session],
id_: Union[Any, List[Any], Tuple],
):
"""
Get one object by primary_key values.
Examples::
class User(Base):
id = Column(Integer, primary_key)
name = Column(String)
with Session(engine) as session:
session.add(User(id=1, name="Alice")
session.commit()
# User(id=1, name="Alice")
print(User.by_pk(1, engine))
print(User.by_pk((1,), engine))
print(User.by_pk([1,), engine))
with Session(engine) as session:
print(User.by_pk(1, session))
print(User.by_pk((1,), session))
print(User.by_pk([1,), session))
**中文文档**
一个简单的语法糖, 允许用户直接用 primary key column 的值访问单个对象.
"""
ses, auto_close = ensure_session(engine_or_session)
obj = ses.get(cls, id_)
clean_session(ses, auto_close)
return obj
[docs] @classmethod
def by_sql(
cls,
engine_or_session: Union[Engine, Session],
sql: Union[str, TextClause],
) -> List['ExtendedBase']:
"""
Query with sql statement or texture sql.
Examples::
class User(Base):
id = Column(Integer, primary_key)
name = Column(String)
with Session(engine) as session:
user_list = [
User(id=1, name="Alice"),
User(id=2, name="Bob"),
User(id=3, name="Cathy"),
]
session.add_all(user_list)
session.commit()
results = User.by_sql(
"SELECT * FROM extended_declarative_base_user",
engine,
)
# [User(id=1, name="Alice"), User(id=2, name="Bob"), User(id=3, name="Cathy")]
print(results)
**中文文档**
一个简单的语法糖, 允许用户直接用 SQL 的字符串进行查询.
"""
if isinstance(sql, str):
sql_stmt = text(sql)
elif isinstance(sql, TextClause):
sql_stmt = sql
else: # pragma: no cover
raise TypeError
ses, auto_close = ensure_session(engine_or_session)
results = ses.scalars(select(cls).from_statement(sql_stmt)).all()
clean_session(ses, auto_close)
return results
[docs] @classmethod
def smart_insert(
cls,
engine_or_session: Union[Engine, Session],
obj_or_objs: Union['ExtendedBase', List['ExtendedBase']],
minimal_size: int = 5,
_op_counter: int = 0,
_insert_counter: int = 0,
) -> Tuple[int, int]:
"""
An optimized Insert strategy.
\
:param minimal_size: internal bulk size for each attempts
:param _op_counter: number of successful bulk INSERT sql invoked
:param _insert_counter: number of successfully inserted objects.
:return: number of bulk INSERT sql invoked. Usually it is
greatly smaller than ``len(data)``. and also return the number of
successfully inserted objects.
.. warning::
This operation is not atomic, if you force stop the program,
then it could be only partially completed
**中文文档**
在Insert中, 如果已经预知不会出现IntegrityError, 那么使用Bulk Insert的速度要
远远快于逐条Insert。而如果无法预知, 那么我们采用如下策略:
1. 尝试Bulk Insert, Bulk Insert由于在结束前不Commit, 所以速度很快。
2. 如果失败了, 那么对数据的条数开平方根, 进行分包, 然后对每个包重复该逻辑。
3. 若还是尝试失败, 则继续分包, 当分包的大小小于一定数量时, 则使用逐条插入。
直到成功为止。
该 Insert 策略在内存上需要额外的 sqrt(n) 的开销, 跟原数据相比体积很小。
但时间上是各种情况下平均最优的。
1.4 以后的重要变化: session 变得更聪明了.
"""
ses, auto_close = ensure_session(engine_or_session)
if isinstance(obj_or_objs, list):
# 首先进行尝试bulk insert
try:
ses.add_all(obj_or_objs)
ses.commit()
_op_counter += 1
_insert_counter += len(obj_or_objs)
# 失败了
except (IntegrityError, FlushError):
ses.rollback()
# 分析数据量
n = len(obj_or_objs)
# 如果数据条数多于一定数量
if n >= minimal_size ** 2:
# 则进行分包
n_chunk = math.floor(math.sqrt(n))
for chunk in grouper_list(obj_or_objs, n_chunk):
(
_op_counter,
_insert_counter,
) = cls.smart_insert(
engine_or_session=ses,
obj_or_objs=chunk,
minimal_size=minimal_size,
_op_counter=_op_counter,
_insert_counter=_insert_counter,
)
# 否则则一条条地逐条插入
else:
for obj in obj_or_objs:
try:
ses.add(obj)
ses.commit()
_op_counter += 1
_insert_counter += 1
except (IntegrityError, FlushError):
ses.rollback()
else:
try:
ses.add(obj_or_objs)
ses.commit()
_op_counter += 1
_insert_counter += 1
except (IntegrityError, FlushError):
ses.rollback()
clean_session(ses, auto_close)
return _op_counter, _insert_counter
[docs] @classmethod
def update_all(
cls,
engine_or_session: Union[Engine, Session],
obj_or_objs: Union['ExtendedBase', List['ExtendedBase']],
include_null: bool = True,
upsert: bool = False,
) -> Tuple[int, int]:
"""
The :meth:`sqlalchemy.crud.updating.update_all` function in ORM syntax.
This operation **IS NOT ATOMIC**. It is a greedy operation, trying to
update as much as it can.
:param engine_or_session: an engine created by``sqlalchemy.create_engine``.
:param obj_or_objs: single object or list of object
:param include_null: update those None value field or not
:param upsert: if True, then do insert also.
:return: number of row been changed
"""
update_counter = 0
insert_counter = 0
ses, auto_close = ensure_session(engine_or_session)
obj_or_objs = ensure_list(obj_or_objs) # type: List[ExtendedBase]
objs_to_insert = list()
for obj in obj_or_objs:
res = ses.execute(
update(cls).
where(*[
field == value
for field, value in zip(obj.pk_fields(), obj.pk_values())
]).
values(**obj.to_dict(include_null=include_null))
)
if res.rowcount:
update_counter += 1
else:
objs_to_insert.append(obj)
if upsert:
try:
ses.add_all(objs_to_insert)
ses.commit()
insert_counter += len(objs_to_insert)
except (IntegrityError, FlushError): # pragma: no cover
ses.rollback()
else:
ses.commit()
clean_session(ses, auto_close)
return update_counter, insert_counter
[docs] @classmethod
def upsert_all(
cls,
engine_or_session: Union[Engine, Session],
obj_or_objs: Union['ExtendedBase', List['ExtendedBase']],
include_null: bool = True,
) -> Tuple[int, int]:
"""
The :meth:`sqlalchemy.crud.updating.upsert_all` function in ORM syntax.
:param engine_or_session: an engine created by``sqlalchemy.create_engine``.
:param obj_or_objs: single object or list of object
:param include_null: update those None value field or not
:return: number of row been changed
"""
return cls.update_all(
engine_or_session=engine_or_session,
obj_or_objs=obj_or_objs,
include_null=include_null,
upsert=True,
)
[docs] @classmethod
def delete_all(
cls,
engine_or_session: Union[Engine, Session],
): # pragma: no cover
"""
Delete all data in this table.
TODO: add a boolean flag for cascade remove
"""
ses, auto_close = ensure_session(engine_or_session)
ses.execute(cls.__table__.delete())
ses.commit()
clean_session(ses, auto_close)
[docs] @classmethod
def count_all(
cls,
engine_or_session: Union[Engine, Session],
) -> int:
"""
Return number of rows in this table.
"""
ses, auto_close = ensure_session(engine_or_session)
count = ses.execute(select(func.count()).select_from(cls)).one()[0]
clean_session(ses, auto_close)
return count
[docs] @classmethod
def select_all(
cls,
engine_or_session: Union[Engine, Session],
) -> List['ExtendedBase']:
"""
"""
ses, auto_close = ensure_session(engine_or_session)
results = ses.scalars(select(cls)).all()
clean_session(ses, auto_close)
return results
[docs] @classmethod
def random_sample(
cls,
engine_or_session: Union[Engine, Session],
limit: int = None,
perc: int = None,
) -> List['ExtendedBase']:
"""
Return random ORM instance.
:rtype: List[ExtendedBase]
"""
ses, auto_close = ensure_session(engine_or_session)
ensure_exact_one_arg_is_not_none(limit, perc)
if limit is not None:
results = ses.scalars(
select(cls).order_by(func.random()).limit(limit)
).all()
elif perc is not None:
selectable = cls.__table__.tablesample(
func.bernoulli(perc),
name="alias",
seed=func.random()
)
args = [
getattr(selectable.c, column.name)
for column in cls.__table__.columns
]
stmt = select(*args)
results = [cls(**dict(row)) for row in ses.execute(stmt)]
else:
raise ValueError
clean_session(ses, auto_close)
return results