Skip to content

Commit

Permalink
updates for the holidays
Browse files Browse the repository at this point in the history
  • Loading branch information
Mortimerp9 committed Dec 9, 2024
1 parent e4c9f82 commit 4712848
Show file tree
Hide file tree
Showing 36 changed files with 7,036 additions and 1,140 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/lint_and_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
max-parallel: 1
matrix:
platform: [ubuntu-latest]
python-version: [3.8]
python-version: [3.9]

runs-on: ${{ matrix.platform }}

Expand All @@ -30,6 +30,7 @@ jobs:
# Fairseq doesn't install with pip==22.1 we need to upgrade past it.
# Also the version on pypi is from before Oct 2020.
# wheel is required by fasttext to be installed correctly with recent pip versions
# fairseq+omegaconf do not play nice when using pip > 24.1
run: |
python --version
python -m pip install --upgrade 'pip>=22.1.2,<24.1'
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "flit_core.buildapi"
name = "stopes"
readme = "README.md"
authors = [{name = "Facebook AI Research"}]
requires-python = ">=3.8"
requires-python = ">=3.9"
dynamic = ["version", "description"]

dependencies = [
Expand All @@ -15,7 +15,7 @@ dependencies = [
"submitit>=1.4.5",
"tqdm",
"posix_ipc",
"pyarrow>=13.0.0"
"pyarrow>=16.1.0"
]
# zip_safe = false
classifiers=[
Expand Down Expand Up @@ -75,7 +75,7 @@ classifiers=[
"torchaudio",
"scipy",
"pandas",
"pyarrow>=13.0.0",
"pyarrow>=16.1.0",
"numba",
"transformers",
"openai-whisper==20230314",
Expand Down
2 changes: 1 addition & 1 deletion stopes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
"""

__version__ = "2.1.0"
__version__ = "2.2.0"
45 changes: 37 additions & 8 deletions stopes/core/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,16 @@

from stopes.core import utils
from stopes.core.cache import Cache, MissingCache, NoCache
from stopes.core.jobs_registry.registry import JobsRegistry
from stopes.core.jobs_registry.submitit_slurm_job import SubmititJob
from stopes.core.jobs_registry.registry import JobsRegistry # type: ignore
from stopes.core.jobs_registry.submitit_slurm_job import SubmititJob # type: ignore


@dataclasses.dataclass
class SkipValue:
"""A value to skip in the array."""

expected_result: tp.Any


if tp.TYPE_CHECKING:
from stopes.core import StopesModule
Expand Down Expand Up @@ -222,6 +230,8 @@ def __init__(
log_folder: tp.Union[Path, str] = Path("executor_logs"),
cluster: str = "local",
partition: tp.Optional[str] = None,
qos: tp.Optional[str] = None,
account: tp.Optional[str] = None,
supports_mem_spec: bool = True, # some slurm clusters do not support mem_gb, if you set this to False, the Requirements.mem_gb coming from the module will be ignored
disable_tqdm: bool = False, # if you don't want tqdm progress bars
max_retries: int = 0,
Expand All @@ -237,6 +247,8 @@ def __init__(
- `log_folder` where to store execution logs for each job (default: `executor_logs`)
- `cluster`, a submitit cluster spec. `local` to run locally, `slurm` for slurm
- `partition`, the slurm partition to use
- `qos`, the slurm QOS (quality-of-service) to use
- `account`, the slurm account to use
- `supports_mem_spec`, ignore mem requirements for some cluster
- `disable_tqdm`, don't show that fancy progress bar
- `max_retries`, how many retries do we want for each job
Expand All @@ -255,6 +267,8 @@ def __init__(
self.log_folder = Path(log_folder)
self.cluster = cluster
self.partition = partition
self.qos = qos
self.account = account
self.supports_mem_spec = supports_mem_spec
self.disable_tqdm = disable_tqdm
self.progress_bar: tqdm.tqdm = None
Expand All @@ -265,9 +279,11 @@ def __init__(

if throttle:
self.throttle = utils.AsyncIPCSemaphore(
name=throttle.shared_name
if throttle.shared_name
else f"/launcher_{getpass.getuser()}_{uuid.uuid4()}",
name=(
throttle.shared_name
if throttle.shared_name
else f"/launcher_{getpass.getuser()}_{uuid.uuid4()}"
),
flags=posix_ipc.O_CREAT,
initial_value=throttle.limit,
timeout=throttle.timeout,
Expand Down Expand Up @@ -324,17 +340,24 @@ def _get_executor(self, module: "StopesModule") -> submitit.Executor:
name = module.name()
module_log_folder = self.log_folder / name
module_log_folder.mkdir(parents=True, exist_ok=True)
executor = AutoExecutor(folder=module_log_folder, cluster=self.cluster)

reqs = module.requirements()
executor = AutoExecutor(
folder=module_log_folder,
cluster=self.cluster,
slurm_max_num_timeout=3 if reqs is None else reqs.max_num_timeout,
)
# update launcher params
if self.update_parameters:
executor.update_parameters(**self.update_parameters)

# setup parameters
reqs = module.requirements()

executor.update_parameters(
name=module.name(),
slurm_partition=self.partition,
slurm_qos=self.qos,
slurm_account=self.account,
)
if self.partition and self.cluster == "slurm":
executor.update_parameters(
Expand Down Expand Up @@ -439,7 +462,13 @@ async def _schedule_array(

for idx, val in enumerate(value_array):
task = Task(module, idx, val, launcher=self)
# first, look up if this iteration has already been cached
# first, look up if this iteration has been skipped
if isinstance(val, SkipValue):
task = task.done_from_cache(val.expected_result)
self.progress_job_end()
tasks.append(task)
continue
# second, look up if this iteration has already been cached
try:
cached_result = self.cache.get_cache(
module,
Expand Down
1 change: 1 addition & 0 deletions stopes/core/stopes_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Requirements:
cpus_per_task: int = 5
timeout_min: int = 720
constraint: tp.Optional[str] = None
max_num_timeout: int = 10


class StopesModule(ABC):
Expand Down
2 changes: 1 addition & 1 deletion stopes/modules/bitext/mining/mine_bitext_indexes_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
INDICES_FILE_SUFFIX = ".idx"

# when requesting more neighbors than elements in the index, FAISS returns -1
INVALID_INDEX_VALUE = np.uint32(-1)
INVALID_INDEX_VALUE = np.int32(-1)
INVALID_INDEX_REPLACEMENT = 0
INVALID_DISTANCES_REPLACEMENT = 2.0

Expand Down
Loading

0 comments on commit 4712848

Please sign in to comment.