Skip to content

Commit

Permalink
remote: separate tree functions from main remote classes (#3931)
Browse files Browse the repository at this point in the history
* BaseRemote: start moving tree functions into BaseRemoteTree

* remote: add LocalRemoteTree

* remote: add cloud remote trees

* remote: move remove(), makedirs() into tree

* remote: move move(), copy(), symlink(), hardlink(), reflink() into tree

* fix attributes for moved remote methods

* tests: update unit tests for moved remote tree methods

* tests: mv unit/remote/test_remote_dir.py unit/remote/test_remote_tree.py

* fix attributes for moved remote tree functions

* tests: update func tests for moved remote tree methods

* test fixes

* fix DS warnings
  • Loading branch information
pmrowla authored Jun 3, 2020
1 parent c03ce4a commit 2c0116c
Show file tree
Hide file tree
Showing 26 changed files with 783 additions and 691 deletions.
2 changes: 1 addition & 1 deletion dvc/data_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def pull(
def _save_pulled_checksums(self, cache):
for checksum in cache.scheme_keys("local"):
cache_file = self.repo.cache.local.checksum_to_path_info(checksum)
if self.repo.cache.local.exists(cache_file):
if self.repo.cache.local.tree.exists(cache_file):
# We can safely save here, as existing corrupted files will
# be removed upon status, while files corrupted during
# download will not be moved from tmp_file
Expand Down
12 changes: 6 additions & 6 deletions dvc/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def is_dir_checksum(self):

@property
def exists(self):
return self.remote.exists(self.path_info)
return self.remote.tree.exists(self.path_info)

def save_info(self):
return self.remote.save_info(self.path_info)
Expand Down Expand Up @@ -217,13 +217,13 @@ def changed(self):

@property
def is_empty(self):
return self.remote.is_empty(self.path_info)
return self.remote.tree.is_empty(self.path_info)

def isdir(self):
return self.remote.isdir(self.path_info)
return self.remote.tree.isdir(self.path_info)

def isfile(self):
return self.remote.isfile(self.path_info)
return self.remote.tree.isfile(self.path_info)

def ignore(self):
if not self.use_scm_ignore:
Expand Down Expand Up @@ -326,7 +326,7 @@ def checkout(
)

def remove(self, ignore_remove=False):
self.remote.remove(self.path_info)
self.remote.tree.remove(self.path_info)
if self.scheme != "local":
return

Expand All @@ -337,7 +337,7 @@ def move(self, out):
if self.scheme == "local" and self.use_scm_ignore:
self.repo.scm.ignore_remove(self.fspath)

self.remote.move(self.path_info, out.path_info)
self.remote.tree.move(self.path_info, out.path_info)
self.def_path = out.def_path
self.path_info = out.path_info
self.save()
Expand Down
67 changes: 37 additions & 30 deletions dvc/remote/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,53 @@

from dvc.path_info import CloudURLInfo
from dvc.progress import Tqdm
from dvc.remote.base import BaseRemote
from dvc.remote.base import BaseRemote, BaseRemoteTree
from dvc.scheme import Schemes

logger = logging.getLogger(__name__)


class AzureRemoteTree(BaseRemoteTree):
@property
def blob_service(self):
return self.remote.blob_service

def _generate_download_url(self, path_info, expires=3600):
from azure.storage.blob import BlobPermissions

expires_at = datetime.utcnow() + timedelta(seconds=expires)

sas_token = self.blob_service.generate_blob_shared_access_signature(
path_info.bucket,
path_info.path,
permission=BlobPermissions.READ,
expiry=expires_at,
)
download_url = self.blob_service.make_blob_url(
path_info.bucket, path_info.path, sas_token=sas_token
)
return download_url

def exists(self, path_info):
paths = self.remote.list_paths(path_info.bucket, path_info.path)
return any(path_info.path == path for path in paths)

def remove(self, path_info):
if path_info.scheme != self.scheme:
raise NotImplementedError

logger.debug(f"Removing {path_info}")
self.blob_service.delete_blob(path_info.bucket, path_info.path)


class AzureRemote(BaseRemote):
scheme = Schemes.AZURE
path_cls = CloudURLInfo
REQUIRES = {"azure-storage-blob": "azure.storage.blob"}
PARAM_CHECKSUM = "etag"
COPY_POLL_SECONDS = 5
LIST_OBJECT_PAGE_SIZE = 5000
TREE_CLS = AzureRemoteTree

def __init__(self, repo, config):
super().__init__(repo, config)
Expand Down Expand Up @@ -65,14 +99,7 @@ def get_etag(self, path_info):
def get_file_checksum(self, path_info):
return self.get_etag(path_info)

def remove(self, path_info):
if path_info.scheme != self.scheme:
raise NotImplementedError

logger.debug(f"Removing {path_info}")
self.blob_service.delete_blob(path_info.bucket, path_info.path)

def _list_paths(self, bucket, prefix, progress_callback=None):
def list_paths(self, bucket, prefix, progress_callback=None):
blob_service = self.blob_service
next_marker = None
while True:
Expand All @@ -97,7 +124,7 @@ def list_cache_paths(self, prefix=None, progress_callback=None):
)
else:
prefix = self.path_info.path
return self._list_paths(
return self.list_paths(
self.path_info.bucket, prefix, progress_callback
)

Expand All @@ -122,23 +149,3 @@ def _download(
to_file,
progress_callback=pbar.update_to,
)

def exists(self, path_info):
paths = self._list_paths(path_info.bucket, path_info.path)
return any(path_info.path == path for path in paths)

def _generate_download_url(self, path_info, expires=3600):
from azure.storage.blob import BlobPermissions

expires_at = datetime.utcnow() + timedelta(seconds=expires)

sas_token = self.blob_service.generate_blob_shared_access_signature(
path_info.bucket,
path_info.path,
permission=BlobPermissions.READ,
expiry=expires_at,
)
download_url = self.blob_service.make_blob_url(
path_info.bucket, path_info.path, sas_token=sas_token
)
return download_url
Loading

0 comments on commit 2c0116c

Please sign in to comment.