From 69b90ec3c45a45003098c02e5c7d76496694ca4f Mon Sep 17 00:00:00 2001 From: Bryce Date: Wed, 17 Apr 2024 22:01:08 -0700 Subject: [PATCH] fix: remove unnecessary version checks --- imaginairy/vendored/clip/clip.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/imaginairy/vendored/clip/clip.py b/imaginairy/vendored/clip/clip.py index 197464cf..69f41c6e 100644 --- a/imaginairy/vendored/clip/clip.py +++ b/imaginairy/vendored/clip/clip.py @@ -8,7 +8,6 @@ import torch from PIL import Image -from pkg_resources import packaging from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor from tqdm import tqdm @@ -23,9 +22,6 @@ BICUBIC = Image.BICUBIC -if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): - warnings.warn("PyTorch version 1.7.1 or higher is recommended") - __all__ = ["available_models", "load", "tokenize"] _tokenizer = _Tokenizer() @@ -272,10 +268,7 @@ def tokenize( sot_token = _tokenizer.encoder["<|startoftext|>"] eot_token = _tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - else: - result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: