Skip to content

Commit

Permalink
chore: GenAI - Added unit test for image loading and mime types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627645004
  • Loading branch information
Ark-kun authored and Copybara-Service committed Apr 24, 2024
1 parent c56dd50 commit 5083fb2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
23 changes: 23 additions & 0 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

# pylint: disable=protected-access,bad-continuation
import io
import pytest
from typing import Iterable, MutableSequence, Optional
from unittest import mock
Expand Down Expand Up @@ -907,6 +908,28 @@ def test_chat_automatic_function_calling(self):
chat2.send_message("What is the weather like in Boston?")
assert err.match("Exceeded the maximum")

@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_image_mime_types(self, generative_models: generative_models):
# Importing external library lazily to reduce the scope
from PIL import Image as PIL_Image # pylint: disable=g-import-not-at-top

image_format_to_mime_type = {
"PNG": "image/png",
"JPEG": "image/jpeg",
"GIF": "image/gif",
}

pil_image: PIL_Image.Image = PIL_Image.new(mode="RGB", size=(200, 200))
for image_format, mime_type in image_format_to_mime_type.items():
image_bytes_io = io.BytesIO()
pil_image.save(image_bytes_io, format=image_format)
image = generative_models.Image.from_bytes(image_bytes_io.getvalue())
image_part = generative_models.Part.from_image(image)
assert image_part.mime_type == mime_type


EXPECTED_SCHEMA_FOR_GET_CURRENT_WEATHER = {
"title": "get_current_weather",
Expand Down
8 changes: 7 additions & 1 deletion vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,7 +1891,13 @@ def text(self) -> str:

@property
def mime_type(self) -> Optional[str]:
return self._raw_part.mime_type
part_type = self._raw_part._pb.WhichOneof("data")
if part_type == "inline_data":
return self._raw_part.inline_data.mime_type
elif part_type == "file_data":
return self._raw_part.file_data.mime_type
else:
raise AttributeError(f"Part has no mime_type.\nPart:\n{self.to_dict()}")

@property
def inline_data(self) -> gapic_content_types.Blob:
Expand Down

0 comments on commit 5083fb2

Please sign in to comment.