diff --git a/pre_commit/clientlib.py b/pre_commit/clientlib.py index e69359b0..7fb49d78 100644 --- a/pre_commit/clientlib.py +++ b/pre_commit/clientlib.py @@ -130,6 +130,7 @@ CONFIG_SCHEMA = schema.Map( 'Config', None, schema.RequiredRecurse('repos', schema.Array(CONFIG_REPO_DICT)), + schema.Optional('fail_fast', schema.check_bool, False), ) diff --git a/pre_commit/commands/autoupdate.py b/pre_commit/commands/autoupdate.py index 17588cc3..4dce674f 100644 --- a/pre_commit/commands/autoupdate.py +++ b/pre_commit/commands/autoupdate.py @@ -9,10 +9,12 @@ from aspy.yaml import ordered_load import pre_commit.constants as C from pre_commit import output +from pre_commit.clientlib import CONFIG_SCHEMA from pre_commit.clientlib import is_local_repo from pre_commit.clientlib import load_config from pre_commit.commands.migrate_config import migrate_config from pre_commit.repository import Repository +from pre_commit.schema import remove_defaults from pre_commit.util import CalledProcessError from pre_commit.util import cmd_output from pre_commit.util import cwd @@ -71,6 +73,7 @@ SHA_LINE_FMT = '{}sha:{}{}{}' def _write_new_config_file(path, output): original_contents = open(path).read() + output = remove_defaults(output, CONFIG_SCHEMA) new_contents = ordered_dump(output, **C.YAML_DUMP_KWARGS) lines = original_contents.splitlines(True) @@ -95,7 +98,7 @@ def _write_new_config_file(path, output): # If we failed to intelligently rewrite the sha lines, fall back to the # pretty-formatted yaml output to_write = ''.join(lines) - if ordered_load(to_write) != output: + if remove_defaults(ordered_load(to_write), CONFIG_SCHEMA) != output: to_write = new_contents with open(path, 'w') as f: diff --git a/pre_commit/commands/run.py b/pre_commit/commands/run.py index 99232585..505bb54d 100644 --- a/pre_commit/commands/run.py +++ b/pre_commit/commands/run.py @@ -169,13 +169,15 @@ def _compute_cols(hooks, verbose): return max(cols, 80) -def _run_hooks(repo_hooks, args, environ): +def _run_hooks(config, repo_hooks, args, environ): """Actually run the hooks.""" skips = _get_skips(environ) cols = _compute_cols([hook for _, hook in repo_hooks], args.verbose) retval = 0 for repo, hook in repo_hooks: retval |= _run_single_hook(hook, repo, args, skips, cols) + if retval and config['fail_fast']: + break if ( retval and args.show_diff_on_failure and @@ -251,4 +253,4 @@ def run(runner, args, environ=os.environ): if not hook['stages'] or args.hook_stage in hook['stages'] ] - return _run_hooks(repo_hooks, args, environ) + return _run_hooks(runner.config, repo_hooks, args, environ) diff --git a/pre_commit/runner.py b/pre_commit/runner.py index 346d6021..d853868a 100644 --- a/pre_commit/runner.py +++ b/pre_commit/runner.py @@ -37,10 +37,14 @@ class Runner(object): def config_file_path(self): return os.path.join(self.git_root, self.config_file) + @cached_property + def config(self): + return load_config(self.config_file_path) + @cached_property def repositories(self): """Returns a tuple of the configured repositories.""" - repos = load_config(self.config_file_path)['repos'] + repos = self.config['repos'] repos = tuple(Repository.create(x, self.store) for x in repos) for repo in repos: repo.require_installed() diff --git a/pre_commit/schema.py b/pre_commit/schema.py index f033071f..e20f74cc 100644 --- a/pre_commit/schema.py +++ b/pre_commit/schema.py @@ -64,6 +64,11 @@ def _apply_default_optional(self, dct): dct.setdefault(self.key, self.default) +def _remove_default_optional(self, dct): + if dct.get(self.key, MISSING) == self.default: + del dct[self.key] + + def _require_key(self, dct): if self.key not in dct: raise ValidationError('Missing required key: {}'.format(self.key)) @@ -85,6 +90,10 @@ def _apply_default_required_recurse(self, dct): dct[self.key] = apply_defaults(dct[self.key], self.schema) +def _remove_default_required_recurse(self, dct): + dct[self.key] = remove_defaults(dct[self.key], self.schema) + + def _check_conditional(self, dct): if dct.get(self.condition_key, MISSING) == self.condition_value: _check_required(self, dct) @@ -110,18 +119,22 @@ def _check_conditional(self, dct): Required = collections.namedtuple('Required', ('key', 'check_fn')) Required.check = _check_required Required.apply_default = _dct_noop +Required.remove_default = _dct_noop RequiredRecurse = collections.namedtuple('RequiredRecurse', ('key', 'schema')) RequiredRecurse.check = _check_required RequiredRecurse.check_fn = _check_fn_required_recurse RequiredRecurse.apply_default = _apply_default_required_recurse +RequiredRecurse.remove_default = _remove_default_required_recurse Optional = collections.namedtuple('Optional', ('key', 'check_fn', 'default')) Optional.check = _check_optional Optional.apply_default = _apply_default_optional +Optional.remove_default = _remove_default_optional OptionalNoDefault = collections.namedtuple( 'OptionalNoDefault', ('key', 'check_fn'), ) OptionalNoDefault.check = _check_optional OptionalNoDefault.apply_default = _dct_noop +OptionalNoDefault.remove_default = _dct_noop Conditional = collections.namedtuple( 'Conditional', ('key', 'check_fn', 'condition_key', 'condition_value', 'ensure_absent'), @@ -129,6 +142,7 @@ Conditional = collections.namedtuple( Conditional.__new__.__defaults__ = (False,) Conditional.check = _check_conditional Conditional.apply_default = _dct_noop +Conditional.remove_default = _dct_noop class Map(collections.namedtuple('Map', ('object_name', 'id_key', 'items'))): @@ -158,6 +172,12 @@ class Map(collections.namedtuple('Map', ('object_name', 'id_key', 'items'))): item.apply_default(ret) return ret + def remove_defaults(self, v): + ret = v.copy() + for item in self.items: + item.remove_default(ret) + return ret + class Array(collections.namedtuple('Array', ('of',))): __slots__ = () @@ -174,6 +194,9 @@ class Array(collections.namedtuple('Array', ('of',))): def apply_defaults(self, v): return [apply_defaults(val, self.of) for val in v] + def remove_defaults(self, v): + return [remove_defaults(val, self.of) for val in v] + class Not(object): def __init__(self, val): @@ -238,6 +261,10 @@ def apply_defaults(v, schema): return schema.apply_defaults(v) +def remove_defaults(v, schema): + return schema.remove_defaults(v) + + def load_from_filename(filename, schema, load_strategy, exc_tp): with reraise_as(exc_tp): if not os.path.exists(filename): diff --git a/tests/commands/autoupdate_test.py b/tests/commands/autoupdate_test.py index 7fb21b9d..2877c5b3 100644 --- a/tests/commands/autoupdate_test.py +++ b/tests/commands/autoupdate_test.py @@ -275,7 +275,7 @@ def test_autoupdate_local_hooks(tempdir_factory): runner = Runner(path, C.CONFIG_FILE) assert autoupdate(runner, tags_only=False) == 0 new_config_writen = load_config(runner.config_file_path) - assert len(new_config_writen) == 1 + assert len(new_config_writen['repos']) == 1 assert new_config_writen['repos'][0] == config diff --git a/tests/commands/run_test.py b/tests/commands/run_test.py index 39d3ac0b..53e098b0 100644 --- a/tests/commands/run_test.py +++ b/tests/commands/run_test.py @@ -729,3 +729,18 @@ def test_pass_filenames( ) assert expected_out + b'\nHello World' in printed assert (b'foo.py' in printed) == pass_filenames + + +def test_fail_fast( + cap_out, repo_with_failing_hook, mock_out_store_directory, +): + with cwd(repo_with_failing_hook): + with modify_config() as config: + # More than one hook + config['fail_fast'] = True + config['repos'][0]['hooks'] *= 2 + stage_a_file() + + ret, printed = _do_run(cap_out, repo_with_failing_hook, _get_opts()) + # it should have only run one hook + assert printed.count(b'Failing hook') == 1 diff --git a/tests/schema_test.py b/tests/schema_test.py index c133a997..c2ecf0fa 100644 --- a/tests/schema_test.py +++ b/tests/schema_test.py @@ -21,6 +21,7 @@ from pre_commit.schema import MISSING from pre_commit.schema import Not from pre_commit.schema import Optional from pre_commit.schema import OptionalNoDefault +from pre_commit.schema import remove_defaults from pre_commit.schema import Required from pre_commit.schema import RequiredRecurse from pre_commit.schema import validate @@ -280,6 +281,37 @@ def test_apply_defaults_map_in_list(): assert ret == [{'key': False}] +def test_remove_defaults_copies_object(): + val = {'key': False} + ret = remove_defaults(val, map_optional) + assert ret is not val + + +def test_remove_defaults_removes_defaults(): + ret = remove_defaults({'key': False}, map_optional) + assert ret == {} + + +def test_remove_defaults_nothing_to_remove(): + ret = remove_defaults({}, map_optional) + assert ret == {} + + +def test_remove_defaults_does_not_change_non_default(): + ret = remove_defaults({'key': True}, map_optional) + assert ret == {'key': True} + + +def test_remove_defaults_map_in_list(): + ret = remove_defaults([{'key': False}], Array(map_optional)) + assert ret == [{}] + + +def test_remove_defaults_does_nothing_on_non_optional(): + ret = remove_defaults({'key': True}, map_required) + assert ret == {'key': True} + + nested_schema_required = Map( 'Repository', 'repo', Required('repo', check_any), @@ -310,6 +342,12 @@ def test_apply_defaults_nested(): assert ret == {'repo': 'repo1', 'hooks': [{'key': False}]} +def test_remove_defaults_nested(): + val = {'repo': 'repo1', 'hooks': [{'key': False}]} + ret = remove_defaults(val, nested_schema_optional) + assert ret == {'repo': 'repo1', 'hooks': [{}]} + + class Error(Exception): pass