Skip to content

Commit

Permalink
fix: prevent verbose json output in router
Browse files Browse the repository at this point in the history
  • Loading branch information
joshprk committed Jan 7, 2025
1 parent 5e7c2b7 commit df56cf1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
19 changes: 11 additions & 8 deletions aios/llm_core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,17 @@ def address_syscall(
model = self.strategy()

if isinstance(model, (str, HfLocalBackend, VLLMLocalBackend, OllamaBackend)):
res = model(
messages=messages,
temperature=temperature,
) if not isinstance(model, str) else str(completion(
model=model,
messages=messages,
temperature=temperature,
))
if not isinstance(model, str):
res = model(
messages=messages,
temperature=temperature,
)
else:
res = completion(
model=model,
messages=messages,
temperature=temperature,
).choices[0].message.content
else:
raise RuntimeError(f"Unsupported model type: {type(model)}")

Expand Down
12 changes: 6 additions & 6 deletions aios/llm_core/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def __init__(self, model_name, device="auto", max_gpu_memory=None, hostname=None
self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}"

def inference_online(self, messages, temperature, stream=False):
return str(completion(
return completion(
model="huggingface/" + self.model_name,
messages=messages,
temperature=temperature,
api_base=self.hostname,
))
).choices[0].message.content

def __call__(
self,
messages,
Expand All @@ -50,7 +50,7 @@ def __call__(
):
if self.hostname is not None:
return self.inference_online(messages, temperature, stream=stream)

if stream:
raise NotImplemented

Expand Down Expand Up @@ -105,12 +105,12 @@ def __init__(self, model_name, device="auto", max_gpu_memory=None, hostname=None
print("Error loading vllm model:", err)

def inference_online(self, messages, temperatures, stream=False):
return str(completion(
return completion(
model="hosted_vllm/" + self.model_name,
messages=messages,
temperature=temperature,
api_base=self.hostname,
))
).choices[0].message.content

def __call__(
self,
Expand Down

0 comments on commit df56cf1

Please sign in to comment.