From 581270c6c1228b4d3570c28a27a08d3daa97a1f3 Mon Sep 17 00:00:00 2001 From: Blallo Date: Fri, 20 Nov 2020 11:16:07 +0100 Subject: [PATCH] Add flag on cli + black Copied from argparse==3.9 the BooleanOptionalAction, to allow one to have boolean (optional) flags on the cli. Also enforced black formatting. --- src/phi/cli.py | 15 ++++++++--- src/phi/compat/__init__.py | 0 src/phi/compat/argparse.py | 51 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 4 deletions(-) create mode 100644 src/phi/compat/__init__.py create mode 100644 src/phi/compat/argparse.py diff --git a/src/phi/cli.py b/src/phi/cli.py index d3de618..4eff430 100644 --- a/src/phi/cli.py +++ b/src/phi/cli.py @@ -2,16 +2,17 @@ import sys import argparse import inspect from phi.logging import setup_logging, get_logger +from phi.compat.argparse import BooleanOptionalAction log = get_logger(__name__) parser = argparse.ArgumentParser() -subparses = parser.add_subparsers(title='actions', dest='action') +subparses = parser.add_subparsers(title="actions", dest="action") cli_callbacks = {} -def register(action_info='', param_infos=[]): +def register(action_info="", param_infos=[]): def decorator(action): # Get function name and arguments action_name = action.__name__ @@ -21,7 +22,7 @@ def register(action_info='', param_infos=[]): subparser = subparses.add_parser(action_name, help=action_info) for i, name in enumerate(param_names): - info = param_infos[i] if i= 3.9 + +import argparse + + +class BooleanOptionalAction(argparse.Action): + def __init__( + self, + option_strings, + dest, + default=None, + type=None, + choices=None, + required=False, + help=None, + metavar=None, + ): + + _option_strings = [] + for option_string in option_strings: + _option_strings.append(option_string) + + if option_string.startswith("--"): + option_string = "--no-" + option_string[2:] + _option_strings.append(option_string) + + if help is not None and default is not None: + help += f" (default: {default})" + + super().__init__( + option_strings=_option_strings, + dest=dest, + nargs=0, + default=default, + type=type, + choices=choices, + required=required, + help=help, + metavar=metavar, + ) + + def __call__(self, parser, namespace, values, option_string=None): + if option_string in self.option_strings: + setattr(namespace, self.dest, not option_string.startswith("--no-")) + + def format_usage(self): + return " | ".join(self.option_strings)