Skip to content

Commit

Permalink
feat(backend): support device args (#55)
Browse files Browse the repository at this point in the history
* support device args

* support device args

* support device args
  • Loading branch information
lwaekfjlk authored Oct 6, 2024
1 parent c4b4ef0 commit df396fb
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 10 deletions.
7 changes: 6 additions & 1 deletion examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from gpu_bartender.server import (
DataArgs,
DeviceArgs,
FinetuningArgs,
ModelArgs,
OptimizerArgs,
Expand Down Expand Up @@ -36,12 +37,16 @@
is_fsdp=True
)

device_args = DeviceArgs(
gpu_num=4,
)

calculator = VRAMCalculator(
model_args=model_args,
finetuning_args=finetuning_args,
optimizer_args=optimizer_args,
data_args=data_args,
num_gpus=4,
device_args=device_args,
unit="MiB"
)

Expand Down
2 changes: 1 addition & 1 deletion frontend/dist/bundle.js

Large diffs are not rendered by default.

5 changes: 0 additions & 5 deletions frontend/src/static/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
<body>
<div class="container">
<h1>GPU Usage Calculator for AI Model Training</h1>
<img
src="https://images.pexels.com/photos/17483874/pexels-photo-17483874.png?auto=compress&cs=tinysrgb&h=350"
alt="AI Neural Network"
class="hero-image"
/>

<div class="calculator">
<div class="input-section">
Expand Down
4 changes: 3 additions & 1 deletion gpu_bartender/server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .calculator import VRAMCalculator
from .data_args import DataArgs
from .device_args import DeviceArgs
from .finetuning_args import FinetuningArgs
from .model_args import ModelArgs
from .optimizer_args import OptimizerArgs
Expand All @@ -9,5 +10,6 @@
'ModelArgs',
'FinetuningArgs',
'OptimizerArgs',
'DataArgs'
'DataArgs',
'DeviceArgs'
]
6 changes: 4 additions & 2 deletions gpu_bartender/server/calculator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict

from data_args import DataArgs
from device_args import DeviceArgs
from finetuning_args import FinetuningArgs
from model_args import ModelArgs
from optimizer_args import OptimizerArgs
Expand All @@ -13,14 +14,15 @@ def __init__(
finetuning_args: FinetuningArgs,
optimizer_args: OptimizerArgs,
data_args: DataArgs,
num_gpus: int = 1,
device_args: DeviceArgs,
unit: str = "MiB"
):
self.model_args = model_args
self.finetuning_args = finetuning_args
self.optimizer_args = optimizer_args
self.data_args = data_args
self.num_gpus = num_gpus
self.device_args = device_args
self.num_gpus = 1
self.unit = unit
self.divisor = 2 ** 20 if unit == "MiB" else 2 ** 30
self.precision = 0 if unit == "MiB" else 3
Expand Down
9 changes: 9 additions & 0 deletions gpu_bartender/server/device_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass, field


@dataclass
class DeviceArgs:
gpu_num: int = field(default=1)
node_num: int = field(default=1)
gpu_memory_limit: int = field(default=0)
gpu_type: str = field(default='A100')
2 changes: 2 additions & 0 deletions gpu_bartender/server/model_args.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import Literal


@dataclass
Expand All @@ -10,3 +11,4 @@ class ModelArgs:
num_key_value_heads: int = field(default=1)
intermediate_size: int = field(default=1)
num_layers: int = field(default=1)
qquantization: Literal['float32', 'float16', 'bfloat16', 'int8'] = field(default='float32')

0 comments on commit df396fb

Please sign in to comment.