Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[usability] deps streamlining #905

Merged
merged 25 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
60ee14b
[usability] deps streamlining
Sep 30, 2024
019a361
[fix] merge lora fix
Oct 15, 2024
2f460ec
Merge pull request #909 from OptimalScale/yizhenjia-mergelora
research4pan Oct 15, 2024
7fd0b80
[usability] temporarily change default version to 0.0.8
Oct 20, 2024
b8efbcf
Merge pull request #911 from OptimalScale/yizhenjia-readme-upd
research4pan Oct 20, 2024
f7a51df
[doc] readme update wandb guide
Oct 25, 2024
efac11e
Merge pull request #912 from OptimalScale/yizhenjia-readme-upd
research4pan Oct 25, 2024
dc41519
[temp] temporarily restrict transformers version
Nov 4, 2024
ba8d98d
[temp] transformers version
Nov 4, 2024
8a3953f
Merge pull request #914 from OptimalScale/yizhenjia-req-upd
research4pan Nov 4, 2024
5e586ce
[usability] add flash attn detect
Nov 5, 2024
9942954
[dev] init toml
Nov 5, 2024
b13fe84
[usability] update setup
Nov 5, 2024
a79e733
[usability] versioning update
Nov 5, 2024
8f85dfb
[usability] versioning update
Nov 5, 2024
c683af3
[usability] `use_auth_token` deprecation update
Nov 5, 2024
a2e376b
[usability] deps streamlining
Sep 30, 2024
ed879eb
[usability] add flash attn detect
Nov 5, 2024
7a0c799
[dev] init toml
Nov 5, 2024
2409a0b
[usability] update setup
Nov 5, 2024
f623fb6
[usability] versioning update
Nov 5, 2024
97e1d3a
[usability] versioning update
Nov 5, 2024
e23ad92
[usability] `use_auth_token` deprecation update
Nov 5, 2024
b17a211
Merge branch 'yizhenjia-streamline' of https://github.com/OptimalScal…
Nov 5, 2024
fb14e16
[usability] setup update
Nov 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ data/

# output models
output_models/
adapter_model/

