diff --git a/README.md b/README.md index 4ae25dc..d2bbf4d 100644 --- a/README.md +++ b/README.md @@ -213,7 +213,7 @@ GENERATION_CONFIG = GenerationConfig( outputs = model.generate( inputs.input_ids.to("cuda:0"), GENERATION_CONFIG, - attention_mask=pos_inputs.attention_mask.to("cuda:0"), + attention_mask=inputs.attention_mask.to("cuda:0"), ) outputs = outputs[:, inputs.input_ids.shape[-1]:]