-
Notifications
You must be signed in to change notification settings - Fork 464
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d1104b3
commit 4bf71d1
Showing
74 changed files
with
4,009 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from cerebrum.client import Cerebrum | ||
|
||
class Config: | ||
def __init__(self): | ||
self._global_client: Cerebrum = None | ||
self._base_url = "http://localhost:8000" | ||
self._timeout = 30 | ||
|
||
@property | ||
def global_client(self): | ||
if not self._global_client: | ||
raise ValueError("Client not set. Call config.client = Cerebrum Client") | ||
return self._global_client | ||
|
||
@global_client.setter | ||
def global_client(self, value): | ||
self._global_client = value | ||
|
||
def configure(self, **kwargs): | ||
"""Configure multiple settings at once""" | ||
for key, value in kwargs.items(): | ||
if hasattr(self, f"_{key}"): | ||
setattr(self, f"_{key}", value) | ||
|
||
config = Config() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import os, json | ||
from typing import Any | ||
|
||
from cerebrum.client import Cerebrum | ||
from cerebrum import config | ||
from cerebrum.interface import AutoLLM, AutoTool | ||
|
||
class BaseAgent: | ||
def __init__(self, agent_name, task_input, config_): | ||
self.agent_name = agent_name | ||
# self.config = self._load_config(dir) | ||
self.config = config_ | ||
|
||
config.global_client = Cerebrum() | ||
# self.send_request = AutoLLM.from_dynamic().process | ||
self.send_request = None | ||
|
||
self.tools, self.tool_info = AutoTool.from_batch_preload(self.config["tools"]).values() | ||
|
||
|
||
# def _load_config(self, dir: str): | ||
# # script_path = os.path.abspath(__file__) | ||
# # script_dir = os.path.dirname(script_path) | ||
# # print('script dir', script_dir) | ||
# # config_file = os.path.join(script_dir, "config.json") | ||
# config_file = os.path.join(dir, "config.json") | ||
|
||
# with open(config_file, "r") as f: | ||
# config = json.load(f) | ||
# return config | ||
|
||
def pre_select_tools(self, tool_names): | ||
pre_selected_tools = [] | ||
for tool_name in tool_names: | ||
for tool in self.tools: | ||
if tool["function"]["name"] == tool_name: | ||
pre_selected_tools.append(tool) | ||
break | ||
|
||
return pre_selected_tools | ||
|
||
def run(self) -> Any: | ||
return {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
from dataclasses import asdict | ||
from typing import Optional, Dict, Any, List | ||
import requests | ||
|
||
from cerebrum.llm.communication import LLMQuery | ||
from cerebrum.overrides.layer import OverridesLayer | ||
from cerebrum.memory.layer import MemoryLayer | ||
from cerebrum.llm.layer import LLMLayer | ||
from cerebrum.tool.layer import ToolLayer | ||
from cerebrum.storage.layer import StorageLayer | ||
|
||
|
||
class Cerebrum: | ||
def __init__(self, base_url: str = "http://localhost:8000"): | ||
self.base_url = base_url.rstrip('/') | ||
self._components_initialized = set() | ||
self.results = {} | ||
|
||
def _post(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Make a POST request to the specified endpoint.""" | ||
response = requests.post(f"{self.base_url}{endpoint}", json=data) | ||
response.raise_for_status() | ||
return response.json() | ||
|
||
def _get(self, endpoint: str) -> Dict[str, Any]: | ||
"""Make a GET request to the specified endpoint.""" | ||
response = requests.get(f"{self.base_url}{endpoint}") | ||
response.raise_for_status() | ||
return response.json() | ||
|
||
def _query_llm(self, agent_name: str, query: LLMQuery): | ||
result = self._post("/query", { | ||
'query_type': 'llm', | ||
'agent_name': agent_name, | ||
'query_data': query.model_dump()}) | ||
|
||
return result | ||
|
||
def add_llm_layer(self, config: LLMLayer) -> 'Cerebrum': | ||
"""Set up the LLM core component.""" | ||
result = self._post("/core/llm/setup", asdict(config)) | ||
self._components_initialized.add("llm") | ||
self.results['llm'] = result | ||
return self | ||
|
||
def add_storage_layer(self, config: StorageLayer) -> 'Cerebrum': | ||
"""Set up the storage manager component.""" | ||
result = self._post("/core/storage/setup", asdict(config)) | ||
self._components_initialized.add("storage") | ||
self.results['storage'] = result | ||
return self | ||
|
||
def add_memory_layer(self, config: MemoryLayer) -> 'Cerebrum': | ||
"""Set up the memory manager component.""" | ||
if "storage" not in self._components_initialized: | ||
raise ValueError( | ||
"Storage manager must be initialized before memory manager") | ||
|
||
result = self._post("/core/memory/setup", asdict(config)) | ||
self._components_initialized.add("memory") | ||
self.results['memory'] = result | ||
return self | ||
|
||
def add_tool_layer(self, config: ToolLayer) -> 'Cerebrum': | ||
"""Set up the tool manager component.""" | ||
result = self._post("/core/tool/setup", asdict(config)) | ||
self._components_initialized.add("tool") | ||
self.results['tool'] = result | ||
return self | ||
|
||
def setup_agent_factory(self, config: OverridesLayer) -> 'Cerebrum': | ||
"""Set up the agent factory for managing agent execution.""" | ||
required_components = {"llm", "memory", "storage", "tool"} | ||
missing_components = required_components - self._components_initialized | ||
|
||
if missing_components: | ||
raise ValueError( | ||
f"Missing required components: {', '.join(missing_components)}" | ||
) | ||
|
||
result = self._post("/core/factory/setup", asdict(config)) | ||
self._components_initialized.add("factory") | ||
self.results['factory'] = result | ||
return self | ||
|
||
def execute(self, agent_id: str, agent_config: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Submit an agent for execution.""" | ||
if "factory" not in self._components_initialized: | ||
raise ValueError( | ||
"Agent factory must be initialized before submitting agents") | ||
|
||
return self._post("/agents/submit", { | ||
"agent_id": agent_id, | ||
"agent_config": agent_config | ||
}) | ||
|
||
def get_agent_status(self, execution_id: str) -> Dict[str, Any]: | ||
"""Get the status of a submitted agent.""" | ||
if "factory" not in self._components_initialized: | ||
raise ValueError( | ||
"Agent factory must be initialized before checking agent status") | ||
|
||
return self._get(f"/agents/{execution_id}/status") | ||
|
||
def poll_agent(self, execution_id: str, polling_interval: float = 1.0, timeout: Optional[float] = None) -> Dict[str, Any]: | ||
"""Wait for an agent to complete execution.""" | ||
import time | ||
start_time = time.time() | ||
|
||
while True: | ||
status = self.get_agent_status(execution_id) | ||
if status["status"] == "completed": | ||
return status["result"] | ||
|
||
if timeout and (time.time() - start_time) > timeout: | ||
raise TimeoutError( | ||
f"Agent execution {execution_id} did not complete within {timeout} seconds") | ||
|
||
time.sleep(polling_interval) | ||
|
||
def override_scheduler(self, config: OverridesLayer) -> 'Cerebrum': | ||
"""Set up the FIFO scheduler with all components.""" | ||
required_components = {"llm", "memory", "storage", "tool"} | ||
missing_components = required_components - self._components_initialized | ||
|
||
if missing_components: | ||
raise ValueError( | ||
f"Missing required components: {', '.join(missing_components)}" | ||
) | ||
|
||
result = self._post("/core/scheduler/setup", asdict(config)) | ||
self._components_initialized.add("scheduler") | ||
self.results['scheduler'] = result | ||
return self | ||
|
||
def get_status(self) -> Dict[str, str]: | ||
"""Get the status of all core components.""" | ||
return self._get("/core/status") | ||
|
||
def cleanup(self) -> Dict[str, Any]: | ||
"""Clean up all active components.""" | ||
result = self._post("/core/cleanup", {}) | ||
self._components_initialized.clear() | ||
return result | ||
|
||
def connect(self) -> 'Cerebrum': | ||
if (self.results.get('scheduler') is None): | ||
self.override_scheduler(OverridesLayer(max_workers=32)) | ||
|
||
self.setup_agent_factory(OverridesLayer(max_workers=32)) | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
self.cleanup() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from cerebrum.manager.agent import AgentManager | ||
from cerebrum.manager.tool import ToolManager | ||
from cerebrum.runtime.process import LLMProcessor, RunnableAgent | ||
|
||
from .. import config | ||
|
||
class AutoAgent: | ||
AGENT_MANAGER = AgentManager('https://my.aios.foundation') | ||
|
||
@classmethod | ||
def from_preloaded(cls, agent_name: str): | ||
_client = config.global_client | ||
|
||
return RunnableAgent(_client, agent_name) | ||
|
||
|
||
class AutoLLM: | ||
@classmethod | ||
def from_dynamic(cls): | ||
return LLMProcessor(config.global_client) | ||
|
||
|
||
class AutoTool: | ||
TOOL_MANAGER = ToolManager('https://my.aios.foundation') | ||
|
||
@classmethod | ||
def from_preloaded(cls, tool_name: str): | ||
if tool_name.split('/')[0] != 'core': | ||
author, name, version = cls.TOOL_MANAGER.download_tool( | ||
author=tool_name.split('/')[0], | ||
name=tool_name.split('/')[1] | ||
) | ||
|
||
tool, _ = cls.TOOL_MANAGER.load_tool(author, name, version) | ||
else: | ||
tool, _ = cls.TOOL_MANAGER.load_tool(local=True, name=tool_name) | ||
|
||
#return tool instance, not class | ||
return tool() | ||
|
||
@classmethod | ||
def from_batch_preload(cls, tool_names: list[str]): | ||
response = { | ||
'tools': [], | ||
'tool_info': [] | ||
} | ||
|
||
for tool_name in tool_names: | ||
tool = AutoTool.from_preloaded(tool_name) | ||
|
||
response['tools'].append(tool.get_tool_call_format()) | ||
response['tool_info'].append( | ||
{ | ||
"name": tool.get_tool_call_format()["function"]["name"], | ||
"description": tool.get_tool_call_format()["function"]["description"], | ||
} | ||
) | ||
|
||
|
||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from pydantic import BaseModel, Field | ||
from typing import List, Dict, Optional, Any, Union | ||
from typing_extensions import Literal | ||
|
||
class Request(BaseModel): | ||
pass | ||
|
||
class LLMQuery(Request): | ||
""" | ||
Query class represents the input structure for performing various actions. | ||
Attributes: | ||
messages (List[Dict[str, Union[str, Any]]]): A list of dictionaries where each dictionary | ||
represents a message containing 'role' and 'content' or other key-value pairs. | ||
tools (Optional[List[Dict[str, Any]]]): An optional list of JSON-like objects (dictionaries) | ||
representing tools and their parameters. Default is an empty list. | ||
action_type (Literal): A string that must be one of "message_llm", "call_tool", or "operate_file". | ||
This restricts the type of action the query performs. | ||
message_return_type (str): The type of the response message. Default is "text". | ||
""" | ||
messages: List[Dict[str, Union[str, Any]]] # List of message dictionaries, each containing role and content. | ||
tools: Optional[List[Dict[str, Any]]] = Field(default_factory=list) # List of JSON-like objects (dictionaries) representing tools. | ||
action_type: Literal["chat", "tool_use", "operate_file"] = Field(default="chat") # Restrict the action_type to specific choices. | ||
message_return_type: str = Field(default="text") # Type of the return message, default is "text". | ||
|
||
class Config: | ||
arbitrary_types_allowed = True # Allows the use of arbitrary types such as Any and Dict. | ||
|
||
class Response(BaseModel): | ||
""" | ||
Response class represents the output structure after performing actions. | ||
Attributes: | ||
response_message (Optional[str]): The generated response message. Default is None. | ||
tool_calls (Optional[List[Dict[str, Any]]]): An optional list of JSON-like objects (dictionaries) | ||
representing the tool calls made during processing. Default is None. | ||
""" | ||
response_message: Optional[str] = None # The generated response message, default is None. | ||
tool_calls: Optional[List[Dict[str, Any]]] = None # List of JSON-like objects representing tool calls, default is None. | ||
finished: bool | ||
|
||
class Config: | ||
arbitrary_types_allowed = True # Allows arbitrary types in validation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from dataclasses import dataclass | ||
|
||
@dataclass | ||
class LLMLayer: | ||
llm_name: str | ||
max_gpu_memory: dict | None = None | ||
eval_device: str = "cuda:0" | ||
max_new_tokens: int = 2048 | ||
log_mode: str = "console" | ||
use_backend: str = "default" |
Empty file.
Oops, something went wrong.