From 10c4e2aa8f4fb9b21fb8871df5ca4607cae866f5 Mon Sep 17 00:00:00 2001 From: Amrit Ghimire Date: Thu, 7 Nov 2024 15:18:23 +0545 Subject: [PATCH] feat(cli): enhance dataset listing with additional options for studio (#561) * feat(cli): enhance dataset listing with additional options for studio With this change, the following variation are available to list datasets instead of `datachain ls-datasets`: - `datachain datasets --local` - `datachain datasets --studio` - `datachain datasets --all` (Default option) - `datachain datasets --team TEAM_NAME` - `datachain datasets` By default, if the user has logged in to Studio using `datachain studio login`, it will try to fetch the datasets from both local and studio. If any specific option (local or studio) is passed, only that option is passed. The team name is parsed from config (Set by `datachain studio team `). The same feature is added for `datachain ls` function and added the error handling for each scenarios. Relates to #10774 * Modify the result format with tabular format * Fix tests for ls * test ls * Fix tests for ls * Fix ls remote sources * Deduplicate --- pyproject.toml | 6 +- src/datachain/cli.py | 160 ++++++++++++++++++++++++++++----- src/datachain/remote/studio.py | 14 ++- src/datachain/studio.py | 24 +++-- tests/conftest.py | 28 ++++++ tests/func/test_ls.py | 17 +--- tests/test_cli_e2e.py | 25 ++++-- tests/test_cli_studio.py | 82 +++++++++++++---- 8 files changed, 284 insertions(+), 72 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 36560b7a4..3532420a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,8 @@ dependencies = [ "huggingface_hub", "iterative-telemetry>=0.0.9", "platformdirs", - "dvc-studio-client>=0.21,<1" + "dvc-studio-client>=0.21,<1", + "tabulate" ] [project.optional-dependencies] @@ -98,7 +99,8 @@ dev = [ "types-python-dateutil", "types-pytz", "types-PyYAML", - "types-requests" + "types-requests", + "types-tabulate" ] examples = [ "datachain[tests]", diff --git a/src/datachain/cli.py b/src/datachain/cli.py index f37d7a9e4..9859c1aff 100644 --- a/src/datachain/cli.py +++ b/src/datachain/cli.py @@ -4,18 +4,21 @@ import sys import traceback from argparse import Action, ArgumentParser, ArgumentTypeError, Namespace -from collections.abc import Iterable, Iterator, Mapping, Sequence +from collections.abc import Iterable, Iterator, Sequence from importlib.metadata import PackageNotFoundError, version from itertools import chain from multiprocessing import freeze_support from typing import TYPE_CHECKING, Optional, Union import shtab +from tabulate import tabulate from datachain import Session, utils from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs +from datachain.config import Config +from datachain.error import DataChainError from datachain.lib.dc import DataChain -from datachain.studio import process_studio_cli_args +from datachain.studio import list_datasets, process_studio_cli_args from datachain.telemetry import telemetry if TYPE_CHECKING: @@ -416,7 +419,36 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915 help="Dataset labels", ) - subp.add_parser("ls-datasets", parents=[parent_parser], description="List datasets") + datasets_parser = subp.add_parser( + "datasets", parents=[parent_parser], description="List datasets" + ) + datasets_parser.add_argument( + "--studio", + action="store_true", + default=False, + help="List the files in the Studio", + ) + datasets_parser.add_argument( + "-L", + "--local", + action="store_true", + default=False, + help="List local files only", + ) + datasets_parser.add_argument( + "-a", + "--all", + action="store_true", + default=True, + help="List all files including hidden files", + ) + datasets_parser.add_argument( + "--team", + action="store", + default=None, + help="The team to list datasets for. By default, it will use team from config.", + ) + rm_dataset_parser = subp.add_parser( "rm-dataset", parents=[parent_parser], description="Removes dataset" ) @@ -474,10 +506,30 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915 help="List files in the long format", ) parse_ls.add_argument( - "--remote", + "--studio", + action="store_true", + default=False, + help="List the files in the Studio", + ) + parse_ls.add_argument( + "-L", + "--local", + action="store_true", + default=False, + help="List local files only", + ) + parse_ls.add_argument( + "-a", + "--all", + action="store_true", + default=True, + help="List all files including hidden files", + ) + parse_ls.add_argument( + "--team", action="store", - default="", - help="Name of remote to use", + default=None, + help="The team to list datasets for. By default, it will use team from config.", ) parse_du = subp.add_parser( @@ -758,11 +810,12 @@ def format_ls_entry(entry: str) -> str: def ls_remote( paths: Iterable[str], long: bool = False, + team: Optional[str] = None, ): from datachain.node import long_line_str from datachain.remote.studio import StudioClient - client = StudioClient() + client = StudioClient(team=team) first = True for path, response in client.ls(paths): if not first: @@ -789,28 +842,66 @@ def ls_remote( def ls( sources, long: bool = False, - remote: str = "", - config: Optional[Mapping[str, str]] = None, + studio: bool = False, + local: bool = False, + all: bool = True, + team: Optional[str] = None, **kwargs, ): - if config is None: - from .config import Config + token = Config().read().get("studio", {}).get("token") + all, local, studio = _determine_flavors(studio, local, all, token) - config = Config().get_remote_config(remote=remote) - remote_type = config["type"] - if remote_type == "local": + if all or local: ls_local(sources, long=long, **kwargs) - else: - ls_remote( - sources, - long=long, + + if (all or studio) and token: + ls_remote(sources, long=long, team=team) + + +def datasets( + catalog: "Catalog", + studio: bool = False, + local: bool = False, + all: bool = True, + team: Optional[str] = None, +): + token = Config().read().get("studio", {}).get("token") + all, local, studio = _determine_flavors(studio, local, all, token) + + local_datasets = set(list_datasets_local(catalog)) if all or local else set() + studio_datasets = ( + set(list_datasets(team=team)) if (all or studio) and token else set() + ) + + rows = [ + _datasets_tabulate_row( + name=name, + version=version, + both=(all or (local and studio)) and token, + local=(name, version) in local_datasets, + studio=(name, version) in studio_datasets, ) + for name, version in local_datasets.union(studio_datasets) + ] + + print(tabulate(rows, headers="keys")) -def ls_datasets(catalog: "Catalog"): +def list_datasets_local(catalog: "Catalog"): for d in catalog.ls_datasets(): for v in d.versions: - print(f"{d.name} (v{v.version})") + yield (d.name, v.version) + + +def _datasets_tabulate_row(name, version, both, local, studio): + row = { + "Name": name, + "Version": version, + } + if both: + row["Studio"] = "\u2714" if studio else "\u2716" + row["Local"] = "\u2714" if local else "\u2716" + return row def rm_dataset( @@ -953,6 +1044,20 @@ def completion(shell: str) -> str: ) +def _determine_flavors(studio: bool, local: bool, all: bool, token: Optional[str]): + if studio and not token: + raise DataChainError( + "Not logged in to Studio. Log in with 'datachain studio login'." + ) + + if local or studio: + all = False + + all = all and not (local or studio) + + return all, local, studio + + def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR0915 # Required for Windows multiprocessing support freeze_support() @@ -1032,12 +1137,21 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09 ls( args.sources, long=bool(args.long), - remote=args.remote, + studio=args.studio, + local=args.local, + all=args.all, + team=args.team, update=bool(args.update), client_config=client_config, ) - elif args.command == "ls-datasets": - ls_datasets(catalog) + elif args.command == "datasets": + datasets( + catalog=catalog, + studio=args.studio, + local=args.local, + all=args.all, + team=args.team, + ) elif args.command == "show": show( catalog, diff --git a/src/datachain/remote/studio.py b/src/datachain/remote/studio.py index 6e432549c..c0b1bb001 100644 --- a/src/datachain/remote/studio.py +++ b/src/datachain/remote/studio.py @@ -131,6 +131,12 @@ def _send_request_msgpack(self, route: str, data: dict[str, Any]) -> Response[An timeout=self.timeout, ) ok = response.ok + if not ok: + if response.status_code == 403: + message = f"Not authorized for the team {self.team}" + raise DataChainError(message) + logger.error("Got bad response from Studio") + content = msgpack.unpackb(response.content, ext_hook=self._unpacker_hook) response_data = content.get("data") if ok and response_data is None: @@ -177,8 +183,12 @@ def _send_request(self, route: str, data: dict[str, Any]) -> Response[Any]: response.content.decode("utf-8"), ) if response.status_code == 403: - message = "Not authorized" + message = f"Not authorized for the team {self.team}" else: + logger.error( + "Got bad response from Studio, content is %s", + response.content.decode("utf-8"), + ) message = data.get("message", "") else: message = "" @@ -214,7 +224,7 @@ def ls(self, paths: Iterable[str]) -> Iterator[tuple[str, Response[LsData]]]: # to handle cases where a path will be expanded (i.e. globs) response: Response[LsData] for path in paths: - response = self._send_request_msgpack("ls", {"source": path}) + response = self._send_request_msgpack("datachain/ls", {"source": path}) yield path, response def ls_datasets(self) -> Response[LsData]: diff --git a/src/datachain/studio.py b/src/datachain/studio.py index 373b4a7b3..9ef390179 100644 --- a/src/datachain/studio.py +++ b/src/datachain/studio.py @@ -1,8 +1,11 @@ import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional + +from tabulate import tabulate from datachain.catalog.catalog import raise_remote_error from datachain.config import Config, ConfigLevel +from datachain.dataset import QUERY_DATASET_PREFIX from datachain.error import DataChainError from datachain.remote.studio import StudioClient from datachain.utils import STUDIO_URL @@ -24,7 +27,13 @@ def process_studio_cli_args(args: "Namespace"): if args.cmd == "token": return token() if args.cmd == "datasets": - return list_datasets(args) + rows = [ + {"Name": name, "Version": version} + for name, version in list_datasets(args.team) + ] + print(tabulate(rows, headers="keys")) + return 0 + if args.cmd == "team": return set_team(args) raise DataChainError(f"Unknown command '{args.cmd}'.") @@ -103,19 +112,22 @@ def token(): print(token) -def list_datasets(args: "Namespace"): - client = StudioClient(team=args.team) +def list_datasets(team: Optional[str] = None): + client = StudioClient(team=team) response = client.ls_datasets() if not response.ok: raise_remote_error(response.message) if not response.data: - print("No datasets found.") return + for d in response.data: name = d.get("name") + if name and name.startswith(QUERY_DATASET_PREFIX): + continue + for v in d.get("versions", []): version = v.get("version") - print(f"{name} (v{version})") + yield (name, version) def save_config(hostname, token): diff --git a/tests/conftest.py b/tests/conftest.py index 83520e079..a9bd17f6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ from datachain.catalog import Catalog from datachain.catalog.loader import get_id_generator, get_metastore, get_warehouse from datachain.cli_utils import CommaSeparatedArgs +from datachain.config import Config, ConfigLevel from datachain.data_storage.sqlite import ( SQLiteDatabaseEngine, SQLiteIDGenerator, @@ -26,6 +27,7 @@ from datachain.utils import ( ENV_DATACHAIN_GLOBAL_CONFIG_DIR, ENV_DATACHAIN_SYSTEM_CONFIG_DIR, + STUDIO_URL, DataChainDir, ) @@ -673,3 +675,29 @@ def dataset_rows(): } for i in range(19) ] + + +@pytest.fixture +def studio_datasets(requests_mock): + with Config(ConfigLevel.GLOBAL).edit() as conf: + conf["studio"] = {"token": "isat_access_token", "team": "team_name"} + + datasets = [ + { + "id": 1, + "name": "dogs", + "versions": [{"version": 1}, {"version": 2}], + }, + { + "id": 2, + "name": "cats", + "versions": [{"version": 1}], + }, + { + "id": 3, + "name": "both", + "versions": [{"version": 1}], + }, + ] + + requests_mock.post(f"{STUDIO_URL}/api/datachain/ls-datasets", json=datasets) diff --git a/tests/func/test_ls.py b/tests/func/test_ls.py index 57d300004..06ddeb402 100644 --- a/tests/func/test_ls.py +++ b/tests/func/test_ls.py @@ -15,8 +15,8 @@ from tests.utils import uppercase_scheme -@pytest.fixture(autouse=True) -def studio_config(): +@pytest.fixture +def studio_config(global_config_dir): with Config(ConfigLevel.GLOBAL).edit() as conf: conf["studio"] = {"token": "isat_access_token", "team": "team_name"} @@ -235,20 +235,11 @@ def _pack_extended_types(obj): """ -def test_ls_remote_sources(cloud_type, capsys, monkeypatch): +def test_ls_remote_sources(cloud_type, capsys, monkeypatch, studio_config): src = f"{cloud_type}://bucket" - token = "35NmrvSlsGVxTYIglxSsBIQHRrMpi6irSSYcAL0flijOytCHc" # noqa: S105 with monkeypatch.context() as m: m.setattr("requests.post", mock_post) - ls( - [src, f"{src}/dogs/others", f"{src}/dogs"], - config={ - "type": "http", - "url": "http://localhost:8111/api/datachain", - "username": "datachain-team", - "token": f"isat_{token}", - }, - ) + ls([src, f"{src}/dogs/others", f"{src}/dogs"], studio=True) captured = capsys.readouterr() assert captured.out == ls_remote_sources_output.format(src=src) diff --git a/tests/test_cli_e2e.py b/tests/test_cli_e2e.py index 909f4f348..7dc6e5206 100644 --- a/tests/test_cli_e2e.py +++ b/tests/test_cli_e2e.py @@ -4,6 +4,15 @@ from textwrap import dedent import pytest +import tabulate + + +def _tabulated_datasets(name, version): + row = [ + {"Name": name, "Version": version}, + ] + return tabulate.tabulate(row, headers="keys") + MNT_FILE_TREE = { "01375.png": 324, @@ -138,27 +147,27 @@ }, }, { - "command": ("datachain", "ls-datasets"), - "expected": "mnt (v1)\n", + "command": ("datachain", "datasets"), + "expected": _tabulated_datasets("mnt", 1), }, { - "command": ("datachain", "ls-datasets"), - "expected": "mnt (v1)\n", + "command": ("datachain", "datasets"), + "expected": _tabulated_datasets("mnt", 1), }, { "command": ("datachain", "edit-dataset", "mnt", "--new-name", "mnt-new"), "expected": "", }, { - "command": ("datachain", "ls-datasets"), - "expected": "mnt-new (v1)\n", + "command": ("datachain", "datasets"), + "expected": _tabulated_datasets("mnt-new", 1), }, { "command": ("datachain", "rm-dataset", "mnt-new", "--version", "1"), "expected": "", }, { - "command": ("datachain", "ls-datasets"), + "command": ("datachain", "datasets"), "expected": "", }, { @@ -200,7 +209,7 @@ def run_step(step, catalog): step["expected"].lstrip("\n").split("\n") ) else: - assert result.stdout == step["expected"].lstrip("\n") + assert result.stdout.strip("\n") == step["expected"].strip("\n") if step.get("listing"): assert "Listing" in result.stderr else: diff --git a/tests/test_cli_studio.py b/tests/test_cli_studio.py index 81c52e799..02054196e 100644 --- a/tests/test_cli_studio.py +++ b/tests/test_cli_studio.py @@ -1,4 +1,5 @@ from dvc_studio_client.auth import AuthorizationExpiredError +from tabulate import tabulate from datachain.cli import main from datachain.config import Config, ConfigLevel @@ -84,28 +85,20 @@ def test_studio_token(capsys): assert main(["studio", "token"]) == 1 -def test_studio_ls_datasets(capsys, requests_mock): - with Config(ConfigLevel.GLOBAL).edit() as conf: - conf["studio"] = {"token": "isat_access_token", "team": "team_name"} +def test_studio_ls_datasets(capsys, studio_datasets): + assert main(["studio", "datasets"]) == 0 + out = capsys.readouterr().out - datasets = [ + expected_rows = [ + {"Name": "dogs", "Version": "1"}, + {"Name": "dogs", "Version": "2"}, { - "id": 1, - "name": "dogs", - "versions": [{"version": 1}, {"version": 2}], - }, - { - "id": 2, - "name": "cats", - "versions": [{"version": 1}], + "Name": "cats", + "Version": "1", }, + {"Name": "both", "Version": "1"}, ] - - requests_mock.post(f"{STUDIO_URL}/api/datachain/ls-datasets", json=datasets) - - assert main(["studio", "datasets"]) == 0 - out = capsys.readouterr().out - assert out.strip() == "dogs (v1)\ndogs (v2)\ncats (v1)" + assert out.strip() == tabulate(expected_rows, headers="keys") def test_studio_team_local(): @@ -118,3 +111,56 @@ def test_studio_team_global(): assert main(["studio", "team", "team_name", "--global"]) == 0 config = Config(ConfigLevel.GLOBAL).read() assert config["studio"]["team"] == "team_name" + + +def test_studio_datasets(capsys, studio_datasets, mocker): + def list_datasets_local(_): + yield "local", 1 + yield "both", 1 + + mocker.patch("datachain.cli.list_datasets_local", side_effect=list_datasets_local) + local_rows = [ + {"Name": "both", "Version": "1"}, + {"Name": "local", "Version": "1"}, + ] + local_output = tabulate(local_rows, headers="keys") + + studio_rows = [ + {"Name": "both", "Version": "1"}, + { + "Name": "cats", + "Version": "1", + }, + {"Name": "dogs", "Version": "1"}, + {"Name": "dogs", "Version": "2"}, + ] + studio_output = tabulate(studio_rows, headers="keys") + + both_rows = [ + {"Name": "both", "Version": "1", "Studio": "\u2714", "Local": "\u2714"}, + {"Name": "cats", "Version": "1", "Studio": "\u2714", "Local": "\u2716"}, + {"Name": "dogs", "Version": "1", "Studio": "\u2714", "Local": "\u2716"}, + {"Name": "dogs", "Version": "2", "Studio": "\u2714", "Local": "\u2716"}, + {"Name": "local", "Version": "1", "Studio": "\u2716", "Local": "\u2714"}, + ] + both_output = tabulate(both_rows, headers="keys") + + assert main(["datasets", "--local"]) == 0 + out = capsys.readouterr().out + assert sorted(out.splitlines()) == sorted(local_output.splitlines()) + + assert main(["datasets", "--studio"]) == 0 + out = capsys.readouterr().out + assert sorted(out.splitlines()) == sorted(studio_output.splitlines()) + + assert main(["datasets", "--local", "--studio"]) == 0 + out = capsys.readouterr().out + assert sorted(out.splitlines()) == sorted(both_output.splitlines()) + + assert main(["datasets", "--all"]) == 0 + out = capsys.readouterr().out + assert sorted(out.splitlines()) == sorted(both_output.splitlines()) + + assert main(["datasets"]) == 0 + out = capsys.readouterr().out + assert sorted(out.splitlines()) == sorted(both_output.splitlines())