Skip to content

Commit

Permalink
Merge branch 'main' into mm-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jan 30, 2024
2 parents 0a7a6c1 + 38f7b94 commit 967bd66
Show file tree
Hide file tree
Showing 19 changed files with 865 additions and 1,289 deletions.
5 changes: 1 addition & 4 deletions promptlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from .promptlib import PromptNode, append, begin, begin_chat, end
from .promptlib import PromptNode, append, begin, end
from .promptlib import PromptProgram
from .promptlib import gen, choose, wait
from .promptlib import set_model
from .promptlib import constrain
from .promptlib import LLM, TransformersLLM
from .promptlib import AICI
13 changes: 2 additions & 11 deletions promptlib/promptlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@

from .prompt import PromptNode, append, begin, begin_chat, end, PromptProgram
from .prompt import PromptNode, append, begin, end, PromptProgram
from .gen import gen, choose, wait
from .model import set_model, begin_assistant, begin_user, begin_system
from .constrain import constrain

from .models import LLM, TransformersLLM

from .aici import AICI

setattr(PromptNode, "append", append)
setattr(PromptNode, "begin", begin)
setattr(PromptNode, "begin_chat", begin_chat)
setattr(PromptNode, "begin_system", begin_system)
setattr(PromptNode, "begin_assistant", begin_assistant)
setattr(PromptNode, "begin_user", begin_user)
setattr(PromptNode, "end", end)
setattr(PromptNode, "gen", gen)
setattr(PromptNode, "choose", choose)
setattr(PromptNode, "wait", wait)
setattr(PromptNode, "constrain", constrain)
setattr(PromptNode, "set_model", set_model)

67 changes: 8 additions & 59 deletions promptlib/promptlib/aici.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import subprocess
import requests
import ujson
import json
import sys
import os
import re
from typing import Optional

import pyaici.rest as aici_rest

class AICI:
# TODO remove this default base_url once we deploy a semi-permanent server
def __init__(self, base_url="http://127.0.0.1:8080/v1/", wasm_runner_id=None, wasm_runner_path=None, wasm_runner_buildsh=None ):
def __init__(self, base_url=None, wasm_runner_id=None, wasm_runner_path=None, wasm_runner_buildsh=None ):
self.base_url = base_url

if wasm_runner_id is None:
Expand All @@ -33,7 +34,7 @@ def _compile_wasm(wasm_runner_buildsh, scriptargs=["build"]):
if r.returncode != 0:
raise RuntimeError(f"error compiling aici promptlib module")

file_path = script_dir + "/../target/strip.wasm"
file_path = script_dir + "/target/strip.wasm"
return file_path


Expand All @@ -42,7 +43,8 @@ def _upload_wasm(base_url, wasm_runner_path):
with open(wasm_runner_path, "rb") as f:
resp = requests.post(base_url + "aici_modules", data=f)
if resp.status_code == 200:
dd = resp.json()
d = resp.json()
dd = d["data"]
mod_id = dd["module_id"]
print(
f"{dd['wasm_size']//1024}kB -> {dd['compiled_size']//1024}kB id:{mod_id[0:8]}"
Expand All @@ -55,57 +57,4 @@ def _upload_wasm(base_url, wasm_runner_path):


def _submit_program(base_url, aici_module, aici_arg, temperature=0, max_tokens=200, n=1, log=False):
json = {
"model": "",
"prompt": "",
"max_tokens": max_tokens,
"n": n,
"temperature": temperature,
"stream": True,
"aici_module": aici_module,
"aici_arg": aici_arg,
}
resp = requests.post(base_url + "completions", json=json, stream=True)
if resp.status_code != 200:
raise RuntimeError(
f"bad response to completions: {resp.status_code} {resp.reason}: {resp.text}"
)
full_resp = []
texts = [""] * n
for line in resp.iter_lines():
if line:
decoded_line: str = line.decode("utf-8")
if decoded_line.startswith("data: {"):
d = ujson.decode(decoded_line[6:])
full_resp.append(d)
for ch in d["choices"]:
idx = ch["index"]
if idx == 0:
if log:
l = ch["logs"].rstrip("\n")
if "Previous WASM Error" in l:
raise "Bailing out due to WASM error: " + l
#else:
# print(ch["text"], end="")
# make sure texts is long enough
while len(texts) <= idx:
texts.append("")
texts[idx] += ch["text"]
#elif decoded_line == "data: [DONE]":
# print(" [DONE]")
#else:
# print(decoded_line)

if len(texts) == 1:
print(texts[0])
else:
print(texts)
os.makedirs("tmp", exist_ok=True)
path = "tmp/response.json"
with open(path, "w") as f:
ujson.dump(
{"request": json, "texts": texts, "response": full_resp}, f, indent=1
)

return texts, json, full_resp
#print(f"response saved to {path}")
return aici_rest.completion("", aici_module, aici_arg, temperature, max_tokens, n, ignore_eos=False, base_url=base_url)
30 changes: 0 additions & 30 deletions promptlib/promptlib/constrain.py

This file was deleted.

3 changes: 0 additions & 3 deletions promptlib/promptlib/constraints/__init__.py

This file was deleted.

196 changes: 0 additions & 196 deletions promptlib/promptlib/constraints/cfg_constraint.py

This file was deleted.

20 changes: 0 additions & 20 deletions promptlib/promptlib/constraints/constraint.py

This file was deleted.

Loading

0 comments on commit 967bd66

Please sign in to comment.