diff --git a/tests/git_test.py b/tests/git_test.py index 5a21fec5..c566b1f3 100644 --- a/tests/git_test.py +++ b/tests/git_test.py @@ -1,15 +1,27 @@ +import contextlib import os import pytest from plumbum import local from pre_commit import git -@pytest.fixture + + +@contextlib.contextmanager +def in_dir(dir): + old_path = local.cwd.getpath() + local.cwd.chdir(dir) + try: + yield + finally: + local.cwd.chdir(old_path) + +@pytest.yield_fixture def empty_git_dir(tmpdir): - local.cwd.chdir(tmpdir.strpath) - local['git']['init']() - return tmpdir.strpath + with in_dir(tmpdir.strpath): + local['git']['init']() + yield tmpdir.strpath def test_get_root(empty_git_dir): @@ -17,9 +29,9 @@ def test_get_root(empty_git_dir): foo = local.path('foo') foo.mkdir() - local.cwd.chdir(foo) - assert git.get_root() == empty_git_dir + with in_dir(foo): + assert git.get_root() == empty_git_dir def test_get_pre_commit_path(empty_git_dir):