From 00a3a9a09bc4541701dc0164bcb3bb826337f120 Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Mon, 21 Mar 2016 12:10:40 -0700 Subject: [PATCH] Add envcontext helper --- pre_commit/envcontext.py | 54 +++++++++++++++++++ pre_commit/five.py | 4 ++ tests/envcontext_test.py | 109 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 167 insertions(+) create mode 100644 pre_commit/envcontext.py create mode 100644 tests/envcontext_test.py diff --git a/pre_commit/envcontext.py b/pre_commit/envcontext.py new file mode 100644 index 00000000..2013c723 --- /dev/null +++ b/pre_commit/envcontext.py @@ -0,0 +1,54 @@ +from __future__ import absolute_import +from __future__ import unicode_literals + +import collections +import contextlib +import os + +from pre_commit import five + + +UNSET = collections.namedtuple('UNSET', ())() + + +Var = collections.namedtuple('Var', ('name', 'default')) +setattr(Var.__new__, five.defaults_attr, ('',)) + + +def format_env(parts, env): + return ''.join( + env.get(part.name, part.default) + if isinstance(part, Var) + else part + for part in parts + ) + + +@contextlib.contextmanager +def envcontext(patch, _env=None): + """In this context, `os.environ` is modified according to `patch`. + + `patch` is an iterable of 2-tuples (key, value): + `key`: string + `value`: + - string: `environ[key] == value` inside the context. + - UNSET: `key not in environ` inside the context. + - template: A template is a tuple of strings and Var which will be + replaced with the previous environment + """ + env = os.environ if _env is None else _env + before = env.copy() + + for k, v in patch: + if v is UNSET: + env.pop(k, None) + elif isinstance(v, tuple): + env[k] = format_env(v, before) + else: + env[k] = v + + try: + yield + finally: + env.clear() + env.update(before) diff --git a/pre_commit/five.py b/pre_commit/five.py index 8b9a2b54..2ae91c59 100644 --- a/pre_commit/five.py +++ b/pre_commit/five.py @@ -12,6 +12,8 @@ if PY2: # pragma: no cover (PY2 only) return s else: return s.encode('UTF-8') + + defaults_attr = 'func_defaults' else: # pragma: no cover (PY3 only) text = str @@ -21,6 +23,8 @@ else: # pragma: no cover (PY3 only) else: return s.decode('UTF-8') + defaults_attr = '__defaults__' + def to_text(s): return s if isinstance(s, text) else s.decode('UTF-8') diff --git a/tests/envcontext_test.py b/tests/envcontext_test.py new file mode 100644 index 00000000..c03e9431 --- /dev/null +++ b/tests/envcontext_test.py @@ -0,0 +1,109 @@ +from __future__ import absolute_import +from __future__ import unicode_literals + +import os + +import mock +import pytest + +from pre_commit.envcontext import envcontext +from pre_commit.envcontext import UNSET +from pre_commit.envcontext import Var + + +def _test(**kwargs): + before = kwargs.pop('before') + patch = kwargs.pop('patch') + expected = kwargs.pop('expected') + assert not kwargs + + env = before.copy() + with envcontext(patch, _env=env): + assert env == expected + assert env == before + + +def test_trivial(): + _test(before={}, patch={}, expected={}) + + +def test_noop(): + _test(before={'foo': 'bar'}, patch=(), expected={'foo': 'bar'}) + + +def test_adds(): + _test(before={}, patch=[('foo', 'bar')], expected={'foo': 'bar'}) + + +def test_overrides(): + _test( + before={'foo': 'baz'}, + patch=[('foo', 'bar')], + expected={'foo': 'bar'}, + ) + + +def test_unset_but_nothing_to_unset(): + _test(before={}, patch=[('foo', UNSET)], expected={}) + + +def test_unset_things_to_remove(): + _test( + before={'PYTHONHOME': ''}, + patch=[('PYTHONHOME', UNSET)], + expected={}, + ) + + +def test_templated_environment_variable_missing(): + _test( + before={}, + patch=[('PATH', ('~/bin:', Var('PATH')))], + expected={'PATH': '~/bin:'}, + ) + + +def test_templated_environment_variable_defaults(): + _test( + before={}, + patch=[('PATH', ('~/bin:', Var('PATH', default='/bin')))], + expected={'PATH': '~/bin:/bin'}, + ) + + +def test_templated_environment_variable_there(): + _test( + before={'PATH': '/usr/local/bin:/usr/bin'}, + patch=[('PATH', ('~/bin:', Var('PATH')))], + expected={'PATH': '~/bin:/usr/local/bin:/usr/bin'}, + ) + + +def test_templated_environ_sources_from_previous(): + _test( + before={'foo': 'bar'}, + patch=( + ('foo', 'baz'), + ('herp', ('foo: ', Var('foo'))), + ), + expected={'foo': 'baz', 'herp': 'foo: bar'}, + ) + + +def test_exception_safety(): + class MyError(RuntimeError): + pass + + env = {} + with pytest.raises(MyError): + with envcontext([('foo', 'bar')], _env=env): + raise MyError() + assert env == {} + + +def test_integration_os_environ(): + with mock.patch.dict(os.environ, {'FOO': 'bar'}, clear=True): + assert os.environ == {'FOO': 'bar'} + with envcontext([('HERP', 'derp')]): + assert os.environ == {'FOO': 'bar', 'HERP': 'derp'} + assert os.environ == {'FOO': 'bar'}