Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce modular files for speech models #35902

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

nikosanto13
Copy link
Contributor

@nikosanto13 nikosanto13 commented Jan 27, 2025

What does this PR do?

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @Cyrilvallez

Additional details

  • Added modular files for models that have heavy duplication with classes from modeling_wav2vec2.py: Hubert, WavLM, Data2VecAudio, Wav2Vec2Conformer, Wav2Vec2Bert, UniSpeech, UniSpeechSat
  • Added some modifications on the modular converter script, from issues that came up during writing the above modular scripts (see inline comments for justification)

"""
for assignment, node in assignments.items():
should_keep = any(re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP)

# If it's a DOCSTRING var and is assigned to None, the parent's docstring is kept.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add this because for many of the models I've used, their docstring was kinda custom (e.g. contained link to original paper). So instead of just copying the docstring from modular file, I figured it would be best to adopt this hybrid approach.

If you agree with the change, I should also update the modular docs: https://github.com/huggingface/transformers/blob/main/docs/source/en/modular_transformers.md

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, I don't really get here. This is already the actual behavior to have the docstring use the parent if it's None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I wanted to say "instead of copying the docstring from the parent ..." (my comment on the code is also kinda obscure)
Essentially, now there are two possibilities:

  • either set MYMODEL_INPUT_DOCSTRING = None, in which case the assignment will be copied by the parent (as it is already the case)
  • or set it to something else (new docstring), and the assignment will be copied from the modular file

So it is more flexible than the existing approach.

new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
imports_to_keep.append(new_node)
existing_protected_statements.update({str(stmt) for stmt in new_statements})
import_statements = [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this beacuse the code before had problematic behaviour for "safe" imports that had multiple other statements inside them, e.g. L381:395 on modeling_wav2vec2.py

if is_deepspeed_zero3_enabled():
    import deepspeed

    with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
 ...

The whole block after the import statement would be displaced in the top of the new modeling script (in the import statements).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's one of the current limitations. However, removing everything else does not seem like a good solution either. Could not wrap my mind around a nice rule for this. For now, the best is maybe to patch the original modeling file to dissociate safe import and other logic? Would that require a lot of change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah let's do it like this, thanks

Copy link
Contributor Author

@nikosanto13 nikosanto13 Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

however, in the example above it would be better to move:

if is_deepspeed_zero3_enabled():
    import deepspeed

outside of constructor, because in the current state the newly created module (prior to running ruff inside modular converter) would have two such statements, and the first one would become:

if is_deepspeed_zero3_enabled():
   pass

after the run_ruff call.

But if we move it top-side, deepspeed would no longer be lazily imported. I think this is not a problem, right?

@Rocketknight1
Copy link
Member

cc @ArthurZucker @qubvel

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for the contribution! I just looked at the modular part, let me know if something is unclear!! 🤗

Comment on lines +58 to +70
# Exclude names to prevent edge cases where we want to keep a name that may
# exist in the mapping, e.g. `Wav2Vec2BaseModelOutput` where `Wav2Vec2` is
# a "base" model identifier but we want the type to pass as is in the produced modeling file
EXCLUDE_NAMES = ["Wav2Vec2BaseModelOutput"]


def preserve_case_replace(text, patterns: dict, default_name: str):
# Create a regex pattern to match all variations
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
compiled_regex = re.compile(f"(?<![a-z0-9])({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL)

# Create exclude pattern
exclude_pattern = "|".join(re.escape(key) for key in EXCLUDE_NAMES)
compiled_regex = re.compile(f"(?<![a-z0-9])(?!{exclude_pattern})({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely not a fan of having exclusions here. And the regex is already way too complicated 🥲 Moreover, I don't think we actually want an output type from another model, do we?

Copy link
Contributor Author

@nikosanto13 nikosanto13 Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you're right, it felt bad while doing it 😂 Unfortunately we need output types from other models in the files I introduced (almost all of them need the Wav2Vec2BaseModelOutput).

But it could be done cleaner, with "type aliasing" e.g. for WavLM model that needs Wav2Vec2ModelBaseOutput, we could add
WavLMBaseOutput = Wav2Vec2ModelBaseOutput
inside modular.

What do you think?

"""
for assignment, node in assignments.items():
should_keep = any(re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP)

# If it's a DOCSTRING var and is assigned to None, the parent's docstring is kept.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, I don't really get here. This is already the actual behavior to have the docstring use the parent if it's None

Comment on lines +1019 to +1032

# Keep return annotation in `modular_xxx.py` if any, else original return annotation
new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns

if not re.match(
r"\ndef .*\(.*\):\n raise.*Error\(.*",
mapper.python_module.code_for_node(updated_methods[name]),
):
func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators)
func = func.with_changes(
body=updated_methods[name].body,
params=new_params,
decorators=new_decorators,
returns=new_return_annotation,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this one! Nice!

new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
imports_to_keep.append(new_node)
existing_protected_statements.update({str(stmt) for stmt in new_statements})
import_statements = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's one of the current limitations. However, removing everything else does not seem like a good solution either. Could not wrap my mind around a nice rule for this. For now, the best is maybe to patch the original modeling file to dissociate safe import and other logic? Would that require a lot of change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants