Skip to content

Commit

Permalink
Ruff the repo
Browse files Browse the repository at this point in the history
  • Loading branch information
bentaculum committed May 30, 2024
1 parent c6f011a commit 7412a48
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
2 changes: 0 additions & 2 deletions trackastra/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path
from typing import Literal

import numpy as np
import torch

# from torch_geometric.nn import GATv2Conv
Expand Down Expand Up @@ -506,4 +505,3 @@ def from_folder(
model.load_state_dict(state)

return model

11 changes: 5 additions & 6 deletions trackastra/model/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import logging
import os
from typing import Optional
import shutil
import tempfile
import zipfile
from pathlib import Path
from typing import Optional

import requests
from pydantic import validate_call
from tqdm import tqdm
import tempfile
import shutil

logger = logging.getLogger(__name__)

Expand All @@ -23,7 +22,7 @@ def download_and_unzip(url: str, dst: Path):
if dst.exists():
print(f"{dst} already downloaded, skipping.")
return

# get the name of the zipfile
zip_base = Path(url.split("/")[-1])

Expand Down Expand Up @@ -60,7 +59,7 @@ def download_pretrained(name: str, download_dir: Optional[Path] = None):
# TODO make safe, introduce versioning
if download_dir is None:
download_dir = Path("~/.trackastra/.models").expanduser()
else:
else:
download_dir = Path(download_dir)

download_dir.mkdir(exist_ok=True, parents=True)
Expand Down

0 comments on commit 7412a48

Please sign in to comment.