mirror of
https://github.com/pre-commit/pre-commit.git
synced 2026-01-13 04:20:28 -06:00
Add envcontext helper
This commit is contained in:
54
pre_commit/envcontext.py
Normal file
54
pre_commit/envcontext.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
from pre_commit import five
|
||||
|
||||
|
||||
UNSET = collections.namedtuple('UNSET', ())()
|
||||
|
||||
|
||||
Var = collections.namedtuple('Var', ('name', 'default'))
|
||||
setattr(Var.__new__, five.defaults_attr, ('',))
|
||||
|
||||
|
||||
def format_env(parts, env):
|
||||
return ''.join(
|
||||
env.get(part.name, part.default)
|
||||
if isinstance(part, Var)
|
||||
else part
|
||||
for part in parts
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def envcontext(patch, _env=None):
|
||||
"""In this context, `os.environ` is modified according to `patch`.
|
||||
|
||||
`patch` is an iterable of 2-tuples (key, value):
|
||||
`key`: string
|
||||
`value`:
|
||||
- string: `environ[key] == value` inside the context.
|
||||
- UNSET: `key not in environ` inside the context.
|
||||
- template: A template is a tuple of strings and Var which will be
|
||||
replaced with the previous environment
|
||||
"""
|
||||
env = os.environ if _env is None else _env
|
||||
before = env.copy()
|
||||
|
||||
for k, v in patch:
|
||||
if v is UNSET:
|
||||
env.pop(k, None)
|
||||
elif isinstance(v, tuple):
|
||||
env[k] = format_env(v, before)
|
||||
else:
|
||||
env[k] = v
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
env.clear()
|
||||
env.update(before)
|
||||
@@ -12,6 +12,8 @@ if PY2: # pragma: no cover (PY2 only)
|
||||
return s
|
||||
else:
|
||||
return s.encode('UTF-8')
|
||||
|
||||
defaults_attr = 'func_defaults'
|
||||
else: # pragma: no cover (PY3 only)
|
||||
text = str
|
||||
|
||||
@@ -21,6 +23,8 @@ else: # pragma: no cover (PY3 only)
|
||||
else:
|
||||
return s.decode('UTF-8')
|
||||
|
||||
defaults_attr = '__defaults__'
|
||||
|
||||
|
||||
def to_text(s):
|
||||
return s if isinstance(s, text) else s.decode('UTF-8')
|
||||
|
||||
109
tests/envcontext_test.py
Normal file
109
tests/envcontext_test.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from pre_commit.envcontext import envcontext
|
||||
from pre_commit.envcontext import UNSET
|
||||
from pre_commit.envcontext import Var
|
||||
|
||||
|
||||
def _test(**kwargs):
|
||||
before = kwargs.pop('before')
|
||||
patch = kwargs.pop('patch')
|
||||
expected = kwargs.pop('expected')
|
||||
assert not kwargs
|
||||
|
||||
env = before.copy()
|
||||
with envcontext(patch, _env=env):
|
||||
assert env == expected
|
||||
assert env == before
|
||||
|
||||
|
||||
def test_trivial():
|
||||
_test(before={}, patch={}, expected={})
|
||||
|
||||
|
||||
def test_noop():
|
||||
_test(before={'foo': 'bar'}, patch=(), expected={'foo': 'bar'})
|
||||
|
||||
|
||||
def test_adds():
|
||||
_test(before={}, patch=[('foo', 'bar')], expected={'foo': 'bar'})
|
||||
|
||||
|
||||
def test_overrides():
|
||||
_test(
|
||||
before={'foo': 'baz'},
|
||||
patch=[('foo', 'bar')],
|
||||
expected={'foo': 'bar'},
|
||||
)
|
||||
|
||||
|
||||
def test_unset_but_nothing_to_unset():
|
||||
_test(before={}, patch=[('foo', UNSET)], expected={})
|
||||
|
||||
|
||||
def test_unset_things_to_remove():
|
||||
_test(
|
||||
before={'PYTHONHOME': ''},
|
||||
patch=[('PYTHONHOME', UNSET)],
|
||||
expected={},
|
||||
)
|
||||
|
||||
|
||||
def test_templated_environment_variable_missing():
|
||||
_test(
|
||||
before={},
|
||||
patch=[('PATH', ('~/bin:', Var('PATH')))],
|
||||
expected={'PATH': '~/bin:'},
|
||||
)
|
||||
|
||||
|
||||
def test_templated_environment_variable_defaults():
|
||||
_test(
|
||||
before={},
|
||||
patch=[('PATH', ('~/bin:', Var('PATH', default='/bin')))],
|
||||
expected={'PATH': '~/bin:/bin'},
|
||||
)
|
||||
|
||||
|
||||
def test_templated_environment_variable_there():
|
||||
_test(
|
||||
before={'PATH': '/usr/local/bin:/usr/bin'},
|
||||
patch=[('PATH', ('~/bin:', Var('PATH')))],
|
||||
expected={'PATH': '~/bin:/usr/local/bin:/usr/bin'},
|
||||
)
|
||||
|
||||
|
||||
def test_templated_environ_sources_from_previous():
|
||||
_test(
|
||||
before={'foo': 'bar'},
|
||||
patch=(
|
||||
('foo', 'baz'),
|
||||
('herp', ('foo: ', Var('foo'))),
|
||||
),
|
||||
expected={'foo': 'baz', 'herp': 'foo: bar'},
|
||||
)
|
||||
|
||||
|
||||
def test_exception_safety():
|
||||
class MyError(RuntimeError):
|
||||
pass
|
||||
|
||||
env = {}
|
||||
with pytest.raises(MyError):
|
||||
with envcontext([('foo', 'bar')], _env=env):
|
||||
raise MyError()
|
||||
assert env == {}
|
||||
|
||||
|
||||
def test_integration_os_environ():
|
||||
with mock.patch.dict(os.environ, {'FOO': 'bar'}, clear=True):
|
||||
assert os.environ == {'FOO': 'bar'}
|
||||
with envcontext([('HERP', 'derp')]):
|
||||
assert os.environ == {'FOO': 'bar', 'HERP': 'derp'}
|
||||
assert os.environ == {'FOO': 'bar'}
|
||||
Reference in New Issue
Block a user