# Distribution / packaging
.Python
Expand Down
27 changes: 26 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,34 @@ cd LMFlow
conda create -n lmflow python=3.9 -y
conda activate lmflow
conda install mpi4py
bash install.sh
pip install -e .
```

> [!TIP]
> We use WandB to track and visualize the training process by default. Before running the training scripts, users may need to log in to WandB using the command:
>```bash
>wandb login
>```
> For detailed instructions, refer to the [WandB Quickstart Guide](https://docs.wandb.ai/quickstart/). Step 1 (registration) and Step 2 (login using your WandB API key) should be sufficient to set up your environment.
>
> <details><summary>Disabling wandb</summary>
>
> One can disable wandb by either:
>
> 1. Adding environment variable before running the training command.
>
>```bash
>export WANDB_MODE=disabled
>```
>
> 2. OR, specifying the integrations to report the results and logs to. In the training script, add:
>
>```bash
>--report_to none \
>```
>
> </details>

### Prepare Dataset

Please refer to our [doc](https://optimalscale.github.io/LMFlow/examples/DATASETS.html).
Expand Down
26 changes: 16 additions & 10 deletions contrib/long-context/sft_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,27 @@
from colorama import Fore,init
from typing import Optional, List

from trl.commands.cli_utils import TrlParser
import torch
from datasets import load_dataset
from dataclasses import dataclass, field
from tqdm.rich import tqdm
from transformers import AutoTokenizer, TrainingArguments, TrainerCallback
from trl import (
ModelConfig,
SFTTrainer,
DataCollatorForCompletionOnlyLM,
SFTConfig,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)

from lmflow.utils.versioning import is_trl_available

if is_trl_available():
from trl import (
ModelConfig,
SFTTrainer,
DataCollatorForCompletionOnlyLM,
SFTConfig,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)
from trl.commands.cli_utils import TrlParser
else:
raise ImportError("Please install trl package to use sft_summarizer.py")

@dataclass
class UserArguments:
Expand Down
14 changes: 10 additions & 4 deletions examples/chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,28 @@
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
"""A simple shell chatbot implemented with lmflow APIs.
"""
from dataclasses import dataclass, field
import logging
import json
import os
import sys
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
import torch
from typing import Optional
import warnings
import gradio as gr
from dataclasses import dataclass, field

import torch
from transformers import HfArgumentParser
from typing import Optional

from lmflow.datasets.dataset import Dataset
from lmflow.pipeline.auto_pipeline import AutoPipeline
from lmflow.models.auto_model import AutoModel
from lmflow.args import ModelArguments, DatasetArguments, AutoArguments
from lmflow.utils.versioning import is_gradio_available

if is_gradio_available():
import gradio as gr
else:
raise ImportError("Gradio is not available. Please install it via `pip install gradio`.")

MAX_BOXES = 20

Expand Down
1 change: 1 addition & 0 deletions examples/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def main():
device=merge_lora_args.device,
ds_config=merge_lora_args.ds_config
)
model.activate_model_for_inference()
model.merge_lora_weights()
model.save(merge_lora_args.output_model_path, save_full_model=True)

Expand Down
22 changes: 12 additions & 10 deletions examples/vis_chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,30 @@
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
"""A simple Multimodal chatbot implemented with lmflow APIs.
"""
import logging
from dataclasses import dataclass, field
import json
import logging
import time

from PIL import Image
from lmflow.pipeline.inferencer import Inferencer
import warnings
from typing import Optional

import numpy as np
import os
import sys
from PIL import Image
import torch
import warnings
import gradio as gr
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from typing import Optional

from lmflow.datasets.dataset import Dataset
from lmflow.pipeline.auto_pipeline import AutoPipeline
from lmflow.models.auto_model import AutoModel
from lmflow.args import (VisModelArguments, DatasetArguments, \
InferencerArguments, AutoArguments)
from lmflow.utils.versioning import is_gradio_available

if is_gradio_available():
import gradio as gr
else:
raise ImportError("Gradio is not available. Please install it via `pip install gradio`.")


MAX_BOXES = 20

Expand Down
8 changes: 0 additions & 8 deletions install.sh

This file was deleted.

18 changes: 18 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["setuptools >= 64"]
build-backend = "setuptools.build_meta"

[tool.ruff]
target-version = "py39"
indent-width = 4

[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["lmflow"]

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
docstring-code-format = true
skip-magic-trailing-comma = false
line-ending = "auto"
14 changes: 2 additions & 12 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,13 @@ datasets==2.14.6
tokenizers>=0.13.3
peft>=0.10.0
torch>=2.0.1
wandb==0.14.0
wandb
deepspeed>=0.14.4
trl==0.8.0
sentencepiece
transformers>=4.31.0
flask
flask_cors
icetk
cpm_kernels==1.0.11
evaluate==0.4.0
scikit-learn==1.2.2
lm-eval==0.3.0
dill<0.3.5
bitsandbytes>=0.40.0
pydantic
gradio
accelerate>=0.27.2
einops>=0.6.1
vllm>=0.4.3
ray>=2.22.0
einops>=0.6.1
21 changes: 12 additions & 9 deletions service/app.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from dataclasses import dataclass, field
import json
import torch
import os
from typing import Optional

from flask import Flask, request, stream_with_context
from flask import render_template
from flask_cors import CORS
from accelerate import Accelerator
from dataclasses import dataclass, field
import torch
from transformers import HfArgumentParser
from typing import Optional

from lmflow.datasets.dataset import Dataset
from lmflow.pipeline.auto_pipeline import AutoPipeline
from lmflow.args import ModelArguments
from lmflow.models.auto_model import AutoModel
from lmflow.args import ModelArguments, DatasetArguments, AutoArguments
from lmflow.utils.versioning import is_flask_available

if is_flask_available():
from flask import Flask, request, stream_with_context
from flask import render_template
from flask_cors import CORS
else:
raise ImportError("Flask is not available. Please install flask and flask_cors.")

WINDOW_LENGTH = 512

Expand Down
25 changes: 19 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,35 @@
import os
from setuptools import find_packages
from setuptools import setup
import subprocess

folder = os.path.dirname(__file__)
version_path = os.path.join(folder, "src", "lmflow", "version.py")

__version__ = None
with open(version_path) as f:
exec(f.read(), globals())
exec(f.read(), globals())

req_path = os.path.join(folder, "requirements.txt")
install_requires = []
if os.path.exists(req_path):
with open(req_path) as fp:
install_requires = [line.strip() for line in fp]
with open(req_path) as fp:
install_requires = [line.strip() for line in fp]

extra_require = {
"multimodal": ["Pillow"],
"vllm": ["vllm>=0.4.3"],
"ray": ["ray>=2.22.0"],
"gradio": ["gradio"],
"flask": ["flask", "flask_cors"],
"flash_attn": ["flash-attn>=2.0.2"],
"trl": ["trl==0.8.0"]
}

readme_path = os.path.join(folder, "README.md")
readme_contents = ""
if os.path.exists(readme_path):
with open(readme_path, encoding='utf-8') as fp:
readme_contents = fp.read().strip()
with open(readme_path, encoding="utf-8") as fp:
readme_contents = fp.read().strip()

setup(
name="lmflow",
Expand All @@ -33,6 +42,7 @@
packages=find_packages("src"),
package_data={},
install_requires=install_requires,
extras_require=extra_require,
classifiers=[
"Intended Audience :: Science/Research/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
Expand All @@ -41,3 +51,6 @@
],
requires_python=">=3.9",
)

# optionals
# lm-eval==0.3.0
21 changes: 12 additions & 9 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
)
from transformers.utils.versions import require_version

from lmflow.utils.versioning import is_flash_attn_available

MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

Expand Down Expand Up @@ -88,9 +90,8 @@ class ModelArguments:
a string representing the specific model version to use (can be a
branch name, tag name, or commit id).

use_auth_token : bool
a boolean indicating whether to use the token generated when running
huggingface-cli login (necessary to use this script with private models).
token : Optional[str]
Necessary when accessing a private model/dataset.

torch_dtype : str
a string representing the dtype to load the model under. If auto is
Expand Down Expand Up @@ -178,13 +179,10 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
token: Optional[str] = field(
default=None,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
"help": ("Necessary to specify when accessing a private model/dataset.")
},
)
trust_remote_code: bool = field(
Expand Down Expand Up @@ -357,6 +355,11 @@ def __post_init__(self):
if not self.use_lora:
logger.warning("use_qlora is set to True, but use_lora is not set to True. Setting use_lora to True.")
self.use_lora = True

if self.use_flash_attention:
if not is_flash_attn_available():
self.use_flash_attention = False
logger.warning("Flash attention is not available in the current environment. Disabling flash attention.")


@dataclass
Expand Down
6 changes: 5 additions & 1 deletion src/lmflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,9 @@
The `Dataset` class includes methods for loading datasets from a dictionary and a Hugging
Face dataset, mapping datasets, and retrieving the backend dataset and arguments.
"""
from lmflow.utils.versioning import is_multimodal_available


from lmflow.datasets.dataset import Dataset
from lmflow.datasets.multi_modal_dataset import CustomMultiModalDataset
if is_multimodal_available():
from lmflow.datasets.multi_modal_dataset import CustomMultiModalDataset
Loading
Loading