diff --git a/xinfer/transformers/blip2.py b/xinfer/transformers/blip2.py index 9875ce1..76304b5 100644 --- a/xinfer/transformers/blip2.py +++ b/xinfer/transformers/blip2.py @@ -17,6 +17,31 @@ "transformers", ModelInputOutput.IMAGE_TEXT_TO_TEXT, ) +@register_model( + "Salesforce/blip2-flan-t5-xl", + "transformers", + ModelInputOutput.IMAGE_TEXT_TO_TEXT, +) +@register_model( + "Salesforce/blip2-flan-t5-xl-coco", + "transformers", + ModelInputOutput.IMAGE_TEXT_TO_TEXT, +) +@register_model( + "Salesforce/blip2-opt-2.7b-coco", + "transformers", + ModelInputOutput.IMAGE_TEXT_TO_TEXT, +) +@register_model( + "Gregor/mblip-mt0-xl", + "transformers", + ModelInputOutput.IMAGE_TEXT_TO_TEXT, +) +@register_model( + "Gregor/mblip-bloomz-7b", + "transformers", + ModelInputOutput.IMAGE_TEXT_TO_TEXT, +) class BLIP2(Vision2SeqModel): def __init__(self, model_id: str, **kwargs): super().__init__(model_id, **kwargs)