Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of docstrings to the files #61

Merged
merged 7 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions google/generativeai/discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@


def _make_message(content: discuss_types.MessageOptions) -> glm.Message:
"""Creates a `glm.Message` object from the provided content."""
if isinstance(content, glm.Message):
return content
if isinstance(content, str):
Expand All @@ -39,6 +40,20 @@ def _make_message(content: discuss_types.MessageOptions) -> glm.Message:


def _make_messages(messages: discuss_types.MessagesOptions) -> List[glm.Message]:
"""
Creates a list of `glm.Message` objects from the provided messages.

This function takes a variety of message content inputs, such as strings, dictionaries,
or `glm.Message` objects, and creates a list of `glm.Message` objects. It ensures that
the authors of the messages alternate appropriately. If authors are not provided,
default authors are assigned based on their position in the list.

Args:
messages: The messages to convert.

Returns:
A list of `glm.Message` objects with alternating authors.
"""
if isinstance(messages, (str, dict, glm.Message)):
messages = [_make_message(messages)]
else:
Expand Down Expand Up @@ -71,6 +86,7 @@ def _make_messages(messages: discuss_types.MessagesOptions) -> List[glm.Message]


def _make_example(item: discuss_types.ExampleOptions) -> glm.Example:
"""Creates a `glm.Example` object from the provided item."""
if isinstance(item, glm.Example):
return item

