diff --git a/cli.py b/cli.py index aa82a73..e2cd7e4 100755 --- a/cli.py +++ b/cli.py @@ -78,18 +78,48 @@ def tabulate_events(events): return tabulate(tab, headers='firstrow') +def get_overlapping_events(session, starts_at, ends_at): + q = session.query(Event) + + if ends_at is None: + q = q.filter(Event.starts_at <= starts_at) + else: + q = q.filter(Event.ends_at >= starts_at)\ + .filter(Event.starts_at <= ends_at) + + return q.all() + + @event.command('add') @click.argument('name') -@click.option('--start') -@click.option('--end') -def event_add(name, start, end): - start = datetime.strptime(start, "%Y-%m-%d %H:%M") - end = datetime.strptime(end, "%Y-%m-%d %H:%M") - event = Event(name=name, starts_at=start, ends_at=end) +@click.argument('starts_at') +@click.argument('ends_at', required=False) +def event_add(name, starts_at, ends_at): + starts_at = datetime.strptime(starts_at, "%Y-%m-%d %H:%M") + ends_at = (datetime.strptime(ends_at, "%Y-%m-%d %H:%M") + if ends_at else None) + + if ends_at and starts_at >= ends_at: + print("Could now add event: specified start date ({}) " + "is past the end date ({})." + .format(starts_at.strftime("%Y-%m-%d %H:%M"), + ends_at.strftime("%Y-%m-%d %H:%M"))) + return with db.get_session() as session: + events = get_overlapping_events(session, starts_at, ends_at) + if events: + print("Could not add event: another event is overlapping the date " + "range you have specified.") + print(tabulate_events(events)) + return + + with db.get_session() as session: + event = Event(name=name, starts_at=starts_at, ends_at=ends_at) session.add(event) + session.flush() print("Event succesfully added.") + print(tabulate_events([event])) @event.command('list') @@ -103,6 +133,71 @@ def event_list(): print("No events found.") +@event.command('set') +@click.option('-n', '--name') +@click.option('-s', '--start') +@click.option('-e', '--end') +@click.argument('event_uid') +def event_set(event_uid, name, start, end): + with db.get_session() as session: + event = session.query(Event).get(event_uid) + + if not event: + print("No event found with id #{}.".format(event_uid)) + return + + if name: + event.name = name + + if start: + starts_at = datetime.strptime(start, "%Y-%m-%d %H:%M") + + if starts_at >= event.ends_at: + print("Could not edit event #{}: specified start date ({}) " + "is past the end date ({})" + .format(event.uid, + starts_at.strftime("%Y-%m-%d %H:%M"), + event.ends_at.strftime("%Y-%m-%d %H:%M"))) + return + + event.starts_at = starts_at + + if end: + if end == 'none': + event.ends_at = None + elif end == 'now': + event.ends_at = datetime.now() + else: + ends_at = datetime.strptime(end, "%Y-%m-%d %H:%M") + + if ends_at <= event.starts_at: + print("Could not edit event #{}: specified end date ({}) " + "is before the start date ({})" + .format(event.uid, + ends_at.strftime("%Y-%m-%d %H:%M"), + event.starts_at.strftime("%Y-%m-%d %H:%M"))) + return + + event.ends_at = datetime.strptime(end, "%Y-%m-%d %H:%M") + + if event.starts_at and event.ends_at: + with db.get_session() as session: + events = get_overlapping_events(session, + event.starts_at, event.ends_at) + if events: + print("Could not edit event: another event is overlapping the " + "date range you have specified.") + print(tabulate_events(events)) + return + + if any([name, start, end]): + with db.get_session() as session: + session.add(event) + session.flush() + print("Event succesfully edited.") + print(tabulate_events([event])) + + @cli.group('product') def product(): pass