From 240a34340af82e8c8e6d6035cd473094ab0dda39 Mon Sep 17 00:00:00 2001 From: Mark Rogers Date: Thu, 7 Mar 2024 01:55:18 -0600 Subject: [PATCH 1/2] subclass from transformers --- mamba_ssm/models/config_mamba.py | 4 +++- mamba_ssm/models/mixer_seq_simple.py | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mamba_ssm/models/config_mamba.py b/mamba_ssm/models/config_mamba.py index 2aa1e5a6..65dc0082 100644 --- a/mamba_ssm/models/config_mamba.py +++ b/mamba_ssm/models/config_mamba.py @@ -1,8 +1,10 @@ from dataclasses import dataclass, field +from transformers import PretrainedConfig + @dataclass -class MambaConfig: +class MambaConfig(PretrainedConfig): d_model: int = 2560 n_layer: int = 64 diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 2f1d97fd..b6a82a08 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -9,6 +9,8 @@ import torch import torch.nn as nn +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutput from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.modules.mamba_simple import Mamba, Block @@ -173,7 +175,7 @@ def forward(self, input_ids, inference_params=None): return hidden_states -class MambaLMHeadModel(nn.Module, GenerationMixin): +class MambaLMHeadModel(PreTrainedModel, GenerationMixin): def __init__( self, @@ -193,7 +195,8 @@ def __init__( pad_vocab_size_multiple = config.pad_vocab_size_multiple factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() + PreTrainedModel.__init__(self, config) + GenerationMixin.__init__(self) if vocab_size % pad_vocab_size_multiple != 0: vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) self.backbone = MixerModel( @@ -235,7 +238,6 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_ if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] lm_logits = self.lm_head(hidden_states) - CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=lm_logits) @classmethod From 49e6513cc52a5d666225c6498691556d4296e789 Mon Sep 17 00:00:00 2001 From: Mark Rogers Date: Thu, 7 Mar 2024 11:59:17 -0600 Subject: [PATCH 2/2] subclass generation mixin --- benchmarks/benchmark_generation_mamba_simple.py | 2 +- mamba_ssm/models/mixer_seq_simple.py | 2 +- mamba_ssm/utils/generation.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmark_generation_mamba_simple.py b/benchmarks/benchmark_generation_mamba_simple.py index b7607787..805188e6 100644 --- a/benchmarks/benchmark_generation_mamba_simple.py +++ b/benchmarks/benchmark_generation_mamba_simple.py @@ -54,7 +54,7 @@ if is_mamba: fn = lambda: model.generate( - input_ids=input_ids, + inputs=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index b6a82a08..abeca8dc 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -14,7 +14,7 @@ from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.modules.mamba_simple import Mamba, Block -from mamba_ssm.utils.generation import GenerationMixin +from mamba_ssm.utils.generation import MambaGenerationMixin as GenerationMixin from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf try: diff --git a/mamba_ssm/utils/generation.py b/mamba_ssm/utils/generation.py index 369c7a14..53a2335b 100644 --- a/mamba_ssm/utils/generation.py +++ b/mamba_ssm/utils/generation.py @@ -11,7 +11,7 @@ from einops import rearrange, repeat from torch import Tensor from torch.profiler import ProfilerActivity, profile, record_function -from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer +from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer, GenerationMixin @dataclass @@ -241,13 +241,13 @@ def should_stop(current_token, inference_params): return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) -class GenerationMixin: +class MambaGenerationMixin(GenerationMixin): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): raise NotImplementedError def generate( self, - input_ids, + inputs, max_length, top_k=1, top_p=0.0, @@ -258,7 +258,7 @@ def generate( **kwargs, ): output = decode( - input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs + inputs, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs ) if not output_scores: output.scores = None