From 6560c75527e07479fceddec4b37e40583917abcf Mon Sep 17 00:00:00 2001 From: Atsuki Yamaguchi <30075338+gucci-j@users.noreply.github.com> Date: Tue, 7 Jan 2025 15:20:48 +0000 Subject: [PATCH] Fix `T_co` import bug (#484) * Fix T_co import bug * Fix styling --- src/lighteval/data.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/lighteval/data.py b/src/lighteval/data.py index 7cb105e6..dd1fd534 100644 --- a/src/lighteval/data.py +++ b/src/lighteval/data.py @@ -25,8 +25,15 @@ from typing import Iterator, Tuple import torch +from packaging import version from torch.utils.data import Dataset -from torch.utils.data.distributed import DistributedSampler, T_co + + +if version.parse(torch.__version__) >= version.parse("2.5.0"): + from torch.utils.data.distributed import DistributedSampler, _T_co +else: + from torch.utils.data.distributed import DistributedSampler + from torch.utils.data.distributed import T_co as _T_co from lighteval.tasks.requests import ( GreedyUntilRequest, @@ -318,7 +325,7 @@ class GenDistributedSampler(DistributedSampler): as our samples are sorted by length. """ - def __iter__(self) -> Iterator[T_co]: + def __iter__(self) -> Iterator[_T_co]: if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator()