diff --git a/pre_commit/main.py b/pre_commit/main.py index 8d2d6302..59de5f24 100644 --- a/pre_commit/main.py +++ b/pre_commit/main.py @@ -55,12 +55,25 @@ def _add_config_option(parser): ) +class AppendReplaceDefault(argparse.Action): + def __init__(self, *args, **kwargs): + super(AppendReplaceDefault, self).__init__(*args, **kwargs) + self.appended = False + + def __call__(self, parser, namespace, values, option_string=None): + if not self.appended: + setattr(namespace, self.dest, []) + self.appended = True + getattr(namespace, self.dest).append(values) + + def _add_hook_type_option(parser): parser.add_argument( '-t', '--hook-type', choices=( 'pre-commit', 'pre-push', 'prepare-commit-msg', 'commit-msg', ), - action='append', + action=AppendReplaceDefault, + default=['pre-commit'], dest='hook_types', ) @@ -121,11 +134,6 @@ def _adjust_args_and_chdir(args): args.files = [os.path.relpath(filename) for filename in args.files] if args.command == 'try-repo' and os.path.exists(args.repo): args.repo = os.path.relpath(args.repo) - if ( - args.command in {'install', 'uninstall', 'init-templatedir'} and - not args.hook_types - ): - args.hook_types = ['pre-commit'] def main(argv=None): diff --git a/tests/main_test.py b/tests/main_test.py index aad9c4b9..364e0d39 100644 --- a/tests/main_test.py +++ b/tests/main_test.py @@ -13,6 +13,20 @@ from pre_commit.error_handler import FatalError from testing.auto_namedtuple import auto_namedtuple +@pytest.mark.parametrize( + ('argv', 'expected'), + ( + ((), ['f']), + (('--f', 'x'), ['x']), + (('--f', 'x', '--f', 'y'), ['x', 'y']), + ), +) +def test_append_replace_default(argv, expected): + parser = argparse.ArgumentParser() + parser.add_argument('--f', action=main.AppendReplaceDefault, default=['f']) + assert parser.parse_args(argv).f == expected + + class Args(object): def __init__(self, **kwargs): kwargs.setdefault('command', 'help')