diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3f0267affb..529a4c57b4 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -303,6 +303,24 @@ class AttentionParams: fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None + def __eq__(self, other): + """ + Overwrite dataclass.__eq__ so that only fp8_meta["recipe"] is compared, + since all other entries of fp8_meta are unused in get_attention_backend. + """ + if not isinstance(other, self.__class__): + return NotImplemented + for field in fields(self): + fname = field.name + sf = getattr(self, fname) + of = getattr(other, fname) + if fname != "fp8_meta": + if sf != of: + return False + elif sf["recipe"] != of["recipe"]: + return False + return True + _alibi_cache = { "_num_heads": None,