Skip to content

Commit

Permalink
Set better headers for HTTP requests
Browse files Browse the repository at this point in the history
  • Loading branch information
jwodder committed Nov 8, 2023
1 parent aab966e commit b98703b
Showing 1 changed file with 67 additions and 38 deletions.
105 changes: 67 additions & 38 deletions src/datalad_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@
ON_WINDOWS = SYSTEM == "Windows"
ON_POSIX = ON_LINUX or ON_MACOS

USER_AGENT = "datalad-installer/{} ({}) {}/{}".format(
__version__,
__url__,
platform.python_implementation(),
platform.python_version(),
)


class SudoConfirm(Enum):
ASK = "ask"
Expand Down Expand Up @@ -2396,10 +2403,13 @@ def __init__(self, auth_required: bool = True) -> None:
" environment variable or hub.oauthtoken Git config option."
)
token = r.stdout.strip()
self.headers = {
"Accept": "application/vnd.github+json",
"User-Agent": USER_AGENT,
"X-GitHub-Api-Version": "2022-11-28",
}
if token:
self.headers = {"Authorization": f"Bearer {token}"}
else:
self.headers = {}
self.headers["Authorization"] = f"Bearer {token}"

@contextmanager
def get(self, url: str) -> Iterator[Any]:
Expand All @@ -2408,8 +2418,8 @@ def get(self, url: str) -> Iterator[Any]:
try:
with urlopen(req) as r:
yield r
except URLError as e:
raise_for_ratelimit(e, self.headers.get("Authorization"))
except HTTPError as e:
self.raise_for_ratelimit(e)
raise

def getjson(self, url: str) -> Any:
Expand Down Expand Up @@ -2470,7 +2480,15 @@ def download_latest_artifact(
archive_download_url = self.get_archive_download_url(artifacts_url)
if archive_download_url is not None:
log.info("Downloading artifact package from %s", archive_download_url)
download_zipfile(archive_download_url, target_dir, headers=self.headers)
try:
download_zipfile(
archive_download_url,
target_dir,
headers={**self.headers, "Accept": "*/*"},
)
except HTTPError as e:
self.raise_for_ratelimit(e)
raise
return
else:
raise RuntimeError("No workflow runs with artifacts found!")
Expand All @@ -2492,7 +2510,15 @@ def download_last_successful_artifact(
archive_download_url = self.get_archive_download_url(artifacts_url)
if archive_download_url is not None:
log.info("Downloading artifact package from %s", archive_download_url)
download_zipfile(archive_download_url, target_dir, headers=self.headers)
try:
download_zipfile(
archive_download_url,
target_dir,
headers={**self.headers, "Accept": "*/*"},
)
except HTTPError as e:
self.raise_for_ratelimit(e)
raise
return
else:
raise RuntimeError("No workflow runs with artifacts found!")
Expand Down Expand Up @@ -2541,11 +2567,15 @@ def download_release_asset(
else:
asset = self.get_release_asset(repo, tag, ext)
target_dir.mkdir(parents=True, exist_ok=True)
download_file(
asset["browser_download_url"],
target_dir / asset["name"],
headers=self.headers,
)
try:
download_file(
asset["browser_download_url"],
target_dir / asset["name"],
headers={**self.headers, "Accept": "*/*"},
)
except HTTPError as e:
self.raise_for_ratelimit(e)
raise

def get_latest_release(self, repo: str) -> dict:
"""
Expand All @@ -2556,6 +2586,30 @@ def get_latest_release(self, repo: str) -> dict:
assert isinstance(data, dict)
return data

def raise_for_ratelimit(self, e: HTTPError) -> None:
if e.code == 403:
try:
resp = json.load(e)
except Exception:
return
if "API rate limit exceeded" in resp.get("message", ""):
if "Authorization" in self.headers:
url = "https://api.github.com/rate_limit"
log.debug("HTTP request: GET %s", url)
req = Request(url, headers=self.headers)
with urlopen(req) as r:
resp = json.load(r)
log.info(
"GitHub rate limit exceeded; details:\n\n%s\n",
textwrap.indent(json.dumps(resp, indent=4), " " * 4),
)
else:
raise RuntimeError(
"GitHub rate limit exceeded and GITHUB_TOKEN not set;"
" suggest setting GITHUB_TOKEN in order to get increased"
" rate limit"
)


class MethodNotSupportedError(Exception):
"""
Expand All @@ -2566,31 +2620,6 @@ class MethodNotSupportedError(Exception):
pass


def raise_for_ratelimit(e: URLError, auth: Optional[str]) -> None:
if isinstance(e, HTTPError) and e.code == 403:
try:
resp = json.load(e)
except Exception:
return
if "API rate limit exceeded" in resp.get("message", ""):
if auth is not None:
url = "https://api.github.com/rate_limit"
log.debug("HTTP request: GET %s", url)
req = Request(url, headers={"Authorization": auth})
with urlopen(req) as r:
resp = json.load(r)
log.info(
"GitHub rate limit exceeded; details:\n\n%s\n",
textwrap.indent(json.dumps(resp, indent=4), " " * 4),
)
else:
raise RuntimeError(
"GitHub rate limit exceeded and GITHUB_TOKEN not set;"
" suggest setting GITHUB_TOKEN in order to get increased"
" rate limit"
)


def download_file(
url: str, path: str | Path, headers: Optional[dict[str, str]] = None
) -> None:
Expand All @@ -2601,6 +2630,7 @@ def download_file(
log.info("Downloading %s", url)
if headers is None:
headers = {}
headers.setdefault("User-Agent", USER_AGENT)
delays = iter([1, 2, 6, 15, 36])
req = Request(url, headers=headers)
while True:
Expand All @@ -2618,7 +2648,6 @@ def download_file(
return
except URLError as e:
if isinstance(e, HTTPError) and e.code not in (500, 502, 503, 504):
raise_for_ratelimit(e, headers.get("Authorization"))
raise
try:
delay = next(delays)
Expand Down

0 comments on commit b98703b

Please sign in to comment.