from datetime import datetime from contextlib import contextmanager import sqlalchemy as sa import sqlalchemy.ext.declarative from sqlalchemy.exc import SQLAlchemyError from sqlalchemy import Column, ForeignKey from sqlalchemy import Integer, String, Boolean, DateTime, Numeric from sqlalchemy.orm import relationship from sqlalchemy_utils import force_auto_coercion from sqlalchemy_utils.types.password import PasswordType from autogestionale.logging import get_logger # The database URL must follow RFC 1738 in the form # dialect+driver://username:password@host:port/database ENGINE_GENERIC = "{engine}://{user}:{password}@{host}:{port}/{database}"\ "?charset=utf8" ENGINE_SQLITE = "sqlite:///{path}" ENGINE_SQLITE_MEMORY = "sqlite://" PASSWORD_SCHEMES = ['pbkdf2_sha512'] Base = sqlalchemy.ext.declarative.declarative_base() log = get_logger('database') force_auto_coercion() class Database: """ Handle database operations." """ Session = None engine = None def __init__(self, **kwargs): """ Initialize database connection. :param engine: The SQLAlchemy database backend in the form dialect+driver where dialect is the name of a SQLAlchemy dialect (sqlite, mysql, postgresql, oracle or mssql) and driver is the name of the DBAPI in all lowercase letters. If driver is not specified the default DBAPI will be imported if available. :param path: Only for SQLite. Path to database. If not specified the database will be kept in memory (should be used only for testing). :param host: :param port: :param database: :param user: :param password: """ if kwargs['engine'] == 'sqlite': if 'path' in kwargs: url = ENGINE_SQLITE.format(path=kwargs['path']) else: url = ENGINE_SQLITE_MEMORY else: url = ENGINE_GENERIC.format(**kwargs) self.engine = sa.create_engine(url) self.Session = sa.orm.sessionmaker( bind=self.engine, expire_on_commit=False ) Base.metadata.create_all(self.engine) @contextmanager def get_session(self): session = self.Session() try: yield session except SQLAlchemyError as e: log.critical("Error performing transaction:") log.critical(e) session.rollback() else: session.commit() finally: session.close() class User(Base): __tablename__ = 'users' uid = Column(Integer, primary_key=True) username = Column(String, nullable=False, unique=True) password = Column(PasswordType(schemes=PASSWORD_SCHEMES), nullable=False) is_active = Column(Boolean, nullable=False, server_default='1') is_authenticated = Column(Boolean, nullable=False, server_default='0') created_at = Column(DateTime, nullable=False, default=datetime.now) def get_id(self): return u'{}'.format(self.uid) class Entry(Base): __tablename__ = 'entry' uid = Column(Integer, primary_key=True) amount = Column(Numeric(precision=3)) description = Column(String, nullable=False) # entry_category_uid = Column(Integer, ForeignKey('entry_category.uid'), # nullable=False) created_at = Column(DateTime, nullable=False, default=datetime.now) event_uid = Column(Integer, ForeignKey('event.uid'), nullable=False) class Event(Base): __tablename__ = 'event' uid = Column(Integer, primary_key=True) name = Column(String, nullable=False) starts_at = Column(DateTime, nullable=False, default=datetime.now) ends_at = Column(DateTime) created_at = Column(DateTime, nullable=False, default=datetime.now) group_uid = Column(Integer, ForeignKey('group.uid'), nullable=False) entries = relationship('Entry', lazy='joined') def get_balance(self): _ret = 0 for ent in self.entries: _ret += ent.amount return _ret class Group(Base): __tablename__ = 'group' uid = Column(Integer, primary_key=True) name = Column(String, nullable=False) description = Column(String, nullable=False) created_at = Column(DateTime, nullable=False, default=datetime.now) events = relationship('Event', lazy='joined') # def to_json(self): # return { # 'uid': self.uid, # 'name': self.name, # 'description': self.description, # 'created_at': self.created_at.isoformat(), # 'events': [{ # 'uid': evt.uid, # 'name': evt.name # } for evt in self.events] # } class ProductCategory(Base): __tablename__ = 'product_categories' uid = Column(Integer, primary_key=True) name = Column(String, nullable=False) sort = Column(Integer, nullable=False, server_default='0') created_at = Column(DateTime, nullable=False, default=datetime.now) products = relationship('Product', lazy='joined') class Product(Base): __tablename__ = 'products' uid = Column(Integer, primary_key=True) name = Column(String, nullable=False) price = Column(Integer, nullable=False) sort = Column(Integer, nullable=False, server_default='0') category_uid = Column(Integer, ForeignKey('product_categories.uid'), nullable=False) is_active = Column(Boolean, nullable=False, server_default='1') created_at = Column(DateTime, nullable=False, default=datetime.now) category = relationship('ProductCategory', lazy='joined')