Skip to content

Commit

Permalink
feat(cli): enhance dataset listing with additional options for studio (
Browse files Browse the repository at this point in the history
…#561)

* feat(cli): enhance dataset listing with additional options for studio

With this change, the following variation are available to list datasets
instead of `datachain ls-datasets`:

- `datachain datasets --local`
- `datachain datasets --studio`
- `datachain datasets --all` (Default option)
- `datachain datasets --team TEAM_NAME`
- `datachain datasets`

By default, if the user has logged in to Studio using `datachain studio login`,
it will try to fetch the datasets from both local and studio.
If any specific option (local or studio) is passed, only that option is
passed.

The team name is parsed from config (Set by  `datachain studio team <TEAM_NAME>`).

The same feature is added for `datachain ls` function and added the
error handling for each scenarios.

Relates to #10774

* Modify the result format with tabular format

* Fix tests for ls

* test ls

* Fix tests for ls

* Fix ls remote sources

* Deduplicate
  • Loading branch information
amritghimire authored Nov 7, 2024
1 parent be081a4 commit 10c4e2a
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 72 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ dependencies = [
"huggingface_hub",
"iterative-telemetry>=0.0.9",
"platformdirs",
"dvc-studio-client>=0.21,<1"
"dvc-studio-client>=0.21,<1",
"tabulate"
]

[project.optional-dependencies]
Expand Down Expand Up @@ -98,7 +99,8 @@ dev = [
"types-python-dateutil",
"types-pytz",
"types-PyYAML",
"types-requests"
"types-requests",
"types-tabulate"
]
examples = [
"datachain[tests]",
Expand Down
160 changes: 137 additions & 23 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
import sys
import traceback
from argparse import Action, ArgumentParser, ArgumentTypeError, Namespace
from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import Iterable, Iterator, Sequence
from importlib.metadata import PackageNotFoundError, version
from itertools import chain
from multiprocessing import freeze_support
from typing import TYPE_CHECKING, Optional, Union

import shtab
from tabulate import tabulate

from datachain import Session, utils
from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
from datachain.config import Config
from datachain.error import DataChainError
from datachain.lib.dc import DataChain
from datachain.studio import process_studio_cli_args
from datachain.studio import list_datasets, process_studio_cli_args
from datachain.telemetry import telemetry

if TYPE_CHECKING:
Expand Down Expand Up @@ -416,7 +419,36 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
help="Dataset labels",
)

subp.add_parser("ls-datasets", parents=[parent_parser], description="List datasets")
datasets_parser = subp.add_parser(
"datasets", parents=[parent_parser], description="List datasets"
)
datasets_parser.add_argument(
"--studio",
action="store_true",
default=False,
help="List the files in the Studio",
)
datasets_parser.add_argument(
"-L",
"--local",
action="store_true",
default=False,
help="List local files only",
)
datasets_parser.add_argument(
"-a",
"--all",
action="store_true",
default=True,
help="List all files including hidden files",
)
datasets_parser.add_argument(
"--team",
action="store",
default=None,
help="The team to list datasets for. By default, it will use team from config.",
)

rm_dataset_parser = subp.add_parser(
"rm-dataset", parents=[parent_parser], description="Removes dataset"
)
Expand Down Expand Up @@ -474,10 +506,30 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
help="List files in the long format",
)
parse_ls.add_argument(
"--remote",
"--studio",
action="store_true",
default=False,
help="List the files in the Studio",
)
parse_ls.add_argument(
"-L",
"--local",
action="store_true",
default=False,
help="List local files only",
)
parse_ls.add_argument(
"-a",
"--all",
action="store_true",
default=True,
help="List all files including hidden files",
)
parse_ls.add_argument(
"--team",
action="store",
default="",
help="Name of remote to use",
default=None,
help="The team to list datasets for. By default, it will use team from config.",
)

