Make entry points simpler

This commit is contained in:
Anthony Sottile
2014-03-22 18:11:30 -07:00
parent 6d1a464c4f
commit 04b421978a
8 changed files with 59 additions and 31 deletions

View File

@@ -6,6 +6,7 @@ import re
import pre_commit.constants as C
from pre_commit.clientlib.validate_base import get_validator
from pre_commit.util import entry
class InvalidConfigError(ValueError): pass
@@ -63,6 +64,7 @@ validate_config = get_validator(
)
@entry
def run(argv):
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -85,3 +87,7 @@ def run(argv):
return 1
return 0
if __name__ == '__main__':
run()

View File

@@ -5,6 +5,7 @@ import argparse
import pre_commit.constants as C
from pre_commit.clientlib.validate_base import get_validator
from pre_commit.util import entry
class InvalidManifestError(ValueError): pass
@@ -52,6 +53,7 @@ validate_manifest = get_validator(
)
@entry
def run(argv):
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -73,4 +75,8 @@ def run(argv):
print(str(e.args[1]))
return 1
return 0
return 0
if __name__ == '__main__':
run()

View File

@@ -1,25 +0,0 @@
import functools
import pre_commit.clientlib.validate_config
import pre_commit.clientlib.validate_manifest
import pre_commit.run
def make_entry_point(entry_point_func):
"""Decorator which turns a function which takes sys.argv[1:] and returns
an integer into an argumentless function which returns an integer.
Args:
entry_point_func - A function which takes an array representing argv
"""
@functools.wraps(entry_point_func)
def func():
import sys
return entry_point_func(sys.argv[1:])
return func
pre_commit_func = make_entry_point(pre_commit.run.run)
validate_manifest_func = make_entry_point(pre_commit.clientlib.validate_manifest.run)
validate_config_func = make_entry_point(pre_commit.clientlib.validate_config.run)

View File

@@ -2,11 +2,11 @@
import argparse
import os.path
import subprocess
import sys
from pre_commit import git
from pre_commit.clientlib.validate_config import validate_config
from pre_commit.repository import Repository
from pre_commit.util import entry
RED = '\033[41m'
@@ -90,6 +90,7 @@ def run_single_hook(hook_id, configs=None, run_all_the_things=False):
return 1
@entry
def run(argv):
parser = argparse.ArgumentParser()
@@ -130,4 +131,4 @@ def run(argv):
if __name__ == '__main__':
run(sys.argv[1:])
run()

View File

@@ -1,6 +1,7 @@
import functools
import os
import sys
class cached_property(object):
@@ -35,3 +36,16 @@ def memoize_by_cwd(func):
wrapper._cache = {}
return wrapper
def entry(func):
"""Allows a function that has `argv` as an argument to be used as a
commandline entry. This will make the function callable using either
explicitly passed argv or defaulting to sys.argv[1:]
"""
@functools.wraps(func)
def wrapper(argv=None):
if argv is None:
argv = sys.argv[1:]
return func(argv)
return wrapper

View File

@@ -19,9 +19,9 @@ setup(
],
entry_points={
'console_scripts': [
'pre-commit = pre_commit.entry_points:pre_commit_func',
'validate-config = pre_commit.entry_points:validate_config_func',
'validate-manifest = pre_commit.entry_points:validate_manifest_func',
'pre-commit = pre_commit.run:run',
'validate-config = pre_commit.clientlib.validate_config:run',
'validate-manifest = pre_commit.clientlib.validate_manifest:run',
],
},
scripts=[

View File

@@ -65,6 +65,7 @@ def test_run_a_hook_lots_of_files(config_for_python_pre_commit_git_repo):
os.environ.get('slowtests', None) == 'false',
reason="TODO: make this test not super slow",
)
@pytest.mark.integration
def test_run_a_node_hook(config_for_node_pre_commit_git_repo):
repo = Repository(config_for_node_pre_commit_git_repo)
repo.install()

View File

@@ -1,9 +1,12 @@
import mock
import pytest
import random
import sys
from plumbum import local
from pre_commit.util import cached_property
from pre_commit.util import entry
from pre_commit.util import memoize_by_cwd
@@ -59,3 +62,25 @@ def test_memoized_by_cwd_changes_with_different_cwd(memoized_by_cwd):
ret2 = memoized_by_cwd('baz')
assert ret != ret2
@pytest.fixture
def entry_func():
@entry
def func(argv):
return argv
return func
def test_explicitly_passed_argv_are_passed(entry_func):
input = object()
ret = entry_func(input)
assert ret is input
def test_no_arguments_passed_uses_argv(entry_func):
argv = [1, 2, 3, 4]
with mock.patch.object(sys, 'argv', argv):
ret = entry_func()
assert ret == argv[1:]