Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High CrossEntropy and Z Loss variance after loading from checkpoint #776

Open
abhijangda opened this issue Jan 6, 2025 · 0 comments
Open
Labels
type/bug An issue about a bug

Comments

@abhijangda
Copy link

abhijangda commented Jan 6, 2025

🐛 Describe the bug

I have been playing with configs/official-1124/OLMo-7B-stage1.yaml and training using the dataset in the YAML file. Unfortunately, I have found a strange issue. After loading from a checkpoint the variance in Cross Entropy and Z Loss has increased dramatically. For example, I ran first iteration till steps 5600 and then re ran training from a checkpoint of 4400. Here are Loss graphs from wandb:

image

You can see clearly that after step 4400 the variance is high.

I have tried this on following two systems and both shows the same problem.

  • 64 AMD MI300X with 8 nodes using ROCM 6.1, PyTorch 2.5.1, and Python 3.11
  • 64 NVIDIA A100 with 8 nodes using CUDA 12.4, PytTorch2.5.1, and Python 3.11

I have tried changing with heads and layers of OLMo-7B-stage1.yaml: 16 and 32 but both have same issues.
I have been using OLMo Core checkpointer using the following method:

  1. First collect tensors of all nodes in model, train, and optim folder of checkpoints in a single folder accessible to all nodes.
  2. Then set --load_path= to the above folder containing all tensors.

Below is the config I used (I removed the dataset URLs):

run_name: OLMo2-7B-stage1
seed: 6198
dry_run: false

model:
  d_model: 4096
  n_heads: 32
  n_layers: 32
  mlp_hidden_size: 22016
  weight_tying: false
  alibi: false
  rope: true
  rope_theta: 500000
  flash_attention: true
  attention_dropout: 0.0
  include_bias: false
  block_type: sequential
  layer_norm_type: rms
  layer_norm_with_affine: true
  layer_norm_eps: 1e-6
  bias_for_layer_norm: false
  attention_layer_norm: true
  attention_layer_norm_with_affine: true
  norm_after: true
  activation_type: swiglu
  residual_dropout: 0.0
  embedding_dropout: 0.0
  max_sequence_length: 4096
  vocab_size: 100278
  embedding_size: 100352
  eos_token_id: 100257
  pad_token_id: 100277
  init_device: meta
  init_fn: normal
  init_std: 0.02
  init_cutoff_factor: 3

softmax_auxiliary_loss: true
auxiliary_loss_multiplier: 1e-5
fused_loss: true

compile: null

wandb:
  project: "llm-kron"
  entity: "abhijangda-microsoft"
  log_interval: 1
  group: "7B"

optimizer:
  name: adamw
  learning_rate: 3.0e-4
  weight_decay: 0.1
  eps: 1e-8
  decay_norm_and_bias: true
  decay_embeddings: false
  betas:
  - 0.9
  - 0.95
  metrics_log_interval: 1

scheduler:
  name: cosine_with_warmup
  units: tokens
  t_warmup: 8388608000
  t_max: 5e12
  alpha_f: 0.1
  warmup_min_lr: 0.0

tokenizer:
  identifier: tokenizers/allenai_dolma2.json
  truncate_direction: right

save_overwrite: false

save_interval: 1000
save_interval_ephemeral: 250
save_num_checkpoints_to_keep: -1
sharded_checkpointer: olmo_core

save_interval_unsharded: null
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 1ep
global_train_batch_size: 1024
device_train_microbatch_size: 8

precision: amp_bf16

fsdp:
  wrapping_strategy: by_block_and_size
  precision: mixed

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
  window_size: 1

gen1_gc_interval: 1

eval_interval: 1000
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
data:
  pad_direction: right
  # generate_doc_lengths: true
  num_workers: 32
  drop_last: true
  pin_memory: true
  prefetch_factor: 8
  persistent_workers: true
  memmap_dtype: uint32
  timeout: 0
  instance_filter:
    repetition_max_period: 13
    repetition_min_period: 1
    repetition_max_count: 32

Any idea what could be the issue here?

Versions

