Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llama] Explicit option to use static tables. #155

Merged
merged 3 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def main():

hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
llama_config = LlamaModelConfig(hp)
llama_config.static_tables = False # Rely on the compiler for hoisting tables.
llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"
model = PagedLlamaModelV1(dataset.root_theta, llama_config)

Expand Down
17 changes: 11 additions & 6 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(
theta: Theta,
*,
context_length: int,
static_context_mask: bool = True,
static_tables: bool = True,
static_context_mask: bool = False,
device: Optional[torch.device] = None,
activation_dtype: torch.dtype = torch.float32,
attention_dtype: torch.dtype = torch.float32,
Expand All @@ -39,7 +40,7 @@ def __init__(
self.attention_dtype = attention_dtype
self.context_length = context_length

if static_context_mask:
if static_tables:
self.register_buffer(
"causal_context_mask", self.generate_causal_context_mask()
)
Expand All @@ -66,10 +67,12 @@ def _maximally_negative_value(self, dtype):

def generate_causal_context_mask(self) -> torch.Tensor:
context_length = self.context_length
unary_broadcast_ones = torch.ones([1, 1], dtype=torch.bool, device=self.device)
context_broadcast_ones = unary_broadcast_ones.expand(
context_length, context_length
)
causal_context_mask = torch.triu(
torch.ones(
[context_length, context_length], dtype=torch.bool, device=self.device
),
context_broadcast_ones,
diagonal=1,
)[None, None, :, :]
return causal_context_mask
Expand Down Expand Up @@ -114,9 +117,11 @@ def attention_mask(
scenarios can benefit from managing this in different ways.
"""
if causal_context_mask is None:
# Try to use the statically generated.
causal_context_mask = self.causal_context_mask
if causal_context_mask is None:
causal_context_mask = self._generate_causal_context_mask()
# Fallback to dynamically generated.
causal_context_mask = self.generate_causal_context_mask()

# Combine the causal context mask and input mask.
dtype = self.attention_dtype
Expand Down
31 changes: 23 additions & 8 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,29 @@ def __init__(
max_seqlen: int,
device: Optional[torch.device] = None,
use_hf: bool = False,
static_tables: bool = True,
):
super().__init__()
# Force static_tables until compiler limitations are solved.
# See https://github.com/nod-ai/sharktank/issues/156
static_tables = True
self.device = device
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self._table = self._create_rotary_embed_table(
max_seqlen=max_seqlen,
dim=rope_dimension_count,
)
if static_tables:
self.register_buffer(
"static_rotary_embed_table", self._create_rotary_embed_table()
)
else:
self.static_rotary_embed_table = None

@property
def rotary_embed_table(self):
if self.static_rotary_embed_table is None:
return self._create_rotary_embed_table()
else:
return self.static_rotary_embed_table

def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int):
# xq_, xk_ shape: bs, sl, _, dim
Expand Down Expand Up @@ -80,7 +95,7 @@ def create_ordering_tensor(dim):
_, sl, _, dim = xq_.shape

# Offset the table based on starting position.
freqs_cis = self._table[start_index : start_index + sl, :]
freqs_cis = self.rotary_embed_table[start_index : start_index + sl, :]
assert freqs_cis.shape[-1] == dim
assert (
freqs_cis.shape[0] >= sl
Expand Down Expand Up @@ -139,7 +154,7 @@ def compute_batch_mask(
) + start_positions.unsqueeze(1)
# Broadcast lookup to [b, ...].
self.trace_tensor("rope.positions_seq", positions_seq)
freqs_cis = self._table[positions_seq]
freqs_cis = self.rotary_embed_table[positions_seq]

# Unsqueeze a unit dim for attention heads.
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
Expand Down Expand Up @@ -167,10 +182,10 @@ def apply_batched_mask(

def _create_rotary_embed_table(
self,
max_seqlen: int,
dim: int,
theta_value: float = 10000.0,
):
dim = self.rope_dimension_count
max_seqlen = self.max_seqlen
freqs = 1.0 / (
theta_value
** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim)
Expand Down
10 changes: 10 additions & 0 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ class LlamaModelConfig:
# rotary embedding).
use_hf: bool = False

# If true, then the model may pre-initialize certain tables during
# init. This can be better for eager execution but when capturing a program,
# it is often better to preserve the calculation explicitly and rely on
# the compiler to transform it to an initialization time step. This can
# be the difference of many gigabytes of static data being embedded in
# the program and not.
static_tables: bool = True

def create_kv_cache(self) -> BaseKVCache:
hp = self.hp
if self.kv_cache_type == "direct":
Expand Down Expand Up @@ -110,6 +118,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
super().__init__(
theta,
context_length=config.hp.context_length,
static_tables=config.static_tables,
device=config.device,
activation_dtype=config.activation_dtype,
attention_dtype=config.attention_dtype,
Expand All @@ -131,6 +140,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
max_seqlen=hp.context_length,
device=self.device,
use_hf=self.use_hf,
static_tables=config.static_tables,
),
)
self.add_module(
Expand Down
Loading