diff --git a/pre_commit/commands/autoupdate.py b/pre_commit/commands/autoupdate.py index 620a8a6e..36df87f8 100644 --- a/pre_commit/commands/autoupdate.py +++ b/pre_commit/commands/autoupdate.py @@ -1,6 +1,7 @@ from __future__ import print_function from __future__ import unicode_literals +import re from collections import OrderedDict from aspy.yaml import ordered_dump @@ -8,11 +9,9 @@ from aspy.yaml import ordered_load import pre_commit.constants as C from pre_commit import output -from pre_commit.clientlib import CONFIG_SCHEMA from pre_commit.clientlib import is_local_repo from pre_commit.clientlib import load_config from pre_commit.repository import Repository -from pre_commit.schema import remove_defaults from pre_commit.util import CalledProcessError from pre_commit.util import cmd_output from pre_commit.util import cwd @@ -65,6 +64,43 @@ def _update_repo(repo_config, runner, tags_only): return new_config +SHA_LINE_RE = re.compile(r'^(\s+)sha:(\s*)([^\s#]+)(.*)$', re.DOTALL) +SHA_LINE_FMT = '{}sha:{}{}{}' + + +def _write_new_config_file(path, output): + original_contents = open(path).read() + new_contents = ordered_dump(output, **C.YAML_DUMP_KWARGS) + + lines = original_contents.splitlines(True) + sha_line_indices_rev = list(reversed([ + i for i, line in enumerate(lines) if SHA_LINE_RE.match(line) + ])) + + for line in new_contents.splitlines(True): + if SHA_LINE_RE.match(line): + # It's possible we didn't identify the sha lines in the original + if not sha_line_indices_rev: + break + line_index = sha_line_indices_rev.pop() + original_line = lines[line_index] + orig_match = SHA_LINE_RE.match(original_line) + new_match = SHA_LINE_RE.match(line) + lines[line_index] = SHA_LINE_FMT.format( + orig_match.group(1), orig_match.group(2), + new_match.group(3), orig_match.group(4), + ) + + # If we failed to intelligently rewrite the sha lines, fall back to the + # pretty-formatted yaml output + to_write = ''.join(lines) + if ordered_load(to_write) != output: + to_write = new_contents + + with open(path, 'w') as f: + f.write(to_write) + + def autoupdate(runner, tags_only): """Auto-update the pre-commit config to the latest versions of repos.""" retv = 0 @@ -100,10 +136,6 @@ def autoupdate(runner, tags_only): output_configs.append(repo_config) if changed: - with open(runner.config_file_path, 'w') as config_file: - config_file.write(ordered_dump( - remove_defaults(output_configs, CONFIG_SCHEMA), - **C.YAML_DUMP_KWARGS - )) + _write_new_config_file(runner.config_file_path, output_configs) return retv diff --git a/pre_commit/schema.py b/pre_commit/schema.py index 5f22277d..a911bb43 100644 --- a/pre_commit/schema.py +++ b/pre_commit/schema.py @@ -64,11 +64,6 @@ def _apply_default_optional(self, dct): dct.setdefault(self.key, self.default) -def _remove_default_optional(self, dct): - if dct.get(self.key, MISSING) == self.default: - del dct[self.key] - - def _require_key(self, dct): if self.key not in dct: raise ValidationError('Missing required key: {}'.format(self.key)) @@ -90,10 +85,6 @@ def _apply_default_required_recurse(self, dct): dct[self.key] = apply_defaults(dct[self.key], self.schema) -def _remove_default_required_recurse(self, dct): - dct[self.key] = remove_defaults(dct[self.key], self.schema) - - def _check_conditional(self, dct): if dct.get(self.condition_key, MISSING) == self.condition_value: _check_required(self, dct) @@ -119,22 +110,18 @@ def _check_conditional(self, dct): Required = collections.namedtuple('Required', ('key', 'check_fn')) Required.check = _check_required Required.apply_default = _dct_noop -Required.remove_default = _dct_noop RequiredRecurse = collections.namedtuple('RequiredRecurse', ('key', 'schema')) RequiredRecurse.check = _check_required RequiredRecurse.check_fn = _check_fn_required_recurse RequiredRecurse.apply_default = _apply_default_required_recurse -RequiredRecurse.remove_default = _remove_default_required_recurse Optional = collections.namedtuple('Optional', ('key', 'check_fn', 'default')) Optional.check = _check_optional Optional.apply_default = _apply_default_optional -Optional.remove_default = _remove_default_optional OptionalNoDefault = collections.namedtuple( 'OptionalNoDefault', ('key', 'check_fn'), ) OptionalNoDefault.check = _check_optional OptionalNoDefault.apply_default = _dct_noop -OptionalNoDefault.remove_default = _dct_noop Conditional = collections.namedtuple( 'Conditional', ('key', 'check_fn', 'condition_key', 'condition_value', 'ensure_absent'), @@ -142,7 +129,6 @@ Conditional = collections.namedtuple( Conditional.__new__.__defaults__ = (False,) Conditional.check = _check_conditional Conditional.apply_default = _dct_noop -Conditional.remove_default = _dct_noop class Map(collections.namedtuple('Map', ('object_name', 'id_key', 'items'))): @@ -168,12 +154,6 @@ class Map(collections.namedtuple('Map', ('object_name', 'id_key', 'items'))): item.apply_default(ret) return ret - def remove_defaults(self, v): - ret = v.copy() - for item in self.items: - item.remove_default(ret) - return ret - class Array(collections.namedtuple('Array', ('of',))): __slots__ = () @@ -190,9 +170,6 @@ class Array(collections.namedtuple('Array', ('of',))): def apply_defaults(self, v): return [apply_defaults(val, self.of) for val in v] - def remove_defaults(self, v): - return [remove_defaults(val, self.of) for val in v] - class Not(object): def __init__(self, val): @@ -257,10 +234,6 @@ def apply_defaults(v, schema): return schema.apply_defaults(v) -def remove_defaults(v, schema): - return schema.remove_defaults(v) - - def load_from_filename(filename, schema, load_strategy, exc_tp): with reraise_as(exc_tp): if not os.path.exists(filename): diff --git a/tests/commands/autoupdate_test.py b/tests/commands/autoupdate_test.py index 8dac48c4..1920610a 100644 --- a/tests/commands/autoupdate_test.py +++ b/tests/commands/autoupdate_test.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import pipes import shutil from collections import OrderedDict @@ -123,6 +124,61 @@ def test_autoupdate_out_of_date_repo( assert out_of_date_repo.head_sha in after +def test_does_not_reformat( + out_of_date_repo, mock_out_store_directory, in_tmpdir, +): + fmt = ( + '- repo: {}\n' + ' sha: {} # definitely the version I want!\n' + ' hooks:\n' + ' - id: foo\n' + ' # These args are because reasons!\n' + ' args: [foo, bar, baz]\n' + ) + config = fmt.format(out_of_date_repo.path, out_of_date_repo.original_sha) + with open(C.CONFIG_FILE, 'w') as f: + f.write(config) + + autoupdate(Runner('.', C.CONFIG_FILE), tags_only=False) + after = open(C.CONFIG_FILE).read() + expected = fmt.format(out_of_date_repo.path, out_of_date_repo.head_sha) + assert after == expected + + +def test_loses_formatting_when_not_detectable( + out_of_date_repo, mock_out_store_directory, in_tmpdir, +): + """A best-effort attempt is made at updating sha without rewriting + formatting. When the original formatting cannot be detected, this + is abandoned. + """ + config = ( + '[\n' + ' {{\n' + ' repo: {}, sha: {},\n' + ' hooks: [\n' + ' # A comment!\n' + ' {{id: foo}},\n' + ' ],\n' + ' }}\n' + ']\n'.format( + pipes.quote(out_of_date_repo.path), out_of_date_repo.original_sha, + ) + ) + with open(C.CONFIG_FILE, 'w') as f: + f.write(config) + + autoupdate(Runner('.', C.CONFIG_FILE), tags_only=False) + after = open(C.CONFIG_FILE).read() + expected = ( + '- repo: {}\n' + ' sha: {}\n' + ' hooks:\n' + ' - id: foo\n' + ).format(out_of_date_repo.path, out_of_date_repo.head_sha) + assert after == expected + + @pytest.yield_fixture def tagged_repo(out_of_date_repo): with cwd(out_of_date_repo.path): diff --git a/tests/schema_test.py b/tests/schema_test.py index c2ecf0fa..c133a997 100644 --- a/tests/schema_test.py +++ b/tests/schema_test.py @@ -21,7 +21,6 @@ from pre_commit.schema import MISSING from pre_commit.schema import Not from pre_commit.schema import Optional from pre_commit.schema import OptionalNoDefault -from pre_commit.schema import remove_defaults from pre_commit.schema import Required from pre_commit.schema import RequiredRecurse from pre_commit.schema import validate @@ -281,37 +280,6 @@ def test_apply_defaults_map_in_list(): assert ret == [{'key': False}] -def test_remove_defaults_copies_object(): - val = {'key': False} - ret = remove_defaults(val, map_optional) - assert ret is not val - - -def test_remove_defaults_removes_defaults(): - ret = remove_defaults({'key': False}, map_optional) - assert ret == {} - - -def test_remove_defaults_nothing_to_remove(): - ret = remove_defaults({}, map_optional) - assert ret == {} - - -def test_remove_defaults_does_not_change_non_default(): - ret = remove_defaults({'key': True}, map_optional) - assert ret == {'key': True} - - -def test_remove_defaults_map_in_list(): - ret = remove_defaults([{'key': False}], Array(map_optional)) - assert ret == [{}] - - -def test_remove_defaults_does_nothing_on_non_optional(): - ret = remove_defaults({'key': True}, map_required) - assert ret == {'key': True} - - nested_schema_required = Map( 'Repository', 'repo', Required('repo', check_any), @@ -342,12 +310,6 @@ def test_apply_defaults_nested(): assert ret == {'repo': 'repo1', 'hooks': [{'key': False}]} -def test_remove_defaults_nested(): - val = {'repo': 'repo1', 'hooks': [{'key': False}]} - ret = remove_defaults(val, nested_schema_optional) - assert ret == {'repo': 'repo1', 'hooks': [{}]} - - class Error(Exception): pass