Skip to content

Commit

Permalink
better EoS handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 26, 2024
1 parent 38f2786 commit 9ccf1ec
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 36 deletions.
43 changes: 41 additions & 2 deletions controllers/llguidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,23 @@ def character_maker2(lm, id, description, valid_weapons):
"type": "object",
"additionalProperties": False,
"properties": {"age": {"type": "integer"}},
}
},
)
# assert grm.match('{"a": 1}')
prompt = ""
grm = "Here's some JSON:\n" + grm # + "\nAnd some more:\n" + grm

prompt = ""
grm = optional("A")

grm = "Q: Are dolphins fish?\nA: " + gen("dolphins", regex="Yes|No", max_tokens=10) + \
"\nQ: Are sharks fish?\nA: " + gen("sharks", regex="Yes|No", max_tokens=10)

# grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=5)

# g = zero_or_more("a") + "b"
# assert not g.match("b")
# assert g.match("b")
# assert g.match("ab")

# lm = guidance.models.Mock(b"<s>1234233234<s>")
# grammar = one_or_more(select(["1", "2"]))
Expand All @@ -287,8 +296,38 @@ def character_maker2(lm, id, description, valid_weapons):
max_tokens = 250

serialized = grm.ll_serialize()

x_serialized = {
"grammars": [
{
"greedy_lexer": False,
"nodes": [
{
"GenGrammar": {
"grammar": 1,
"stop_rx": "",
"no_initial_skip": True,
"temperature": 0.0,
}
}
],
"rx_nodes": [],
},
{
"greedy_lexer": True,
"greedy_skip_rx": "[\\x20\\x0A\\x0D\\x09]+",
"nodes": [
{"Lexeme": {"rx": "-?(?:0|[1-9][0-9]*)", "contextual": False}}
#{"Lexeme": {"rx": "[ab][ab]", "contextual": False}}
],
"rx_nodes": [],
},
]
}

serialized["max_tokens"] = max_tokens
llguidance_json = {"grammar": serialized}