absl-py==2.1.0
accelerate==0.18.0
-e git+ssh://[email protected]/abhijangda/OLMo.git@77e47c6d84c018fc33a5eda086056c1402f74381#egg=ai2_olmo
ai2-olmo-core==0.1.0
aiofiles==23.2.1
aiohappyeyeballs==2.4.3
aiohttp==3.11.3
aioshutil==1.5
aiosignal==1.3.1
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.6.2.post1
anykeystore==0.2
apex==0.1
appdirs==1.4.4
asttokens==2.4.1
astunparse==1.6.3
attrs==24.2.0
autocommand==2.2.2
backoff==2.2.1
backports.tarfile==1.2.0
beaker-gantry==1.10.0
beaker-py==1.32.3
beautifulsoup4==4.12.3
bitsandbytes==0.44.1
black==23.12.1
boltons==21.0.0
boto3==1.35.84
botocore==1.35.84
bracex==2.5.post1
Brotli==1.1.0
build==1.2.2.post1
cached_path==1.6.5
cachetools==5.5.0
certifi==2024.8.30
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.0
click==8.1.7
click-help-colors==0.9.4
click-option-group==0.5.6
cmake==3.31.0.1
codeshield==1.0.1
colorama==0.4.6
coloredlogs==15.0.1
contourpy==1.3.1
cryptacular==1.6.2
cryptography==43.0.3
cupy==13.3.0
cxxfilt==0.3.0
cycler==0.12.1
dataclasses-json==0.6.7
datasets==3.2.0
decorator==5.1.1
defusedxml==0.7.1
Deprecated==1.2.15
dill==0.3.6
distro==1.9.0
docker==7.1.0
docker-pycreds==0.4.0
docutils==0.21.2
effdet==0.4.1
einops==0.8.0
emoji==2.14.0
eval_type_backport==0.2.0
evaluate==0.4.3
exceptiongroup==1.2.2
executing==2.1.0
expecttest==0.2.1
face==24.0.0
fastapi==0.115.5
fastrlock==0.8.2
ffmpy==0.4.0
filelock==3.16.1
filetype==1.2.0
fire==0.7.0
flash_attn @ file:///home/aiscuser/ajangda/flash-attention
flatbuffers==24.3.25
fonttools==4.55.0
frozenlist==1.5.0
fsspec==2023.9.2
ftfy==6.3.1
gitdb==4.0.11
GitPython==3.1.43
glom==22.1.0
google-api-core==2.23.0
google-auth==2.36.0
google-cloud-core==2.4.1
google-cloud-storage==2.19.0
google-cloud-vision==3.8.1
google-crc32c==1.6.0
google-resumable-media==2.7.2
googleapis-common-protos==1.66.0
gradio==5.6.0
gradio_client==1.4.3
greenlet==3.1.1
grpcio==1.68.0
grpcio-status==1.62.3
h11==0.14.0
httpcore==1.0.7
httpx==0.27.2
huggingface-hub==0.26.5
humanfriendly==10.0
hupper==1.12.1
hypothesis==6.119.2
idna==3.10
importlib_metadata==7.1.0
importlib_resources==6.4.0
inflate64==1.0.0
inflect==7.3.1
iniconfig==2.0.0
iopath==0.1.10
ipython==8.29.0
isort==5.12.0
jaraco.classes==3.4.0
jaraco.context==5.3.0
jaraco.functools==4.0.1
jaraco.text==3.12.1
jedi==0.19.2
jeepney==0.8.0
Jinja2==3.1.4
jmespath==1.0.1
joblib==1.4.2
jsonpatch==1.33
jsonpath-python==1.0.6
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
keyring==25.5.0
kiwisolver==1.4.7
langchain==0.2.17
langchain-community==0.2.19
langchain-core==0.2.43
langchain-openai==0.1.20
langchain-text-splitters==0.2.4
langdetect==1.0.9
langsmith==0.1.143
layoutparser==0.3.4
lightning-utilities==0.11.9
lintrunner==0.12.5
loralib==0.1.2
lxml==5.3.0
markdown-it-py==3.0.0
MarkupSafe==2.1.5
marshmallow==3.23.1
matplotlib==3.9.2
matplotlib-inline==0.1.7
mdurl==0.1.2
more-itertools==10.3.0
mpi4py @ file:///work/ci_py311/mpi4py_1676858691457/work
mpmath==1.3.0
mscclpp @ file:///home/ajangda/mscclpp
msgspec==0.18.6
multidict==6.1.0
multiprocess==0.70.14
multivolumefile==0.2.3
mypy==1.3.0
mypy-extensions==1.0.0
necessary==0.4.3
nest-asyncio==1.6.0
netifaces==0.11.0
networkx==3.4.2
nh3==0.2.20
ninja==1.11.1.1
nltk==3.9.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
omegaconf==2.3.0
onnx==1.17.0
onnxruntime==1.20.0
openai==1.39.0
opencv-python==4.10.0.84
opentelemetry-api==1.25.0
opentelemetry-exporter-otlp-proto-common==1.25.0
opentelemetry-exporter-otlp-proto-http==1.25.0
opentelemetry-instrumentation==0.46b0
opentelemetry-instrumentation-requests==0.46b0
opentelemetry-proto==1.25.0
opentelemetry-sdk==1.25.0
opentelemetry-semantic-conventions==0.46b0
opentelemetry-util-http==0.46b0
optimum==1.23.3
optree==0.13.1
ordered-set==4.1.0
orjson==3.10.11
packaging==24.2
pandas==2.2.3
parso==0.8.4
PasteDeploy==3.1.0
pathspec==0.12.1
pbkdf2==1.3
pdf2image==1.17.0
pdfminer.six==20231228
pdfplumber==0.11.4
peewee==3.17.8
peft==0.13.2
petname==2.6
pexpect==4.9.0
pi_heif==0.20.0
pikepdf==9.4.2
pillow==11.0.0
pkginfo==1.12.0
plaster==1.1.2
plaster-pastedeploy==1.0.1
platformdirs==4.2.2
pluggy==1.5.0
portalocker==3.0.0
prettytable==3.12.0
prompt_toolkit==3.0.48
propcache==0.2.0
proto-plus==1.25.0
protobuf==4.25.5
psutil==6.1.0
ptyprocess==0.7.0
pure_eval==0.2.3
py7zr==0.22.0
pyarrow==18.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pybcj==1.0.2
pybind11==2.13.6
pybind11_global==2.13.6
pycocotools==2.0.8
pycparser==2.22
pycryptodomex==3.21.0
pydantic==2.9.2
pydantic_core==2.23.4
pydub==0.25.1
pyfastkron @ file:///home/aiscuser/ajangda/OLMo/pyfastkron-1.0.1-py3-none-any.whl#sha256=600f33c84967e12106e7e2b25f583422bf4a1a1f8dc887b5e8df54fa9bba2082
Pygments==2.18.0
pyparsing==3.2.0
pypdf==5.1.0
pypdfium2==4.30.0
pyppmd==1.1.0
pyproject_hooks==1.2.0
pyramid==2.0.2
pyramid-mailer==0.15.1
pytest==8.3.4
pytest-sphinx==0.6.3
python-dateutil==2.8.2
python-iso639==2024.10.22
python-magic==0.4.27
python-multipart==0.0.12
python3-openid==3.2.0
pytorch-triton-rocm==3.1.0
pytz==2024.2
PyYAML==6.0.1
pyzstd==0.16.2
RapidFuzz==3.10.1
readme_renderer==44.0
referencing==0.35.1
regex==2024.11.6
repoze.sendmail==4.4.1
requests==2.32.3
requests-oauthlib==2.0.0
requests-toolbelt==1.0.0
requirements-parser==0.11.0
responses==0.18.0
rfc3986==2.0.0
rich==13.5.3
rouge_score==0.1.2
rpds-py==0.21.0
rsa==4.9
ruamel.yaml==0.17.40
ruamel.yaml.clib==0.2.12
ruff==0.7.4
s3transfer==0.10.4
safehttpx==0.1.1
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
SecretStorage==3.3.3
semantic-version==2.10.0
semgrep==1.96.0
sentence-transformers==3.3.1
sentencepiece==0.2.0
sentry-sdk==2.19.2
setproctitle==1.3.4
shellingham==1.5.4
six==1.16.0
smart-open==7.1.0
smashed==0.21.5
smmap==5.0.1
sniffio==1.3.1
sortedcontainers==2.4.0
soupsieve==2.6
SQLAlchemy==2.0.36
stack-data==0.6.3
starlette==0.41.3
sympy==1.13.1
tabulate==0.9.0
tenacity==8.5.0
termcolor==2.5.0
texttable==1.7.0
threadpoolctl==3.5.0
tiktoken==0.8.0
timm==1.0.11
tokenize_rt==6.1.0
tokenizers==0.13.3
tomli==2.0.1
tomlkit==0.12.0
torch==2.5.1+rocm6.1
torchaudio==2.5.1+rocm6.1
torchmetrics==1.6.0
torchvision==0.20.1+rocm6.1
tqdm==4.67.1
traitlets==5.14.3
transaction==5.0
transformers==4.28.1
translationstring==1.4
triton==3.1.0
trouting==0.3.3
twine==6.0.1
typeguard==4.3.0
typer==0.13.1
types-dataclasses==0.6.6
types-setuptools==75.6.0.20241126
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2024.2
unstructured==0.15.8
unstructured-client==0.27.0
unstructured-inference==0.7.36
unstructured.pytesseract==0.3.13
urllib3==2.2.3
uvicorn==0.32.0
velruse==1.1.1
venusian==3.1.1
wandb==0.19.1
wcmatch==8.5.2
wcwidth==0.2.13
WebOb==1.8.9
websockets==12.0
wrapt==1.16.0
WTForms==3.2.1
wtforms-recaptcha==0.3.2
xxhash==3.5.0
yarl==1.17.2
zipp==3.19.2
zope.deprecation==5.0
zope.interface==7.2
zope.sqlalchemy==3.1

@abhijangda abhijangda added the type/bug An issue about a bug label Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type/bug An issue about a bug
Projects
None yet
Development

No branches or pull requests

1 participant