parse_du = subp.add_parser(
Expand Down Expand Up @@ -758,11 +810,12 @@ def format_ls_entry(entry: str) -> str:
def ls_remote(
paths: Iterable[str],
long: bool = False,
team: Optional[str] = None,
):
from datachain.node import long_line_str
from datachain.remote.studio import StudioClient

client = StudioClient()
client = StudioClient(team=team)
first = True
for path, response in client.ls(paths):
if not first:
Expand All @@ -789,28 +842,66 @@ def ls_remote(
def ls(
sources,
long: bool = False,
remote: str = "",
config: Optional[Mapping[str, str]] = None,
studio: bool = False,
local: bool = False,
all: bool = True,
team: Optional[str] = None,
**kwargs,
):
if config is None:
from .config import Config
token = Config().read().get("studio", {}).get("token")
all, local, studio = _determine_flavors(studio, local, all, token)

config = Config().get_remote_config(remote=remote)
remote_type = config["type"]
if remote_type == "local":
if all or local:
ls_local(sources, long=long, **kwargs)
else:
ls_remote(
sources,
long=long,

if (all or studio) and token:
ls_remote(sources, long=long, team=team)


def datasets(
catalog: "Catalog",
studio: bool = False,
local: bool = False,
all: bool = True,
team: Optional[str] = None,
):
token = Config().read().get("studio", {}).get("token")
all, local, studio = _determine_flavors(studio, local, all, token)

local_datasets = set(list_datasets_local(catalog)) if all or local else set()
studio_datasets = (
set(list_datasets(team=team)) if (all or studio) and token else set()
)

rows = [
_datasets_tabulate_row(
name=name,
version=version,
both=(all or (local and studio)) and token,
local=(name, version) in local_datasets,
studio=(name, version) in studio_datasets,
)
for name, version in local_datasets.union(studio_datasets)
]

print(tabulate(rows, headers="keys"))


def ls_datasets(catalog: "Catalog"):
def list_datasets_local(catalog: "Catalog"):
for d in catalog.ls_datasets():
for v in d.versions:
print(f"{d.name} (v{v.version})")
yield (d.name, v.version)


def _datasets_tabulate_row(name, version, both, local, studio):
row = {
"Name": name,
"Version": version,
}
if both:
row["Studio"] = "\u2714" if studio else "\u2716"
row["Local"] = "\u2714" if local else "\u2716"
return row


def rm_dataset(
Expand Down Expand Up @@ -953,6 +1044,20 @@ def completion(shell: str) -> str:
)


def _determine_flavors(studio: bool, local: bool, all: bool, token: Optional[str]):
if studio and not token:
raise DataChainError(
"Not logged in to Studio. Log in with 'datachain studio login'."
)

if local or studio:
all = False

all = all and not (local or studio)

return all, local, studio


def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR0915
# Required for Windows multiprocessing support
freeze_support()
Expand Down Expand Up @@ -1032,12 +1137,21 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
ls(
args.sources,
long=bool(args.long),
remote=args.remote,
studio=args.studio,
local=args.local,
all=args.all,
team=args.team,
update=bool(args.update),
client_config=client_config,
)
elif args.command == "ls-datasets":
ls_datasets(catalog)
elif args.command == "datasets":
datasets(
catalog=catalog,
studio=args.studio,
local=args.local,
all=args.all,
team=args.team,
)
elif args.command == "show":
show(
catalog,
Expand Down
14 changes: 12 additions & 2 deletions src/datachain/remote/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ def _send_request_msgpack(self, route: str, data: dict[str, Any]) -> Response[An
timeout=self.timeout,
)
ok = response.ok
if not ok:
if response.status_code == 403:
message = f"Not authorized for the team {self.team}"
raise DataChainError(message)
logger.error("Got bad response from Studio")

content = msgpack.unpackb(response.content, ext_hook=self._unpacker_hook)
response_data = content.get("data")
if ok and response_data is None:
Expand Down Expand Up @@ -177,8 +183,12 @@ def _send_request(self, route: str, data: dict[str, Any]) -> Response[Any]:
response.content.decode("utf-8"),
)
if response.status_code == 403:
message = "Not authorized"
message = f"Not authorized for the team {self.team}"
else:
logger.error(
"Got bad response from Studio, content is %s",
response.content.decode("utf-8"),
)
message = data.get("message", "")
else:
message = ""
Expand Down Expand Up @@ -214,7 +224,7 @@ def ls(self, paths: Iterable[str]) -> Iterator[tuple[str, Response[LsData]]]:
# to handle cases where a path will be expanded (i.e. globs)
response: Response[LsData]
for path in paths:
response = self._send_request_msgpack("ls", {"source": path})
response = self._send_request_msgpack("datachain/ls", {"source": path})
yield path, response

def ls_datasets(self) -> Response[LsData]:
Expand Down
24 changes: 18 additions & 6 deletions src/datachain/studio.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from tabulate import tabulate

from datachain.catalog.catalog import raise_remote_error
from datachain.config import Config, ConfigLevel
from datachain.dataset import QUERY_DATASET_PREFIX
from datachain.error import DataChainError
from datachain.remote.studio import StudioClient
from datachain.utils import STUDIO_URL
Expand All @@ -24,7 +27,13 @@ def process_studio_cli_args(args: "Namespace"):
if args.cmd == "token":
return token()
if args.cmd == "datasets":
return list_datasets(args)
rows = [
{"Name": name, "Version": version}
for name, version in list_datasets(args.team)
]
print(tabulate(rows, headers="keys"))
return 0

if args.cmd == "team":
return set_team(args)
raise DataChainError(f"Unknown command '{args.cmd}'.")
Expand Down Expand Up @@ -103,19 +112,22 @@ def token():
print(token)


def list_datasets(args: "Namespace"):
client = StudioClient(team=args.team)
def list_datasets(team: Optional[str] = None):
client = StudioClient(team=team)
response = client.ls_datasets()
if not response.ok:
raise_remote_error(response.message)
if not response.data:
print("No datasets found.")
return

for d in response.data:
name = d.get("name")
if name and name.startswith(QUERY_DATASET_PREFIX):
continue

for v in d.get("versions", []):
version = v.get("version")
print(f"{name} (v{version})")
yield (name, version)


def save_config(hostname, token):
Expand Down
28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from datachain.catalog import Catalog
from datachain.catalog.loader import get_id_generator, get_metastore, get_warehouse
from datachain.cli_utils import CommaSeparatedArgs
from datachain.config import Config, ConfigLevel
from datachain.data_storage.sqlite import (
SQLiteDatabaseEngine,
SQLiteIDGenerator,
Expand All @@ -26,6 +27,7 @@
from datachain.utils import (
ENV_DATACHAIN_GLOBAL_CONFIG_DIR,
ENV_DATACHAIN_SYSTEM_CONFIG_DIR,
STUDIO_URL,
DataChainDir,
)

Expand Down Expand Up @@ -673,3 +675,29 @@ def dataset_rows():
}
for i in range(19)
]


@pytest.fixture
def studio_datasets(requests_mock):
with Config(ConfigLevel.GLOBAL).edit() as conf:
conf["studio"] = {"token": "isat_access_token", "team": "team_name"}

datasets = [
{
"id": 1,
"name": "dogs",
"versions": [{"version": 1}, {"version": 2}],
},
{
"id": 2,
"name": "cats",
"versions": [{"version": 1}],
},
{
"id": 3,
"name": "both",
"versions": [{"version": 1}],
},
]

requests_mock.post(f"{STUDIO_URL}/api/datachain/ls-datasets", json=datasets)
Loading

0 comments on commit 10c4e2a

Please sign in to comment.