Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 21, 2024
1 parent 06c6307 commit 2ba24ac
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 12 deletions.
8 changes: 7 additions & 1 deletion examples/example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from gpu_bartender import VRAMCalculator, ModelArgs, DataArgs, OptimizerArgs, FinetuningArgs
from gpu_bartender import (
DataArgs,
FinetuningArgs,
ModelArgs,
OptimizerArgs,
VRAMCalculator,
)

# Example usage
model_args = ModelArgs(
Expand Down
6 changes: 3 additions & 3 deletions gpu_bartender/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .calculator import VRAMCalculator
from .model_args import ModelArgs
from .data_args import DataArgs
from .finetuning_args import FinetuningArgs
from .model_args import ModelArgs
from .optimizer_args import OptimizerArgs
from .data_args import DataArgs

__all__ = [
'VRAMCalculator',
'ModelArgs',
'FinetuningArgs',
'OptimizerArgs',
'DataArgs'
]
]
15 changes: 8 additions & 7 deletions gpu_bartender/calculator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from dataclasses import dataclass, field
from typing import Optional

from .data_args import DataArgs
from .finetuning_args import FinetuningArgs
from .model_args import ModelArgs
from .optimizer_args import OptimizerArgs


class VRAMCalculator:
def __init__(
self,
model_args: ModelArgs,
finetuning_args: FinetuningArgs,
optimizer_args: OptimizerArgs,
data_args: DataArgs,
num_gpus: int = 1,
self,
model_args: ModelArgs,
finetuning_args: FinetuningArgs,
optimizer_args: OptimizerArgs,
data_args: DataArgs,
num_gpus: int = 1,
unit: str = "MiB"
):
self.model_args = model_args
Expand Down
1 change: 1 addition & 0 deletions gpu_bartender/data_args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, field


@dataclass
class DataArgs:
batch_size: int = field(default=4)
Expand Down
1 change: 1 addition & 0 deletions gpu_bartender/finetuning_args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class LoraArgs:
lora_alpha: Optional[int] = field(default=None)
Expand Down
3 changes: 2 additions & 1 deletion gpu_bartender/model_args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, field


@dataclass
class ModelArgs:
num_params: int = field(default=1)
Expand All @@ -8,4 +9,4 @@ class ModelArgs:
num_attention_heads: int = field(default=1)
num_key_value_heads: int = field(default=1)
intermediate_size: int = field(default=1)
num_layers: int = field(default=1)
num_layers: int = field(default=1)
1 change: 1 addition & 0 deletions gpu_bartender/optimizer_args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class OptimizerArgs:
optimizer: str = field(default="adam")
Expand Down

0 comments on commit 2ba24ac

Please sign in to comment.