Expand All @@ -91,6 +107,21 @@ def _make_example(item: discuss_types.ExampleOptions) -> glm.Example:
def _make_examples_from_flat(
examples: List[discuss_types.MessageOptions],
) -> List[glm.Example]:
"""
Creates a list of `glm.Example` objects from a list of message options.

This function takes a list of `discuss_types.MessageOptions` and pairs them into
`glm.Example` objects. The input examples must be in pairs to create valid examples.

Args:
examples: The list of `discuss_types.MessageOptions`.

Returns:
A list of `glm.Example objects` created by pairing up the provided messages.

Raises:
ValueError: If the provided list of examples is not of even length.
"""
if len(examples) % 2 != 0:
raise ValueError(
textwrap.dedent(
Expand All @@ -116,6 +147,19 @@ def _make_examples_from_flat(


def _make_examples(examples: discuss_types.ExamplesOptions) -> List[glm.Example]:
"""
Creates a list of `glm.Example` objects from the provided examples.

This function takes various types of example content inputs and creates a list
of `glm.Example` objects. It handles the conversion of different input types and ensures
the appropriate structure for creating valid examples.

Args:
examples: The examples to convert.

Returns:
A list of `glm.Example` objects created from the provided examples.
"""
if isinstance(examples, glm.Example):
return [examples]

Expand Down Expand Up @@ -155,6 +199,23 @@ def _make_message_prompt_dict(
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
) -> glm.MessagePrompt:
"""
Creates a `glm.MessagePrompt` object from the provided prompt components.

This function constructs a `glm.MessagePrompt` object using the provided `context`, `examples`,
or `messages`. It ensures the proper structure and handling of the input components.

Either pass a `prompt` or it's component `context`, `examples`, `messages`.

Args:
prompt: The complete prompt components.
context: The context for the prompt.
examples: The examples for the prompt.
messages: The messages for the prompt.

Returns:
A `glm.MessagePrompt` object created from the provided prompt components.
"""
if prompt is None:
prompt = dict(
context=context,
Expand Down Expand Up @@ -201,6 +262,7 @@ def _make_message_prompt(
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
) -> glm.MessagePrompt:
"""Creates a `glm.MessagePrompt` object from the provided prompt components."""
prompt = _make_message_prompt_dict(
prompt=prompt, context=context, examples=examples, messages=messages
)
Expand All @@ -219,6 +281,7 @@ def _make_generate_message_request(
top_k: float | None = None,
prompt: discuss_types.MessagePromptOptions | None = None,
) -> glm.GenerateMessageRequest:
"""Creates a `glm.GenerateMessageRequest` object for generating messages."""
model = model_types.make_model_name(model)

prompt = _make_message_prompt(
Expand All @@ -236,6 +299,8 @@ def _make_generate_message_request(


def set_doc(doc):
"""A decorator to set the docstring of a function."""

def inner(f):
f.__doc__ = doc
return f
Expand Down
47 changes: 43 additions & 4 deletions google/generativeai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import re
from typing import Optional, List
from typing import Optional, List, Iterator

import google.ai.generativelanguage as glm
from google.generativeai.client import get_default_model_client
Expand All @@ -35,6 +35,22 @@ def get_model(name: str, *, client=None) -> model_types.Model:


class ModelsIterable(model_types.ModelsIterable):
"""
An iterable class to traverse through a list of models.

This class allows you to iterate over a list of models, fetching them in pages
if necessary based on the provided `page_size` and `page_token`.

Args:
page_size: The number of `models` to fetch per page.
page_token: Token representing the current page. Pass `None` for the first page.
models: List of models to iterate through.
client: An optional client for the model service.

Returns:
A `ModelsIterable` iterable object that allows iterating through the models.
"""

def __init__(
self,
*,
Expand All @@ -48,21 +64,44 @@ def __init__(
self._models = models
self._client = client

def __iter__(self):
def __iter__(self) -> Iterator[model_types.Model]:
"""
Returns an iterator over the models.
"""
while self:
page = self._models
yield from page
self = self._next_page()

def _next_page(self):
def _next_page(self) -> ModelsIterable | None:
"""
Fetches the next page of models based on the page token.
"""
if not self._page_token:
return None
return _list_models(
page_size=self._page_size, page_token=self._page_token, client=self._client
)


def _list_models(page_size, page_token, client):
def _list_models(
page_size: int, page_token: str | None, client: glm.ModelServiceClient
) -> ModelsIterable:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too

Fetches a page of models using the provided client and pagination tokens.

This function queries the `client` to retrieve a page of models based on the given
`page_size` and `page_token`. It then processes the response and returns an iterable
object to traverse through the models.

Args:
page_size: How many `types.Models` to fetch per page (api call).
page_token: Token representing the current page.
client: The client to communicate with the model service.

Returns:
An iterable `ModelsIterable` object containing the fetched models and pagination info.
"""
result = client.list_models(page_size=page_size, page_token=page_token)
result = result._response
result = type(result).to_dict(result)
Expand Down
45 changes: 45 additions & 0 deletions google/generativeai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@


def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt:
"""
Creates a `glm.TextPrompt` object based on the provided prompt input.

Args:
prompt: The prompt input, either a string or a dictionary.

Returns:
glm.TextPrompt: A TextPrompt object containing the prompt text.

Raises:
TypeError: If the provided prompt is neither a string nor a dictionary.
"""
if isinstance(prompt, str):
return glm.TextPrompt(text=prompt)
elif isinstance(prompt, dict):
Expand All @@ -49,6 +61,28 @@ def _make_generate_text_request(
safety_settings: safety_types.SafetySettingOptions | None = None,
stop_sequences: str | Iterable[str] | None = None,
) -> glm.GenerateTextRequest:
"""
Creates a `glm.GenerateTextRequest` object based on the provided parameters.

This function generates a `glm.GenerateTextRequest` object with the specified
parameters. It prepares the input parameters and creates a request that can be
used for generating text using the chosen model.

Args:
model: The model to use for text generation.
prompt: The prompt for text generation. Defaults to None.
temperature: The temperature for randomness in generation. Defaults to None.
candidate_count: The number of candidates to consider. Defaults to None.
max_output_tokens: The maximum number of output tokens. Defaults to None.
top_p: The nucleus sampling probability threshold. Defaults to None.
top_k: The top-k sampling parameter. Defaults to None.
safety_settings: Safety settings for generated text. Defaults to None.
stop_sequences: Stop sequences to halt text generation. Can be a string
or iterable of strings. Defaults to None.

Returns:
`glm.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters.
"""
model = model_types.make_model_name(model)
prompt = _make_text_prompt(prompt=prompt)
safety_settings = safety_types.normalize_safety_settings(safety_settings)
Expand Down Expand Up @@ -155,6 +189,17 @@ def __init__(self, **kwargs):
def _generate_response(
request: glm.GenerateTextRequest, client: glm.TextServiceClient = None
) -> Completion:
"""
Generates a response using the provided `glm.GenerateTextRequest` and client.

Args:
request: The text generation request.
client: The client to use for text generation. Defaults to None, in which
case the default text client is used.

Returns:
`Completion`: A `Completion` object with the generated text and response information.
"""
if client is None:
client = get_default_text_client()

Expand Down
1 change: 1 addition & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from absl.testing import absltest

import google.ai.generativelanguage as glm

from google.ai.generativelanguage_v1beta2.types import model

from google.generativeai import models
Expand Down