Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add params in init datafilter fn #63

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions DPF/filters/multigpu_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ def run_one_process(
results: list[pd.DataFrame],
filter_class: Optional[type[DataFilter]],
filter_kwargs: Optional[dict[str, Any]],
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]],
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device], dict[str, Any]], DataFilter]],
datafilter_init_fn_kwargs: dict[str, Any],
device: Union[str, torch.device],
filter_run_kwargs: dict[str, Any]
) -> None:
reader = DatasetReader(connector=connector)
processor = reader.from_df(config, df)
if datafilter_init_fn:
datafilter = datafilter_init_fn(i, device)
datafilter = datafilter_init_fn(i, device, datafilter_init_fn_kwargs)
else:
datafilter = filter_class(**filter_kwargs, _pbar_position=i, device=device) # type: ignore

Expand All @@ -51,7 +52,8 @@ def __init__(
devices: list[Union[torch.device, str]],
datafilter_class: Optional[type[DataFilter]] = None,
datafilter_params: Optional[dict[str, Any]] = None,
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]] = None
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device], dict[str, Any]], DataFilter]] = None,
datafilter_init_fn_kwargs: Optional[dict[str, Any]] = None,
):
"""
Parameters
Expand All @@ -62,19 +64,22 @@ def __init__(
Class of datafilter to use
datafilter_params: Optional[dict[str, Any]] = None
Parameters for datafilter_class initialization
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]] = None
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device], dict[str, Any]], DataFilter]] = None
Initialization function for a datafilter. Takes _pbar_position as first arg and device as a second arg
datafilter_init_fn_kwargs: Optional[dict[str, Any]] = None
Additional parameters for datafilter_init_fn
"""
self.filter_class = datafilter_class
self.filter_params = datafilter_params
self.datafilter_init_fn = datafilter_init_fn
self.datafilter_init_fn_kwargs = datafilter_init_fn_kwargs if datafilter_init_fn_kwargs is not None else {}
assert self.datafilter_init_fn or self.filter_class, "One method of filter initialization should be specified"
self.devices = devices
self.num_parts = len(devices)

# getting result columns names
if self.datafilter_init_fn:
datafilter = self.datafilter_init_fn(0, devices[0])
datafilter = self.datafilter_init_fn(0, devices[0], self.datafilter_init_fn_kwargs)
else:
datafilter = self.filter_class(**self.filter_params, device=devices[0]) # type: ignore
self._result_columns = datafilter.result_columns
Expand Down Expand Up @@ -127,6 +132,7 @@ def run(
self.filter_class,
self.filter_params,
self.datafilter_init_fn,
self.datafilter_init_fn_kwargs,
self.devices[i],
filter_run_kwargs
)
Expand Down
9 changes: 5 additions & 4 deletions docs/multi_gpu_filter.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@ To run a complex datafilter or if you want to manually create a datafilter class
from DPF.filters.images.llava_captioning_filter import LLaVaCaptioningFilter
from DPF.filters.multigpu_filter import MultiGPUDataFilter

def init_fn(pbar_pos: int, device: str):
print('INIT FN', pbar_pos, device)
def init_fn(pbar_pos: int, device: str, params: dict):
print('INIT FN', pbar_pos, device, params)

return LLaVaCaptioningFilter(
workers=8, prompt='short', batch_size=16,
workers=8, prompt=params['prompt'], batch_size=16,
device=device, _pbar_position=pbar_pos
)

multigpufilter = MultiGPUDataFilter(
['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'],
datafilter_init_fn=init_fn
datafilter_init_fn=init_fn,
datafilter_init_fn_kwargs={'prompt': 'short'}
)
processor.apply_multi_gpu_data_filter(multigpufilter)
```
Expand Down
Loading