diff --git a/pre_commit/clientlib/validate_config.py b/pre_commit/clientlib/validate_config.py index f77d10f5..039f0aeb 100644 --- a/pre_commit/clientlib/validate_config.py +++ b/pre_commit/clientlib/validate_config.py @@ -6,6 +6,7 @@ import re import pre_commit.constants as C from pre_commit.clientlib.validate_base import get_validator +from pre_commit.util import entry class InvalidConfigError(ValueError): pass @@ -63,6 +64,7 @@ validate_config = get_validator( ) +@entry def run(argv): parser = argparse.ArgumentParser() parser.add_argument( @@ -85,3 +87,7 @@ def run(argv): return 1 return 0 + + +if __name__ == '__main__': + run() diff --git a/pre_commit/clientlib/validate_manifest.py b/pre_commit/clientlib/validate_manifest.py index 7f11ff0c..06864b0b 100644 --- a/pre_commit/clientlib/validate_manifest.py +++ b/pre_commit/clientlib/validate_manifest.py @@ -5,6 +5,7 @@ import argparse import pre_commit.constants as C from pre_commit.clientlib.validate_base import get_validator +from pre_commit.util import entry class InvalidManifestError(ValueError): pass @@ -52,6 +53,7 @@ validate_manifest = get_validator( ) +@entry def run(argv): parser = argparse.ArgumentParser() parser.add_argument( @@ -73,4 +75,8 @@ def run(argv): print(str(e.args[1])) return 1 - return 0 \ No newline at end of file + return 0 + + +if __name__ == '__main__': + run() diff --git a/pre_commit/entry_points.py b/pre_commit/entry_points.py deleted file mode 100644 index 4f448ca1..00000000 --- a/pre_commit/entry_points.py +++ /dev/null @@ -1,25 +0,0 @@ - -import functools - -import pre_commit.clientlib.validate_config -import pre_commit.clientlib.validate_manifest -import pre_commit.run - - -def make_entry_point(entry_point_func): - """Decorator which turns a function which takes sys.argv[1:] and returns - an integer into an argumentless function which returns an integer. - - Args: - entry_point_func - A function which takes an array representing argv - """ - @functools.wraps(entry_point_func) - def func(): - import sys - return entry_point_func(sys.argv[1:]) - return func - - -pre_commit_func = make_entry_point(pre_commit.run.run) -validate_manifest_func = make_entry_point(pre_commit.clientlib.validate_manifest.run) -validate_config_func = make_entry_point(pre_commit.clientlib.validate_config.run) \ No newline at end of file diff --git a/pre_commit/run.py b/pre_commit/run.py index 08ba2437..c86776cc 100644 --- a/pre_commit/run.py +++ b/pre_commit/run.py @@ -2,11 +2,11 @@ import argparse import os.path import subprocess -import sys from pre_commit import git from pre_commit.clientlib.validate_config import validate_config from pre_commit.repository import Repository +from pre_commit.util import entry RED = '\033[41m' @@ -90,6 +90,7 @@ def run_single_hook(hook_id, configs=None, run_all_the_things=False): return 1 +@entry def run(argv): parser = argparse.ArgumentParser() @@ -130,4 +131,4 @@ def run(argv): if __name__ == '__main__': - run(sys.argv[1:]) + run() diff --git a/pre_commit/util.py b/pre_commit/util.py index 648b7892..47162be4 100644 --- a/pre_commit/util.py +++ b/pre_commit/util.py @@ -1,6 +1,7 @@ import functools import os +import sys class cached_property(object): @@ -35,3 +36,16 @@ def memoize_by_cwd(func): wrapper._cache = {} return wrapper + + +def entry(func): + """Allows a function that has `argv` as an argument to be used as a + commandline entry. This will make the function callable using either + explicitly passed argv or defaulting to sys.argv[1:] + """ + @functools.wraps(func) + def wrapper(argv=None): + if argv is None: + argv = sys.argv[1:] + return func(argv) + return wrapper diff --git a/setup.py b/setup.py index c50827d5..93bdae3d 100644 --- a/setup.py +++ b/setup.py @@ -19,9 +19,9 @@ setup( ], entry_points={ 'console_scripts': [ - 'pre-commit = pre_commit.entry_points:pre_commit_func', - 'validate-config = pre_commit.entry_points:validate_config_func', - 'validate-manifest = pre_commit.entry_points:validate_manifest_func', + 'pre-commit = pre_commit.run:run', + 'validate-config = pre_commit.clientlib.validate_config:run', + 'validate-manifest = pre_commit.clientlib.validate_manifest:run', ], }, scripts=[ diff --git a/tests/repository_test.py b/tests/repository_test.py index 8ba7f799..dadeb8af 100644 --- a/tests/repository_test.py +++ b/tests/repository_test.py @@ -65,6 +65,7 @@ def test_run_a_hook_lots_of_files(config_for_python_pre_commit_git_repo): os.environ.get('slowtests', None) == 'false', reason="TODO: make this test not super slow", ) +@pytest.mark.integration def test_run_a_node_hook(config_for_node_pre_commit_git_repo): repo = Repository(config_for_node_pre_commit_git_repo) repo.install() diff --git a/tests/util_test.py b/tests/util_test.py index bc79ff94..ba41b521 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -1,9 +1,12 @@ +import mock import pytest import random +import sys from plumbum import local from pre_commit.util import cached_property +from pre_commit.util import entry from pre_commit.util import memoize_by_cwd @@ -59,3 +62,25 @@ def test_memoized_by_cwd_changes_with_different_cwd(memoized_by_cwd): ret2 = memoized_by_cwd('baz') assert ret != ret2 + + +@pytest.fixture +def entry_func(): + @entry + def func(argv): + return argv + + return func + + +def test_explicitly_passed_argv_are_passed(entry_func): + input = object() + ret = entry_func(input) + assert ret is input + + +def test_no_arguments_passed_uses_argv(entry_func): + argv = [1, 2, 3, 4] + with mock.patch.object(sys, 'argv', argv): + ret = entry_func() + assert ret == argv[1:]