llguidance_arg = json.dumps(llguidance_json, indent=1)
# save llguidance_arg to file
with open("tmp/llguidance_arg.json", "w") as f:
Expand Down
24 changes: 23 additions & 1 deletion controllers/llguidance_ctrl/src/earley/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,29 @@ impl Lexer {

pub fn force_lexeme_end(&self, prev: StateID) -> LexerResult {
let info = self.state_info(prev);
let idx = info.possible.first_bit_set().expect("no allowed lexemes");
match info.possible.first_bit_set() {
Some(idx) => LexerResult::Lexeme(PreLexeme {
idx: LexemeIdx::new(idx),
byte: None,
byte_next_row: false,
hidden_len: 0,
}),
None => LexerResult::Error,
}
}

pub fn try_lexeme_end(&mut self, prev: StateID) -> LexerResult {
let prev_accepting = self.state_info(prev).accepting.first_bit_set();
let eos_state = self.dfa.transition_bytes(prev, EOS_MARKER);
let eos_accepting = self.state_info(eos_state).accepting.first_bit_set();

let idx = match (prev_accepting, eos_accepting) {
(Some(p), Some(e)) if p < e => p,
(_, Some(e)) => e,
(Some(p), None) => p,
(None, None) => return LexerResult::Error,
};

LexerResult::Lexeme(PreLexeme {
idx: LexemeIdx::new(idx),
byte: None,
Expand Down
58 changes: 46 additions & 12 deletions controllers/llguidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use super::{
grammar::{CGrammar, CSymIdx, CSymbol, ModelVariable, RuleIdx},
lexer::{LexerResult, PreLexeme, StateID},
lexerspec::{Lexeme, LexemeIdx, LexerSpec},
EOS_MARKER,
};

const TRACE: bool = false;
Expand Down Expand Up @@ -369,6 +370,25 @@ impl Parser {
Ok(r)
}

pub fn compute_bias(&mut self, trie: &TokTrie, start: &[u8]) -> SimpleVob {
let mut set = trie.alloc_token_set();

trie.compute_bias_ext(self, &mut set, start);

// clean damage from EOS_MARKER
if self.lexer_allows_eos() {
let first_token_of_eos_marker = trie.greedy_tokenize(EOS_MARKER)[0];
set.disallow_token(first_token_of_eos_marker);
}

if set.num_set() == 1 && set.is_allowed(trie.eos_token()) {
// we're going to be stopped outside - we better flush the lexer
self.flush_lexer();
}

set
}

pub fn grammar(&self) -> &CGrammar {
&self.grammar
}
Expand All @@ -383,7 +403,7 @@ impl Parser {
self.after_dots().map(|pos| self.grammar.sym_data_at(pos))
}

pub fn can_advance(&self) -> bool {
fn can_advance_inner(&self) -> bool {
let skip = self.grammar.lexeme_to_sym_idx(LexemeIdx::SKIP);
for data in self.after_dots_symdata() {
if data.idx == skip || data.idx == CSymIdx::NULL {
Expand All @@ -396,11 +416,15 @@ impl Parser {
false
}

pub fn can_advance(&self) -> bool {
self.has_pending_lexeme_bytes() || self.can_advance_inner()
}

pub fn has_pending_lexeme_bytes(&self) -> bool {
self.curr_row_bytes().len() > 0
}

pub fn row_is_accepting(&self) -> bool {
fn row_is_accepting(&self) -> bool {
for pos in self.after_dots() {
let after_dot = self.grammar.sym_idx_at(pos);
if after_dot == CSymIdx::NULL {
Expand All @@ -413,10 +437,6 @@ impl Parser {
false
}

pub fn is_accepting(&self) -> bool {
!self.has_pending_lexeme_bytes() && self.row_is_accepting()
}

pub fn lexer_allows_eos(&mut self) -> bool {
let mut allowed_eos = self.lexer_spec().eos_lexemes();
allowed_eos.and(&self.curr_row().allowed_lexemes);
Expand Down Expand Up @@ -729,6 +749,13 @@ impl Parser {
}
}

pub fn is_accepting(&mut self) -> bool {
self.trie_started();
let r = self.flush_lexer() && self.row_is_accepting();
self.trie_finished();
r
}

pub fn try_push_byte_definitive(&mut self, byte: Option<u8>) -> bool {
assert!(self.scratch.definitive);

Expand Down Expand Up @@ -764,15 +791,19 @@ impl Parser {
&self.rows[self.lexer_state().row_idx as usize]
}

pub fn model_variables(&self) -> Vec<ModelVariable> {
pub fn model_variables(&mut self) -> Vec<ModelVariable> {
self.trie_started();
let mut vars = vec![];
for sym_data in self.after_dots_symdata() {
if let Some(ref mv) = sym_data.props.model_variable {
if !vars.contains(mv) {
vars.push(mv.clone());
if self.flush_lexer() {
for sym_data in self.after_dots_symdata() {
if let Some(ref mv) = sym_data.props.model_variable {
if !vars.contains(mv) {
vars.push(mv.clone());
}
}
}
}
self.trie_finished();
vars
}

Expand Down Expand Up @@ -804,8 +835,11 @@ impl Parser {
}

fn flush_lexer(&mut self) -> bool {
if !self.has_pending_lexeme_bytes() {
return true;
}
let curr = self.lexer_state();
let lex_result = self.lexer.force_lexeme_end(curr.lexer_state);
let lex_result = self.lexer.try_lexeme_end(curr.lexer_state);
self.advance_lexer_or_parser(lex_result, curr)
}

Expand Down
26 changes: 7 additions & 19 deletions controllers/llguidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ use std::sync::Arc;

use crate::{
api::{GenGrammarOptions, TopLevelGrammar},
earley::{
grammars_from_json, CGrammar, CSymIdx, ModelVariable, Parser, ParserStats, EOS_MARKER,
},
earley::{grammars_from_json, CGrammar, CSymIdx, ModelVariable, Parser, ParserStats},
};
use aici_abi::{MidProcessArg, MidProcessResult, TokenId, TokenizerEnv};
use anyhow::Result;
Expand Down Expand Up @@ -39,7 +37,6 @@ pub struct TokenParser {
previous_grm_bytes: Vec<u8>,
mid_process_was_accepting: bool,

first_token_of_eos_marker: TokenId,
max_tokens_total: usize,
max_tokens_parser: usize,
compiled_grammars: Vec<Arc<CGrammar>>,
Expand Down Expand Up @@ -73,8 +70,6 @@ impl TokenParser {
GenGrammarOptions::default(),
)?;

let first_token_of_eos_marker = token_env.tok_trie().greedy_tokenize(EOS_MARKER)[0];

Ok(TokenParser {
log_level,
token_env,
Expand All @@ -85,7 +80,6 @@ impl TokenParser {
parser_stack: Vec::new(),
previous_grm_bytes: Vec::new(),
compiled_grammars,
first_token_of_eos_marker,
llm_tokens: Vec::new(),
llm_bytes: Vec::new(),
grm_prefix: Vec::new(),
Expand Down Expand Up @@ -345,15 +339,15 @@ impl TokenParser {

let inner_done = {
let empty_token_prefix = token_prefix.is_empty();
let row_accepting = self.parser.row_is_accepting();
let no_pending_bytes = !self.parser.has_pending_lexeme_bytes();
let is_accepting = no_pending_bytes && row_accepting;
let lexer_bytes = self.parser.has_pending_lexeme_bytes();
let is_accepting = self.parser.is_accepting();
let can_advance = self.parser.can_advance();
let inner_done = empty_token_prefix && is_accepting && (!can_advance || has_eos);
infoln!(
self,
"inner_done: {inner_done}; can_advance: {can_advance} (eos:{has_eos}); \
accept: {is_accepting} (row:{row_accepting} & lexer:{no_pending_bytes}); \
"inner_done: {inner_done}; lexer_bytes: {lexer_bytes}; \
can_advance: {can_advance} (eos:{has_eos}); \
accept: {is_accepting}; \
empty_token_prefix: {empty_token_prefix}"
);
self.mid_process_was_accepting =
Expand All @@ -362,14 +356,8 @@ impl TokenParser {
};

let trie = self.token_env.tok_trie();
let mut set = trie.alloc_token_set();
// self.parser.print_row(self.parser.num_rows() - 1);
trie.compute_bias_ext(&mut self.parser, &mut set, &token_prefix);

// clean damage from EOS_MARKER
if self.parser.lexer_allows_eos() {
set.disallow_token(self.first_token_of_eos_marker);
}
let set = self.parser.compute_bias(trie, &token_prefix);

if inner_done
|| self.max_tokens_parser == 0
Expand Down
2 changes: 1 addition & 1 deletion py/guidance
Submodule guidance updated 1 files
+7 −5 guidance/_parser.py
2 changes: 1 addition & 1 deletion py/llguidance/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ find = { where = ["python"] }
target = "llguidance._lib"
binding = "PyO3"
debug = false
# features = ["aici_llguidance_ctrl/logging"]
features = ["aici_llguidance_ctrl/logging"]
# See reference for RustExtension in https://setuptools-rust.readthedocs.io/en/latest/reference.html

0 comments on commit 9ccf1ec

Please sign in to comment.