From 773a93401fd5e3e16e70b983352a9855e45f5887 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Mon, 20 Jan 2025 16:02:24 -0500 Subject: [PATCH] Add put to CLI (#197) --- python/hdfs_native/cli.py | 139 +++++++++++++++++++++++++++++++++++++- python/tests/test_cli.py | 57 ++++++++++++++++ 2 files changed, 195 insertions(+), 1 deletion(-) diff --git a/python/hdfs_native/cli.py b/python/hdfs_native/cli.py index 3fb602c..38ccdd5 100644 --- a/python/hdfs_native/cli.py +++ b/python/hdfs_native/cli.py @@ -1,7 +1,9 @@ import functools +import glob import os import re import shutil +import stat import sys from argparse import ArgumentParser, Namespace from concurrent.futures import ThreadPoolExecutor, as_completed @@ -9,6 +11,7 @@ from urllib.parse import urlparse from hdfs_native import Client +from hdfs_native._internal import WriteOptions @functools.cache @@ -52,6 +55,10 @@ def _glob_path(client: Client, glob: str) -> List[str]: return [glob] +def _glob_local_path(glob_pattern: str) -> List[str]: + return glob.glob(glob_pattern) + + def _download_file( client: Client, remote_src: str, @@ -80,13 +87,44 @@ def _upload_file( client: Client, local_src: str, remote_dst: str, + direct: bool = False, force: bool = False, preserve: bool = False, ) -> None: + if not direct and not force: + # Check if file already exists before we write it to a temporary file + try: + client.get_file_info(remote_dst) + raise FileExistsError( + f"{remote_dst} already exists, use --force to overwrite" + ) + except FileNotFoundError: + pass + + if direct: + write_destination = remote_dst + else: + write_destination = f"{remote_dst}.__COPYING__" + with open(local_src, "rb") as local_file: - with client.create(remote_dst) as remote_file: + with client.create( + write_destination, + WriteOptions(overwrite=force), + ) as remote_file: shutil.copyfileobj(local_file, remote_file) + if preserve: + st = os.stat(local_src) + client.set_times( + write_destination, + int(st.st_mtime * 1000), + int(st.st_atime * 1000), + ) + client.set_permission(write_destination, stat.S_IMODE(st.st_mode)) + + if not direct: + client.rename(write_destination, remote_dst, overwrite=force) + def cat(args: Namespace): for src in args.src: @@ -220,6 +258,59 @@ def mv(args: Namespace): client.rename(src_path, target_path) +def put(args: Namespace): + paths: List[str] = [] + + for pattern in args.localsrc: + for path in _glob_local_path(pattern): + paths.append(path) + + if len(paths) == 0: + raise FileNotFoundError("No files matched patterns") + + client = _client_for_url(args.dst) + dst_path = _path_for_url(args.dst) + + dst_is_dir = False + try: + dst_is_dir = client.get_file_info(dst_path).isdir + except FileNotFoundError: + pass + + if len(paths) > 1 and not dst_is_dir: + raise ValueError("Destination must be directory when copying multiple files") + elif not dst_is_dir: + _upload_file( + client, + paths[0], + dst_path, + direct=args.direct, + force=args.force, + preserve=args.preserve, + ) + else: + with ThreadPoolExecutor(args.threads) as executor: + futures = [] + for path in paths: + filename = os.path.basename(path) + + futures.append( + executor.submit( + _upload_file, + client, + path, + os.path.join(dst_path, filename), + direct=args.direct, + force=args.force, + preserve=args.preserve, + ) + ) + + # Iterate to raise any exceptions thrown + for f in as_completed(futures): + f.result() + + def main(in_args: Optional[Sequence[str]] = None): parser = ArgumentParser( description="""Command line utility for interacting with HDFS using hdfs-native. @@ -340,6 +431,52 @@ def main(in_args: Optional[Sequence[str]] = None): mv_parser.add_argument("dst", help="Target destination of file or directory") mv_parser.set_defaults(func=mv) + put_parser = subparsers.add_parser( + "put", + aliases=["copyFromLocal"], + help="Copy local files to a remote destination", + description="""Copy files matching a pattern to a remote destination. + When copying multiple files, the destination must be a directory""", + ) + put_parser.add_argument( + "-p", + "--preserve", + action="store_true", + default=False, + help="Preserve timestamps, ownership, and the mode", + ) + put_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Overwrite the destination if it already exists", + ) + put_parser.add_argument( + "-d", + "--direct", + action="store_true", + default=False, + help="Skip creation of temporary file (._COPYING_) and write directly to file", + ) + put_parser.add_argument( + "-t", + "--threads", + type=int, + help="Number of threads to use", + default=1, + ) + put_parser.add_argument( + "localsrc", + nargs="+", + help="Source patterns to copy", + ) + put_parser.add_argument( + "dst", + help="Local destination to write to", + ) + put_parser.set_defaults(func=put) + args = parser.parse_args(in_args) args.func(args) diff --git a/python/tests/test_cli.py b/python/tests/test_cli.py index 597e038..425d1e2 100644 --- a/python/tests/test_cli.py +++ b/python/tests/test_cli.py @@ -190,3 +190,60 @@ def test_mv(client: Client): client.get_file_info("/testdir/testfile1") client.get_file_info("/testdir/testfile2") + + +def test_put(client: Client): + data = b"0123456789" + + with pytest.raises(FileNotFoundError): + cli_main(["put", "testfile", "/testfile"]) + + with TemporaryDirectory() as tmp_dir: + with open(os.path.join(tmp_dir, "testfile"), "wb") as file: + file.write(data) + + cli_main(["put", os.path.join(tmp_dir, "testfile"), "/remotefile"]) + with client.read("/remotefile") as file: + assert file.read() == data + + cli_main(["put", os.path.join(tmp_dir, "testfile"), "/"]) + with client.read("/testfile") as file: + assert file.read() == data + + with pytest.raises(FileExistsError): + cli_main(["put", os.path.join(tmp_dir, "testfile"), "/"]) + + cli_main(["put", "-f", "-p", os.path.join(tmp_dir, "testfile"), "/"]) + st = os.stat(os.path.join(tmp_dir, "testfile")) + status = client.get_file_info("/testfile") + assert stat.S_IMODE(st.st_mode) == status.permission + assert int(st.st_atime * 1000) == status.access_time + assert int(st.st_mtime * 1000) == status.modification_time + + with open(os.path.join(tmp_dir, "testfile2"), "wb") as file: + file.write(data) + + with pytest.raises(ValueError): + cli_main( + [ + "put", + os.path.join(tmp_dir, "testfile"), + os.path.join(tmp_dir, "testfile2"), + "/notadir", + ] + ) + + client.mkdirs("/testdir") + cli_main( + [ + "put", + os.path.join(tmp_dir, "testfile"), + os.path.join(tmp_dir, "testfile2"), + "/testdir", + ] + ) + + with client.read("/testdir/testfile") as file: + assert file.read() == data + with client.read("/testdir/testfile2") as file: + assert file.read() == data