Skip to content

Commit

Permalink
Add constant enumerating supported formats (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w authored Jan 1, 2019
1 parent 69e70dd commit 18f54e7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
6 changes: 2 additions & 4 deletions tests/test_xtarfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from tempfile import mkstemp
from unittest import TestCase

from xtarfile.xtarfile import SUPPORTED_FORMATS
from xtarfile.xtarfile import get_compression
from xtarfile.xtarfile import xtarfile_open
from xtarfile.xtarfile import HANDLERS


class FileExtensionIdContext:
Expand Down Expand Up @@ -54,11 +54,9 @@ def test_falls_back_to_extension(self):

class OpenTests(TestCase):
def test_roundtrip(self):
plugins = [key for (key, value) in HANDLERS.items() if value]
compressors = ['gz', 'bz2', 'xz'] + plugins
contexts = (ExplicitOpenIdContext, FileExtensionIdContext)

for compressor, ctx in product(compressors, contexts):
for compressor, ctx in product(SUPPORTED_FORMATS, contexts):
context = ctx(self, compressor)
with self.subTest(compressor=compressor, context=str(context)):
self._test_roundtrip(context)
Expand Down
11 changes: 8 additions & 3 deletions xtarfile/xtarfile.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from itertools import chain
from tarfile import open as tarfile_open

from xtarfile.zstd import ZstandardTarfile


HANDLERS = {
_HANDLERS = {
'zstd': ZstandardTarfile
}

_NATIVE_FORMATS = ('gz', 'bz2', 'xz')

SUPPORTED_FORMATS = frozenset(chain(_HANDLERS.keys(), _NATIVE_FORMATS))


def get_compression(path: str, mode: str) -> str:
for delim in (':', '|'):
Expand All @@ -24,10 +29,10 @@ def get_compression(path: str, mode: str) -> str:
def xtarfile_open(path: str, mode: str, **kwargs):
compression = get_compression(path, mode)

if not compression or compression in ('gz', 'bz2', 'xz'):
if not compression or compression in _NATIVE_FORMATS:
return tarfile_open(path, mode, **kwargs)

handler_class = HANDLERS.get(compression)
handler_class = _HANDLERS.get(compression)
if handler_class is not None:
handler = handler_class(**kwargs)
if mode.startswith('r'):
Expand Down

0 comments on commit 18f54e7

Please sign in to comment.