Skip to content

Commit

Permalink
Add put to CLI (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kimahriman authored Jan 20, 2025
1 parent b130ecb commit 773a934
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 1 deletion.
139 changes: 138 additions & 1 deletion python/hdfs_native/cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import functools
import glob
import os
import re
import shutil
import stat
import sys
from argparse import ArgumentParser, Namespace
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional, Sequence, Tuple
from urllib.parse import urlparse

from hdfs_native import Client
from hdfs_native._internal import WriteOptions


@functools.cache
Expand Down Expand Up @@ -52,6 +55,10 @@ def _glob_path(client: Client, glob: str) -> List[str]:
return [glob]


def _glob_local_path(glob_pattern: str) -> List[str]:
return glob.glob(glob_pattern)


def _download_file(
client: Client,
remote_src: str,
Expand Down Expand Up @@ -80,13 +87,44 @@ def _upload_file(
client: Client,
local_src: str,
remote_dst: str,
direct: bool = False,
force: bool = False,
preserve: bool = False,
) -> None:
if not direct and not force:
# Check if file already exists before we write it to a temporary file
try:
client.get_file_info(remote_dst)
raise FileExistsError(
f"{remote_dst} already exists, use --force to overwrite"
)
except FileNotFoundError:
pass

if direct:
write_destination = remote_dst
else:
write_destination = f"{remote_dst}.__COPYING__"

with open(local_src, "rb") as local_file:
with client.create(remote_dst) as remote_file:
with client.create(
write_destination,
WriteOptions(overwrite=force),
) as remote_file:
shutil.copyfileobj(local_file, remote_file)

if preserve:
st = os.stat(local_src)
client.set_times(
write_destination,
int(st.st_mtime * 1000),
int(st.st_atime * 1000),
)
client.set_permission(write_destination, stat.S_IMODE(st.st_mode))

if not direct:
client.rename(write_destination, remote_dst, overwrite=force)


def cat(args: Namespace):
for src in args.src:
Expand Down Expand Up @@ -220,6 +258,59 @@ def mv(args: Namespace):
client.rename(src_path, target_path)


def put(args: Namespace):
paths: List[str] = []

for pattern in args.localsrc:
for path in _glob_local_path(pattern):
paths.append(path)

if len(paths) == 0:
raise FileNotFoundError("No files matched patterns")

client = _client_for_url(args.dst)
dst_path = _path_for_url(args.dst)

dst_is_dir = False
try:
dst_is_dir = client.get_file_info(dst_path).isdir
except FileNotFoundError:
pass

if len(paths) > 1 and not dst_is_dir:
raise ValueError("Destination must be directory when copying multiple files")
elif not dst_is_dir:
_upload_file(
client,
paths[0],
dst_path,
direct=args.direct,
force=args.force,
preserve=args.preserve,
)
else:
with ThreadPoolExecutor(args.threads) as executor:
futures = []
for path in paths:
filename = os.path.basename(path)

futures.append(
executor.submit(
_upload_file,
client,
path,
os.path.join(dst_path, filename),
direct=args.direct,
force=args.force,
preserve=args.preserve,
)
)

# Iterate to raise any exceptions thrown
for f in as_completed(futures):
f.result()


def main(in_args: Optional[Sequence[str]] = None):
parser = ArgumentParser(
description="""Command line utility for interacting with HDFS using hdfs-native.
Expand Down Expand Up @@ -340,6 +431,52 @@ def main(in_args: Optional[Sequence[str]] = None):
mv_parser.add_argument("dst", help="Target destination of file or directory")
mv_parser.set_defaults(func=mv)

put_parser = subparsers.add_parser(
"put",
aliases=["copyFromLocal"],
help="Copy local files to a remote destination",
description="""Copy files matching a pattern to a remote destination.
When copying multiple files, the destination must be a directory""",
)
put_parser.add_argument(
"-p",
"--preserve",
action="store_true",
default=False,
help="Preserve timestamps, ownership, and the mode",
)
put_parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Overwrite the destination if it already exists",
)
put_parser.add_argument(
"-d",
"--direct",
action="store_true",
default=False,
help="Skip creation of temporary file (<dst>._COPYING_) and write directly to file",
)
put_parser.add_argument(
"-t",
"--threads",
type=int,
help="Number of threads to use",
default=1,
)
put_parser.add_argument(
"localsrc",
nargs="+",
help="Source patterns to copy",
)
put_parser.add_argument(
"dst",
help="Local destination to write to",
)
put_parser.set_defaults(func=put)

args = parser.parse_args(in_args)
args.func(args)

Expand Down
57 changes: 57 additions & 0 deletions python/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,60 @@ def test_mv(client: Client):

client.get_file_info("/testdir/testfile1")
client.get_file_info("/testdir/testfile2")


def test_put(client: Client):
data = b"0123456789"

with pytest.raises(FileNotFoundError):
cli_main(["put", "testfile", "/testfile"])

with TemporaryDirectory() as tmp_dir:
with open(os.path.join(tmp_dir, "testfile"), "wb") as file:
file.write(data)

cli_main(["put", os.path.join(tmp_dir, "testfile"), "/remotefile"])
with client.read("/remotefile") as file:
assert file.read() == data

cli_main(["put", os.path.join(tmp_dir, "testfile"), "/"])
with client.read("/testfile") as file:
assert file.read() == data

with pytest.raises(FileExistsError):
cli_main(["put", os.path.join(tmp_dir, "testfile"), "/"])

cli_main(["put", "-f", "-p", os.path.join(tmp_dir, "testfile"), "/"])
st = os.stat(os.path.join(tmp_dir, "testfile"))
status = client.get_file_info("/testfile")
assert stat.S_IMODE(st.st_mode) == status.permission
assert int(st.st_atime * 1000) == status.access_time
assert int(st.st_mtime * 1000) == status.modification_time

with open(os.path.join(tmp_dir, "testfile2"), "wb") as file:
file.write(data)

with pytest.raises(ValueError):
cli_main(
[
"put",
os.path.join(tmp_dir, "testfile"),
os.path.join(tmp_dir, "testfile2"),
"/notadir",
]
)

client.mkdirs("/testdir")
cli_main(
[
"put",
os.path.join(tmp_dir, "testfile"),
os.path.join(tmp_dir, "testfile2"),
"/testdir",
]
)

with client.read("/testdir/testfile") as file:
assert file.read() == data
with client.read("/testdir/testfile2") as file:
assert file.read() == data

0 comments on commit 773a934

Please sign in to comment.