Skip to content

Commit

Permalink
Merge pull request #674 from Bob17293729/zongbo-dev
Browse files Browse the repository at this point in the history
code exection class and test cases
  • Loading branch information
research4pan authored Dec 2, 2023
2 parents 4124102 + b2aa8fa commit 8516a72
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 0 deletions.
45 changes: 45 additions & 0 deletions examples/tool_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import argparse
from lmflow.args import InferencerArguments
from lmflow.args import ModelArguments
from lmflow.args import DatasetArguments
from lmflow.models import hf_decoder_model
from lmflow.pipeline.inferencer import ToolInferencer
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=str, default='0',
help='gpu id, currently speculative inference only support single gpu')
parser.add_argument('--model', type=str, default='codellama/CodeLlama-7b-instruct-hf',
help='target code generation model name or path you \
currently only supports huggingface decoder only models')
params = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = params.gpu

model_args = ModelArguments(model_name_or_path=params.model)
model = hf_decoder_model.HFDecoderModel(model_args)
inferencer_args = InferencerArguments()
data_args = DatasetArguments()

toolinf = ToolInferencer(model_args, data_args, inferencer_args)

while True:
try:
text = input("Tool Inference: ")
toolinf_res = toolinf.inference(model, text)
toolinf_res = toolinf_res.replace("<s>","")
toolinf_res = toolinf_res.replace("</s>","")
print('\n\nResult:')
print(toolinf_res)
print('\n\n')
run_code = input("Run code? (y/n): ")
if run_code == 'y':
toolinf.code_exec(toolinf_res)
if run_code == 'n':
continue


except EOFError:
break

if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions scripts/run_tool.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model="gorilla-llm/gorilla-7b-hf-delta-v1"
python examples/tool_inference.py \
--model ${model} \
86 changes: 86 additions & 0 deletions src/lmflow/pipeline/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
from typing import Dict, List
from concurrent.futures import ThreadPoolExecutor
import subprocess

from transformers import AutoConfig
import torch.distributed as dist
Expand Down Expand Up @@ -553,3 +554,88 @@ def speculative_sampling(input_ids: torch.Tensor,

def stream_inference(self):
raise NotImplementedError("Streaming output for SpeculativeInferencer is not supported yet")

class ToolInferencer(Inferencer):
"""
Initializes the `ToolInferencer` class with given arguments.
Parameters
------------
model_args : ModelArguments object.
Contains the arguments required to load the model.
data_args : DatasetArguments object.
Contains the arguments required to load the dataset.
inferencer_args : InferencerArguments object.
Contains the arguments required to perform inference.
"""
def __init__(self, model_args, data_args, inferencer_args):
super().__init__(model_args, data_args, inferencer_args)

self.model = HFDecoderModel(self.model_args)

def inference(
self,
model: HFDecoderModel,
input: str,
max_new_tokens: int=1024,
):
"""
Perform inference for a model
Parameters
------------
model : HFDecoderModel object.
TunableModel to perform inference
input : str.
The input text (i.e., the prompt) for the model.
max_new_tokens : int.
The maximum number of tokens to be generated by the model.
Returns:
output : str.
The output text generated by the model.
"""
if self.inferencer_args.device == "gpu":
input_id = model.encode(input, return_tensors="pt").to(device=self.local_rank)
elif self.inferencer_args.device == "cpu":
input_id = model.encode(input, return_tensors="pt").to(device='cpu')
logger.debug(f"input_id: {input_id}")
input_length = input_id.shape[1]
output_id = model.inference(
input_id,
use_accelerator=True,
max_new_tokens=max_new_tokens,
# pad_token_id=model.tokenizer.eos_token_id,
)
# logger.debug(f"output: {output_id}")
output = model.decode(output_id[0])
output = output.replace(input,"")
return output

def code_exec(self, code):
# Execute the code
result = subprocess.run(["python", "-c", code], capture_output=True, text=True)

# Print the result
if result.returncode == 0:
print("Successfully Executed, the result is:")
print(result.stdout)
return result.stdout
else:
print("Error:")
print(result.stderr)
return result







36 changes: 36 additions & 0 deletions tests/models/test_tool_inferencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from lmflow.pipeline.inferencer import ToolInferencer
import unittest
from lmflow.args import InferencerArguments
from lmflow.args import ModelArguments
from lmflow.args import DatasetArguments
from lmflow.models import hf_decoder_model

CODE_1 = "print(\"hello world\")"
RES_1 = "hello world\n"
CODE_2 = "b=a+1\nprint(b)"
RES_2 = """Traceback (most recent call last):
File "<string>", line 1, in <module>
NameError: name 'a' is not defined
"""

class ToolInferencerTest(unittest.TestCase):
def set_up(self):
model_args = ModelArguments(model_name_or_path="codellama/CodeLlama-7b-instruct-hf")
model = hf_decoder_model.HFDecoderModel(model_args)
inferencer_args = InferencerArguments()
data_args = DatasetArguments()
self.toolinf = ToolInferencer(model_args, data_args, inferencer_args)

def test_code_exec_1(self,code=CODE_1, expected_output=RES_1):

toolinf_res = self.toolinf.code_exec(code)
self.assertEqual(toolinf_res, expected_output)

def test_code_exec_2(self,code=CODE_2):
toolinf_res = self.toolinf.code_exec(code)
self.assertNotEqual(toolinf_res.returncode, 0)

unittest.main()



0 comments on commit 8516a72

Please sign in to comment.