Skip to content

Commit

Permalink
Update GPT Fast handler to add option to get GPU ID from context (#2872)
Browse files Browse the repository at this point in the history
* Update GPT Fast handler to add option to get GPU ID from context

* Fix lint checks

* Fix lint checks

* Fix lint checks

* Update handler

---------

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
sachanub and Ubuntu authored Jan 4, 2024
1 parent 2eae8e5 commit 39ea211
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion examples/large_models/gpt_fast/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@ def __init__(self):

def initialize(self, ctx):
self.context = ctx
properties = self.context.system_properties
gpu_id = properties.get("gpu_id")
if gpu_id is not None and int(gpu_id) < 0:
raise ValueError("Invalid gpu_id")
rank = maybe_init_dist()

self.local_rank = rank if rank is not None else 0
self.local_rank = rank if rank is not None else int(gpu_id)

if torch.cuda.is_available():
self.map_location = "cuda"
Expand Down

0 comments on commit 39ea211

Please sign in to comment.