from datetime import datetime, timedelta from uuid import uuid4 from contextlib import contextmanager import sqlalchemy as sa import sqlalchemy.ext.declarative from sqlalchemy.exc import SQLAlchemyError from sqlalchemy import Column, ForeignKey from sqlalchemy import Integer, String, Boolean, DateTime from sqlalchemy.orm import relationship from sqlalchemy_utils import force_auto_coercion from sqlalchemy_utils.types.password import PasswordType from pos.logging import get_logger log = get_logger('database') # The database URL must follow RFC 1738 in the form # dialect+driver://username:password@host:port/database ENGINE_GENERIC = "{engine}://{user}:{password}@{host}:{port}/{database}"\ "?charset=utf8" ENGINE_SQLITE = "sqlite:///{path}" ENGINE_SQLITE_MEMORY = "sqlite://" PASSWORD_SCHEMES = ['pbkdf2_sha512'] Base = sqlalchemy.ext.declarative.declarative_base() class Database: """ Handle database operations." """ Session = None engine = None def __init__(self, **kwargs): """ Initialize database connection. :param engine: The SQLAlchemy database backend in the form dialect+driver where dialect is the name of a SQLAlchemy dialect (sqlite, mysql, postgresql, oracle or mssql) and driver is the name of the DBAPI in all lowercase letters. If driver is not specified the default DBAPI will be imported if available. :param path: Only for SQLite. Path to database. If not specified the database will be kept in memory (should be used only for testing). :param host: :param port: :param database: :param user: :param password: """ if kwargs['engine'] == 'sqlite': if 'path' in kwargs: url = ENGINE_SQLITE.format(path=kwargs['path']) else: url = ENGINE_SQLITE_MEMORY else: url = ENGINE_GENERIC.format(**kwargs) self.engine = sa.create_engine(url) self.Session = sa.orm.sessionmaker( bind=self.engine, expire_on_commit=False ) Base.metadata.create_all(self.engine) force_auto_coercion() @contextmanager def get_session(self): session = self.Session() try: yield session except SQLAlchemyError as e: log.critical("Error performing transaction: {}".format(e)) session.rollback() else: session.commit() finally: session.close() class User(Base): __tablename__ = 'users' uid = Column(Integer, primary_key=True) username = Column(String, nullable=False, unique=True) password = Column(PasswordType(schemes=PASSWORD_SCHEMES), nullable=False) is_active = Column(Boolean, nullable=False, server_default='1') created_at = Column(DateTime, nullable=False, default=datetime.now) class Event(Base): __tablename__ = 'events' uid = Column(Integer, primary_key=True) name = Column(String, nullable=False) starts_at = Column(DateTime, nullable=False, default=datetime.now) ends_at = Column(DateTime) created_at = Column(DateTime, nullable=False, default=datetime.now) transactions = relationship('Transaction', lazy='joined') class ProductCategory(Base): __tablename__ = 'product_categories' uid = Column(Integer, primary_key=True) name = Column(String, nullable=False) sort = Column(Integer, nullable=False, server_default='0') created_at = Column(DateTime, nullable=False, default=datetime.now) products = relationship('Product', lazy='joined') class Product(Base): __tablename__ = 'products' uid = Column(Integer, primary_key=True) name = Column(String, nullable=False) price = Column(Integer, nullable=False) sort = Column(Integer, nullable=False, server_default='0') category_uid = Column(Integer, ForeignKey('product_categories.uid'), nullable=False) is_active = Column(Boolean, nullable=False, server_default='1') created_at = Column(DateTime, nullable=False, default=datetime.now) category = relationship('ProductCategory', lazy='joined') class Transaction(Base): __tablename__ = 'transactions' uid = Column(Integer, primary_key=True) event_uid = Column(Integer, ForeignKey('events.uid'), nullable=False) created_at = Column(DateTime, nullable=False, default=datetime.now) event = relationship('Event', lazy='joined') orders = relationship('Order', lazy='joined') class Order(Base): __tablename__ = 'orders' uid = Column(Integer, primary_key=True) product_uid = Column(Integer, ForeignKey('products.uid'), nullable=False) quantity = Column(Integer, nullable=False) transaction_uid = Column(Integer, ForeignKey('transactions.uid'), nullable=False) product = relationship('Product', lazy='joined') transaction = relationship('Transaction', lazy='joined') class AccessToken(Base): __tablename__ = 'access_tokens' uid = Column(Integer, primary_key=True) user_uid = Column(Integer, ForeignKey('users.uid'), nullable=False) token = Column(String(36), nullable=False, default=str(uuid4())) is_active = Column(Boolean, nullable=False, server_default='1') created_at = Column(DateTime, nullable=False, default=datetime.now) expires_at = Column(DateTime, nullable=False, default=(datetime.now() + timedelta(days=2))) user = relationship('User', lazy='joined') def is_valid(self): return all([ self.is_active, self.created_at < datetime.now(), self.expires_at > datetime.now() ])