Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subclass from transformers for PEFT support and overall wider adoption #227

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_generation_mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

if is_mamba:
fn = lambda: model.generate(
input_ids=input_ids,
inputs=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
Expand Down
4 changes: 3 additions & 1 deletion mamba_ssm/models/config_mamba.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import dataclass, field

from transformers import PretrainedConfig


@dataclass
class MambaConfig:
class MambaConfig(PretrainedConfig):

d_model: int = 2560
n_layer: int = 64
Expand Down
10 changes: 6 additions & 4 deletions mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput

from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.generation import MambaGenerationMixin as GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

try:
Expand Down Expand Up @@ -173,7 +175,7 @@ def forward(self, input_ids, inference_params=None):
return hidden_states


class MambaLMHeadModel(nn.Module, GenerationMixin):
class MambaLMHeadModel(PreTrainedModel, GenerationMixin):

def __init__(
self,
Expand All @@ -193,7 +195,8 @@ def __init__(
pad_vocab_size_multiple = config.pad_vocab_size_multiple
factory_kwargs = {"device": device, "dtype": dtype}

super().__init__()
PreTrainedModel.__init__(self, config)
GenerationMixin.__init__(self)
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
self.backbone = MixerModel(
Expand Down Expand Up @@ -235,7 +238,6 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions mamba_ssm/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from einops import rearrange, repeat
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer, GenerationMixin


@dataclass
Expand Down Expand Up @@ -241,13 +241,13 @@ def should_stop(current_token, inference_params):
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))


class GenerationMixin:
class MambaGenerationMixin(GenerationMixin):
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError

def generate(
self,
input_ids,
inputs,
max_length,
top_k=1,
top_p=0.0,
Expand All @@ -258,7 +258,7 @@ def generate(
**kwargs,
):
output = decode(
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
inputs, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
)
if not output_scores:
output.scores = None
Expand Down