From b6926e8e2ef50d709945f75252e7c6b9cacda290 Mon Sep 17 00:00:00 2001 From: Chris Kuehl Date: Sat, 20 Oct 2018 17:14:50 -0700 Subject: [PATCH] Attempt to partition files to use all possible cores --- pre_commit/xargs.py | 18 +++++++++++++----- tests/xargs_test.py | 42 +++++++++++++++++++++++++++++++++++------- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/pre_commit/xargs.py b/pre_commit/xargs.py index aa4f27e0..9c4bc78a 100644 --- a/pre_commit/xargs.py +++ b/pre_commit/xargs.py @@ -1,7 +1,9 @@ from __future__ import absolute_import +from __future__ import division from __future__ import unicode_literals import contextlib +import math import multiprocessing.pool import sys @@ -37,8 +39,13 @@ class ArgumentTooLongError(RuntimeError): pass -def partition(cmd, varargs, _max_length=None): +def partition(cmd, varargs, target_concurrency, _max_length=None): _max_length = _max_length or _get_platform_max_length() + + # Generally, we try to partition evenly into at least `target_concurrency` + # partitions, but we don't want a bunch of tiny partitions. + max_args = max(4, math.ceil(len(varargs) / target_concurrency)) + cmd = tuple(cmd) ret = [] @@ -51,7 +58,10 @@ def partition(cmd, varargs, _max_length=None): arg = varargs.pop() arg_length = _command_length(arg) + 1 - if total_length + arg_length <= _max_length: + if ( + total_length + arg_length <= _max_length + and len(ret_cmd) < max_args + ): ret_cmd.append(arg) total_length += arg_length elif not ret_cmd: @@ -94,9 +104,7 @@ def xargs(cmd, varargs, **kwargs): except parse_shebang.ExecutableNotFoundError as e: return e.to_output() - # TODO: teach partition to intelligently target our desired concurrency - # while still respecting max_length. - partitions = partition(cmd, varargs, **kwargs) + partitions = partition(cmd, varargs, target_concurrency, **kwargs) def run_cmd_partition(run_cmd): return cmd_output(*run_cmd, encoding=None, retcode=None) diff --git a/tests/xargs_test.py b/tests/xargs_test.py index b60a37d6..3dcb6e8a 100644 --- a/tests/xargs_test.py +++ b/tests/xargs_test.py @@ -36,11 +36,11 @@ def linux_mock(): def test_partition_trivial(): - assert xargs.partition(('cmd',), ()) == (('cmd',),) + assert xargs.partition(('cmd',), (), 1) == (('cmd',),) def test_partition_simple(): - assert xargs.partition(('cmd',), ('foo',)) == (('cmd', 'foo'),) + assert xargs.partition(('cmd',), ('foo',), 1) == (('cmd', 'foo'),) def test_partition_limits(): @@ -54,6 +54,7 @@ def test_partition_limits(): '.' * 5, '.' * 6, ), + 1, _max_length=20, ) assert ret == ( @@ -68,21 +69,21 @@ def test_partition_limit_win32_py3(win32_py3_mock): cmd = ('ninechars',) # counted as half because of utf-16 encode varargs = ('😑' * 5,) - ret = xargs.partition(cmd, varargs, _max_length=20) + ret = xargs.partition(cmd, varargs, 1, _max_length=20) assert ret == (cmd + varargs,) def test_partition_limit_win32_py2(win32_py2_mock): cmd = ('ninechars',) varargs = ('😑' * 5,) # 4 bytes * 5 - ret = xargs.partition(cmd, varargs, _max_length=30) + ret = xargs.partition(cmd, varargs, 1, _max_length=30) assert ret == (cmd + varargs,) def test_partition_limit_linux(linux_mock): cmd = ('ninechars',) varargs = ('😑' * 5,) - ret = xargs.partition(cmd, varargs, _max_length=30) + ret = xargs.partition(cmd, varargs, 1, _max_length=30) assert ret == (cmd + varargs,) @@ -90,12 +91,39 @@ def test_argument_too_long_with_large_unicode(linux_mock): cmd = ('ninechars',) varargs = ('😑' * 10,) # 4 bytes * 10 with pytest.raises(xargs.ArgumentTooLongError): - xargs.partition(cmd, varargs, _max_length=20) + xargs.partition(cmd, varargs, 1, _max_length=20) + + +def test_partition_target_concurrency(): + ret = xargs.partition( + ('foo',), ('A',) * 22, + 4, + _max_length=50, + ) + assert ret == ( + ('foo',) + ('A',) * 6, + ('foo',) + ('A',) * 6, + ('foo',) + ('A',) * 6, + ('foo',) + ('A',) * 4, + ) + + +def test_partition_target_concurrency_wont_make_tiny_partitions(): + ret = xargs.partition( + ('foo',), ('A',) * 10, + 4, + _max_length=50, + ) + assert ret == ( + ('foo',) + ('A',) * 4, + ('foo',) + ('A',) * 4, + ('foo',) + ('A',) * 2, + ) def test_argument_too_long(): with pytest.raises(xargs.ArgumentTooLongError): - xargs.partition(('a' * 5,), ('a' * 5,), _max_length=10) + xargs.partition(('a' * 5,), ('a' * 5,), 1, _max_length=10) def test_xargs_smoke():