From 18f54e7c1096bebf1e5dfd04194fdfa6004c743a Mon Sep 17 00:00:00 2001 From: Clemens Wolff Date: Tue, 1 Jan 2019 12:37:12 -0500 Subject: [PATCH] Add constant enumerating supported formats (#1) --- tests/test_xtarfile.py | 6 ++---- xtarfile/xtarfile.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/test_xtarfile.py b/tests/test_xtarfile.py index b09a7e8..bd241ff 100644 --- a/tests/test_xtarfile.py +++ b/tests/test_xtarfile.py @@ -6,9 +6,9 @@ from tempfile import mkstemp from unittest import TestCase +from xtarfile.xtarfile import SUPPORTED_FORMATS from xtarfile.xtarfile import get_compression from xtarfile.xtarfile import xtarfile_open -from xtarfile.xtarfile import HANDLERS class FileExtensionIdContext: @@ -54,11 +54,9 @@ def test_falls_back_to_extension(self): class OpenTests(TestCase): def test_roundtrip(self): - plugins = [key for (key, value) in HANDLERS.items() if value] - compressors = ['gz', 'bz2', 'xz'] + plugins contexts = (ExplicitOpenIdContext, FileExtensionIdContext) - for compressor, ctx in product(compressors, contexts): + for compressor, ctx in product(SUPPORTED_FORMATS, contexts): context = ctx(self, compressor) with self.subTest(compressor=compressor, context=str(context)): self._test_roundtrip(context) diff --git a/xtarfile/xtarfile.py b/xtarfile/xtarfile.py index b460bab..9eeedc7 100644 --- a/xtarfile/xtarfile.py +++ b/xtarfile/xtarfile.py @@ -1,12 +1,17 @@ +from itertools import chain from tarfile import open as tarfile_open from xtarfile.zstd import ZstandardTarfile -HANDLERS = { +_HANDLERS = { 'zstd': ZstandardTarfile } +_NATIVE_FORMATS = ('gz', 'bz2', 'xz') + +SUPPORTED_FORMATS = frozenset(chain(_HANDLERS.keys(), _NATIVE_FORMATS)) + def get_compression(path: str, mode: str) -> str: for delim in (':', '|'): @@ -24,10 +29,10 @@ def get_compression(path: str, mode: str) -> str: def xtarfile_open(path: str, mode: str, **kwargs): compression = get_compression(path, mode) - if not compression or compression in ('gz', 'bz2', 'xz'): + if not compression or compression in _NATIVE_FORMATS: return tarfile_open(path, mode, **kwargs) - handler_class = HANDLERS.get(compression) + handler_class = _HANDLERS.get(compression) if handler_class is not None: handler = handler_class(**kwargs) if mode.startswith('r'):