Skip to content

Commit

Permalink
refactor(wrap_stdio): remake classes
Browse files Browse the repository at this point in the history
  • Loading branch information
saygox committed Nov 22, 2021
1 parent a5860a8 commit f4ac267
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 37 deletions.
36 changes: 22 additions & 14 deletions commitizen/wrap_stdio_linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,50 @@

if sys.platform == "linux": # pragma: no cover
import os
from io import IOBase

class WrapStdioLinux:
def __init__(self, stdx: IOBase):
self._fileno = stdx.fileno()
if self._fileno == 0:
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
tty = open(fd, "wb+", buffering=0)
else:
tty = open("/dev/tty", "w") # type: ignore

# from io import IOBase

class WrapStdinLinux:
def __init__(self):
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
tty = open(fd, "wb+", buffering=0)
self.tty = tty

def __getattr__(self, key):
if key == "encoding" and self._fileno == 0:
if key == "encoding":
return "UTF-8"
return getattr(self.tty, key)

def __del__(self):
self.tty.close()

class WrapStdoutLinux:
def __init__(self):
tty = open("/dev/tty", "w")
self.tty = tty

def __getattr__(self, key):
return getattr(self.tty, key)

def __del__(self):
self.tty.close()

backup_stdin = None
backup_stdout = None
backup_stderr = None

def _wrap_stdio():
global backup_stdin
backup_stdin = sys.stdin
sys.stdin = WrapStdioLinux(sys.stdin)
sys.stdin = WrapStdinLinux()

global backup_stdout
backup_stdout = sys.stdout
sys.stdout = WrapStdioLinux(sys.stdout)
sys.stdout = WrapStdoutLinux()

global backup_stderr
backup_stderr = sys.stderr
sys.stderr = WrapStdioLinux(sys.stderr)
sys.stderr = WrapStdoutLinux()

def _unwrap_stdio():
global backup_stdin
Expand Down
32 changes: 19 additions & 13 deletions commitizen/wrap_stdio_unix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
if sys.platform != "win32" and sys.platform != "linux": # pragma: no cover
import os
import selectors
from asyncio import (
from asyncio import ( # get_event_loop_policy,; set_event_loop_policy,
DefaultEventLoopPolicy,
get_event_loop_policy,
set_event_loop_policy,
)
from io import IOBase

class CZEventLoopPolicy(DefaultEventLoopPolicy): # pragma: no cover
def get_event_loop(self):
self.set_event_loop(self._loop_factory(selectors.SelectSelector()))
return self._local._loop
# class CZEventLoopPolicy(DefaultEventLoopPolicy): # pragma: no cover
# def get_event_loop(self):
# self.set_event_loop(self._loop_factory(selectors.SelectSelector()))
# return self._local._loop

class WrapStdioUnix:
def __init__(self, stdx: IOBase):
Expand All @@ -33,15 +31,20 @@ def __getattr__(self, key):
def __del__(self):
self.tty.close()

backup_event_loop_policy = None
backup_event_loop = None
# backup_event_loop_policy = None
backup_stdin = None
backup_stdout = None
backup_stderr = None

def _wrap_stdio():
global backup_event_loop_policy
backup_event_loop_policy = get_event_loop_policy()
set_event_loop_policy(CZEventLoopPolicy())
global backup_event_loop
backup_event_loop = DefaultEventLoopPolicy.get_event_loop()
DefaultEventLoopPolicy.set_event_loop(selectors.SelectSelector())

# global backup_event_loop_policy
# backup_event_loop_policy = get_event_loop_policy()
# set_event_loop_policy(CZEventLoopPolicy())

global backup_stdin
backup_stdin = sys.stdin
Expand All @@ -56,8 +59,11 @@ def _wrap_stdio():
sys.stderr = WrapStdioUnix(sys.stderr)

def _unwrap_stdio():
global backup_event_loop_policy
set_event_loop_policy(backup_event_loop_policy)
global backup_event_loop
DefaultEventLoopPolicy.set_event_loop(backup_event_loop)

# global backup_event_loop_policy
# set_event_loop_policy(backup_event_loop_policy)

global backup_stdin
sys.stdin.close()
Expand Down
30 changes: 20 additions & 10 deletions tests/test_wrap_stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,52 @@ def test_warp_stdio_exists():
if sys.platform == "win32": # pragma: no cover
pass
elif sys.platform == "linux":
from commitizen.wrap_stdio_linux import WrapStdioLinux
from commitizen.wrap_stdio_linux import WrapStdinLinux, WrapStdoutLinux

def test_wrap_stdio_linux(mocker):
def test_wrap_stdin_linux(mocker):

tmp_stdin = sys.stdin
tmp_stdout = sys.stdout
tmp_stderr = sys.stderr

mocker.patch("os.open")
readerwriter_mock = mocker.mock_open(read_data="data")
mocker.patch("builtins.open", readerwriter_mock, create=True)

mocker.patch.object(sys.stdin, "fileno", return_value=0)
mocker.patch.object(sys.stdout, "fileno", return_value=1)
mocker.patch.object(sys.stdout, "fileno", return_value=2)

wrap_stdio.wrap_stdio()

assert sys.stdin != tmp_stdin
assert isinstance(sys.stdin, WrapStdioLinux)
assert isinstance(sys.stdin, WrapStdinLinux)
assert sys.stdin.encoding == "UTF-8"
assert sys.stdin.read() == "data"

wrap_stdio.unwrap_stdio()

assert sys.stdin == tmp_stdin

def test_wrap_stdout_linux(mocker):

tmp_stdout = sys.stdout
tmp_stderr = sys.stderr

mocker.patch("os.open")
readerwriter_mock = mocker.mock_open(read_data="data")
mocker.patch("builtins.open", readerwriter_mock, create=True)

wrap_stdio.wrap_stdio()

assert sys.stdout != tmp_stdout
assert isinstance(sys.stdout, WrapStdioLinux)
assert isinstance(sys.stdout, WrapStdoutLinux)
sys.stdout.write("stdout")
readerwriter_mock().write.assert_called_with("stdout")

assert sys.stderr != tmp_stderr
assert isinstance(sys.stderr, WrapStdioLinux)
assert isinstance(sys.stderr, WrapStdoutLinux)
sys.stdout.write("stderr")
readerwriter_mock().write.assert_called_with("stderr")

wrap_stdio.unwrap_stdio()

assert sys.stdin == tmp_stdin
assert sys.stdout == tmp_stdout
assert sys.stderr == tmp_stderr

Expand Down

0 comments on commit f4ac267

Please sign in to comment.