diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..8890ad9 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 The Meerkat Team. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index be5daca..a460772 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
- Meerkat logo + Meerkat logo [![GitHub](https://img.shields.io/github/license/HazyResearch/meerkat)](https://img.shields.io/github/license/HazyResearch/meerkat) [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) @@ -9,7 +9,7 @@
-Zoology provides machine learning researchers with a simple playground for understanding and testing language model architectures on synthetic tasks. This repository can be used to reproduce the results in our paper *[Zoology: Measuring and Improving Recall in Efficient Language Models](https://arxiv.org/abs/2312.04927)*. +Zoology provides machine learning researchers with a simple playground for understanding and testing language model architectures on synthetic tasks. This repository can be used to reproduce the results in our paper *[Zoology: Measuring and Improving Recall in Efficient Language Models](https://arxiv.org/abs/2312.04927)*. See the section on [reproducing paper experiments](#reproducing-paper-experiments) for details. --- @@ -30,7 +30,7 @@ pip install -e .[extra,analysis] ``` If you want to keep this install as lightweight as possible; the only required dependencies are: `torch, einops, tqdm, pydantic, wandb`. There is some extra functionality (*e.g.* launching sweeps in parallel with Ray) that require additional dependencies. To install without the optional dependencies, run `pip install -e .`. -Then, try running an example experiments with: +Then, try running an example experiment with: ``` python -m zoology.launch zoology/experiments/examples/basic.py ``` @@ -40,6 +40,24 @@ python -m zoology.launch zoology/experiments/examples/basic_sweep.py ``` If you have access to multiple GPUs, you can run the sweep in parallel by adding the `-p` flag. + +## Reproducing paper experiments + In this section, we'll show how to reproduce the results in our paper *[Zoology: Measuring and Improving Recall in Efficient Language Models](https://arxiv.org/abs/2312.04927)* and [blogpost](https://hazyresearch.stanford.edu/blog/2023-12-11-zoology1-analysis). + + The main synthetic data results in our work are summarized in Figure 2. The x-axis is the model dimension and the y-axis is accuracy on Mqar. Increasing the sequence +length correlates with increased task difficulty. The results shown are the maximum performance for each model over four learning rates. +
+ Figure 2 +
+ +To reproduce these results, ensure you have WandB setup to log all the results and then run the command: +``` +python -m zoology.launch zoology/experiments/paper/figure2.py -p +``` +Note that there are 448 model/data configurations in this sweep, so it takes a while to run. We ran most of our experiments on an 8xA100 with the `-p` flag, which launches configurations in parallel. To run a smaller scale experiment, you can modify the loops in `figure2.py` file to only include a subset of the configurations you're interested in (*e.g.* you can drop some models, sequence lengths, or learning rates). For more details on how the experiments are configured, see the [configuration section](#configuration-experiments-and-sweeps). + +To produce the plot after the run, see the plotting code `zoology/analysis/paper/figure2.py`. + ## Configuration, Experiments, and Sweeps In this section, we'll walk through how to configure an experiment and launch sweeps. @@ -169,8 +187,6 @@ When you launch an experiment with this configuration, the `my_data_builder` fun **Caching dataset creation.** Sometimes it's useful to cache the dataset creation process, especially if it's expensive. To do so you can pass a `cache_dir` to the `DataConfig`: `DataConfig(..., cache_dir="my_cache_dir")`. - - ## About This repo is being developed by members of the HazyResearch group. diff --git a/banner.png b/assets/banner.png similarity index 100% rename from banner.png rename to assets/banner.png diff --git a/assets/figure2.png b/assets/figure2.png new file mode 100644 index 0000000..b794671 Binary files /dev/null and b/assets/figure2.png differ diff --git a/zoology/analysis/paper/figure2.py b/zoology/analysis/paper/figure2.py new file mode 100644 index 0000000..904f816 --- /dev/null +++ b/zoology/analysis/paper/figure2.py @@ -0,0 +1,71 @@ +import os + +import pandas as pd +from tqdm import tqdm +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt + +from zoology.analysis.utils import fetch_wandb_runs + + + + +def plot( + df: pd.DataFrame, + max_seq_len: int = 512, +): + + plot_df = df.groupby([ + "model.sequence_mixer.name", + "model.d_model", + "data.input_seq_len", + ])["valid/accuracy"].max().reset_index() + + run_dir = "/var/cr05_data/sim_data/code/petting-zoo/" + sns.set_theme(style="whitegrid") + g = sns.relplot( + data=plot_df[plot_df["data.input_seq_len"] <= max_seq_len], + y="valid/accuracy", + col="data.input_seq_len", + x="model.d_model", + hue="model.sequence_mixer.name", + kind="line", + marker="o", + height=2.25, + aspect=1, + ) + g.set(xscale="log", ylabel="Accuracy", xlabel="Model dimension") + + # Set custom x-ticks + ticks = [64, 128, 256, 512] # Modify this list as needed + for ax in g.axes.flat: + ax.set_xticks(ticks) + ax.get_xaxis().set_major_formatter(plt.ScalarFormatter()) # This will keep the tick labels as integers rather than in scientific notation + + # Set custom y-ticks + y_ticks = [0, 0.25, 0.5, 0.75, 1.0] + for ax in g.axes.flat: + ax.set_yticks(y_ticks) + + for ax, title in zip(g.axes.flat, g.col_names): + ax.set_title(f"Sequence Length: {title}") + + +if __name__ == "__main__" : + df = fetch_wandb_runs( + launch_id=[ + "default-2023-10-25-22-20-38", + "default-2023-10-26-19-09-31", + "default-2023-10-27-04-13-56", + "default-2023-10-29-17-31-26", + "default-2023-11-12-00-31-44", + "default-2023-11-13-00-31-15", + "default-2023-11-13-00-42-27" + ], + project_name="zoology" + ) + + df["data.input_seq_len"] = df["data.input_seq_len"].fillna(df["data.0.input_seq_len"]) + plot(df=df, max_seq_len=1024) + plt.savefig("results.pdf") diff --git a/zoology/data/associative_recall.py b/zoology/data/associative_recall.py index 66361bb..8695e4e 100755 --- a/zoology/data/associative_recall.py +++ b/zoology/data/associative_recall.py @@ -136,7 +136,7 @@ def _ar( inputs[:, 0:context_size] = kvs # create a matrix of indices, which is needed to index correctly below - rows = np.tile(np.arange(num_examples), (3, 1)).T + rows = np.tile(np.arange(num_examples), (num_queries, 1)).T # sample random kv pairs to use for the queries kv_idx_choices = np.arange(0, num_kv_pairs) diff --git a/zoology/experiments/mqar_d_model.py b/zoology/experiments/mqar_dmodel.py similarity index 96% rename from zoology/experiments/mqar_d_model.py rename to zoology/experiments/mqar_dmodel.py index 1d10684..3d37a84 100644 --- a/zoology/experiments/mqar_d_model.py +++ b/zoology/experiments/mqar_dmodel.py @@ -72,6 +72,12 @@ "l_max": input_seq_len, }, ), + "rwkv5": dict( + name="zoology.mixers.rwkv5.RWKVTimeMixer", + kwargs={ + "l_max": input_seq_len, + }, + ), "base_conv": dict( name="zoology.mixers.base_conv.BaseConv", kwargs={ diff --git a/zoology/experiments/paper/figure2.py b/zoology/experiments/paper/figure2.py new file mode 100644 index 0000000..3e93171 --- /dev/null +++ b/zoology/experiments/paper/figure2.py @@ -0,0 +1,159 @@ +import uuid +import numpy as np +from zoology.config import TrainConfig, ModelConfig, DataConfig, LoggerConfig + + +sweep_id = uuid.uuid4().hex[:6] +sweep_name = "figure2" + sweep_id + + +VOCAB_SIZE = 8_192 + + +configs = [] +for input_seq_len, num_kv_pairs in [ + (64, 4), + (128, 8), + (256, 16), + (512, 64), +]: + if input_seq_len == 1024: + batch_size = 64 + elif input_seq_len == 512: + batch_size = 128 + elif input_seq_len == 256: + batch_size = 256 + else: + batch_size = 512 + + data = DataConfig( + num_train_examples=100_000, + num_test_examples=3_000, + vocab_size=VOCAB_SIZE, + input_seq_len=input_seq_len, + batch_size=batch_size, + # cache_dir="", # TODO: add a directory to cache your results! + builder={ + "name": "zoology.data.associative_recall.multiquery_ar", + "kwargs": { + "num_kv_pairs": num_kv_pairs, + "train_power_a": 0.01, + "test_power_a": 0.01, + "random_non_queries": False + } + } + ) + + for d_model in [ + 64, + 128, + 256, + 512 + ]: + for lr in np.logspace(-4, -2, 4): + + MIXERS = { + "attention": dict( + name="zoology.mixers.attention.MHA", + kwargs={ + "dropout": 0.1, + "num_heads": 1 + }, + ), + "hyena": dict( + name="zoology.mixers.hyena.Hyena", + kwargs={ + "l_max": input_seq_len + }, + ), + "rwkv": dict( + name="zoology.mixers.rwkv.RWKVTimeMixer", + kwargs={ + "l_max": input_seq_len, + }, + ), + "base_conv": dict( + name="zoology.mixers.base_conv.BaseConv", + kwargs={ + "l_max": input_seq_len, + # pass a list of kernel sizes for each of four layers + "kernel_size": [3, -1, 3, -1] + } + ), + "h3": dict( + name="zoology.mixers.h3.H3", + kwargs={ + "l_max": input_seq_len, + "d_state": input_seq_len, # makes it mathematically equivalent to Hyena + "head_dim": 2 + } + ), + "based": dict( + name="zoology.mixers.hybrid.Hybrid", + kwargs={ + "configs": [ + dict( + name="zoology.mixers.base_conv.BaseConv", + kwargs={ + "l_max": input_seq_len, + # pass a list of kernel sizes for each of four layers + "kernel_size": 3, + "implicit_long_conv": True, + } + ), + dict( + name="zoology.mixers.based.Based", + kwargs={ + "l_max": input_seq_len, + "feature_dim": 8, + "num_key_value_heads": 1, + "num_heads": 1, + "feature_name": "taylor_exp" + } + ) + ] + } + ), + "mamba": dict( + name="zoology.mixers.mamba.Mamba", + kwargs={} + ), + } + + for sequence_mixer in [ + "attention", + "hyena", + "rwkv", + "base_conv", + "h3", + "based", + "mamba" + ]: + + if 'mamba' in sequence_mixer: + block_type = "MambaBlock" + else: + block_type = "TransformerBlock" + + model = ModelConfig( + d_model=d_model, + n_layers=4 if sequence_mixer != "attention" else 2, + block_type=block_type, + max_position_embeddings=input_seq_len if sequence_mixer == "attention" else 0, + vocab_size=VOCAB_SIZE, + sequence_mixer=MIXERS[sequence_mixer], + state_mixer=dict(name="torch.nn.Identity", kwargs={}) + ) + config = TrainConfig( + model=model, + data=data, + learning_rate=lr, + max_epochs=64, + run_id=f"{sequence_mixer}-seqlen{input_seq_len}-dmodel{d_model}-lr{lr}-kv{num_kv_pairs}", + logger=LoggerConfig( + project_name="zoology", + entity="hazy-research" + ) + + ) + configs.append(config) \ No newline at end of file diff --git a/zoology/mixers/based.py b/zoology/mixers/based.py index 4f6ee3e..76b35e3 100644 --- a/zoology/mixers/based.py +++ b/zoology/mixers/based.py @@ -69,11 +69,11 @@ def forward_mem_save(self, x: torch.Tensor) -> torch.Tensor: -> Assume x.shape is (batch_size, n_heads, seq_len, head_dim) """ # Slow but memory-saving way to compute 2nd-order terms; how do w/o outer-product first? - x2 = oe.contract('...m,...n->...mn', x, x) / self.input_dim + x2 = oe.contract('...m,...n->...mn', x, x) / self.rd x2d = torch.diagonal(x2, dim1=-2, dim2=-1) / self.r2 x2 = x2[..., self.tril_indices[0], self.tril_indices[1]] x = torch.cat([torch.ones(x[..., :1].shape).to(x.device), - x / self.rd, x2d, x2], dim=-1) + x / self.rrd, x2d, x2], dim=-1) return x diff --git a/zoology/mixers/rwkv.py b/zoology/mixers/rwkv.py index ec2429e..b4d7c70 100644 --- a/zoology/mixers/rwkv.py +++ b/zoology/mixers/rwkv.py @@ -38,9 +38,10 @@ def backward(ctx, grad_output): # # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice from torch.utils.cpp_extension import load -wkv_cuda = load(name="wkv", sources= - ["/var/cr05_data/sim_data/code/petting-zoo/src/models/mixers/cuda/wkv_op.cpp", - "/var/cr05_data/sim_data/code/petting-zoo/src/models/mixers/cuda/wkv_cuda.cu"], +dir_path = os.path.dirname(os.path.realpath(__file__)) +wkv_cuda = load(name="wkv", sources=[ + os.path.join(dir_path, "./rwkv/v4/wkv_op.cpp"), + os.path.join(dir_path, "./rwkv/v4/wkv_cuda.cu")], verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) class WKV(torch.autograd.Function): diff --git a/zoology/mixers/rwkv/v4/wkv_cuda.cu b/zoology/mixers/rwkv/v4/wkv_cuda.cu new file mode 100644 index 0000000..a4522cb --- /dev/null +++ b/zoology/mixers/rwkv/v4/wkv_cuda.cu @@ -0,0 +1,125 @@ +#include +#include + +#define MIN_VALUE (-1e38) + +template +__global__ void kernel_forward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + F *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + F p = 0, q = 0, o = MIN_VALUE; + // p and q are running sums divided by exp(o) (to avoid overflows) + for (int i = 0; i < T; i++) { + const int ii = i * C; + + F no = max(o, u + k[ii]); + F A = exp(o - no); + F B = exp(u + k[ii] - no); + y[ii] = (A * p + B * v[ii]) / (A * q + B); + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, + F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const gy = _gy + _offset; + + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F y[Tmax], z[Tmax], zexp[Tmax]; + + F gw = 0, gu = 0; + F p = 0, q = 0; + F dpdw = 0, dqdw = 0; + F o = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + F no = max(o, k[ii] + u); + F A = exp(o - no); + F B = exp(k[ii] + u - no); + + F num = A * p + B * v[ii]; + F iden = 1 / (A * q + B); + + y[i] = num * iden; + z[i] = iden; + zexp[i] = k[ii] + u - no; + + gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; + gu += gy[ii] * (v[ii] - y[i]) * B * iden; + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + dpdw = A * (p + dpdw); + dqdw = A * (q + dqdw); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } + + F gp = 0, gq = 0; + o = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + F A = gy[ii] * z[i] * exp(zexp[i]); + F B = exp(k[ii] + o); + gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); + gv[ii] = A + B * gp; + + F no = max(w + o, zexp[i] - k[ii] - u); + A = exp(w + o - no); + B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); + gp = A * gp + B; + gq = A * gq - B * y[i]; + o = no; + } + + // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] += gw * _w[_c]; + _gu[_offsetBC] += gu; +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); +} \ No newline at end of file diff --git a/zoology/mixers/rwkv/v4/wkv_op.cpp b/zoology/mixers/rwkv/v4/wkv_op.cpp new file mode 100644 index 0000000..e59a515 --- /dev/null +++ b/zoology/mixers/rwkv/v4/wkv_op.cpp @@ -0,0 +1,21 @@ +#include + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} \ No newline at end of file diff --git a/zoology/mixers/rwkv/v5/wkv5_cuda.cu b/zoology/mixers/rwkv/v5/wkv5_cuda.cu new file mode 100644 index 0000000..18bc730 --- /dev/null +++ b/zoology/mixers/rwkv/v5/wkv5_cuda.cu @@ -0,0 +1,202 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _w += h*_N_; + _u += h*_N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_] = {0}; + + __syncthreads(); + w[i] = _w[i]; + u[i] = float(_u[i]); + __syncthreads(); + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, + F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _w += h*_N_; + _u += h*_N_; + __w += h*_N_; + + __shared__ float w_[_N_], u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_]; + __syncthreads(); + w_[i] = _w[i]; + u_[i] = float(_u[i]); + __syncthreads(); + + const float w = w_[i]; + const float ww = __w[i]; + const float u = u_[i]; + + float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; + + float gw = 0, gu = 0; + const int t000 = b*T*C + h*_N_ + i; + const int t111 = (b+1)*T*C + h*_N_ + i; + const int t222 = t111 - 2*C; + + for (int t = t000; t < t111; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + float gr = 0, gu_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b*C + h*_N_ + i] = F(gu); + + for (int t = t000; t < t222; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t + 2*C]); + __syncthreads(); + + const float k = float(_k[t]); + float gw_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + float& s2 = sbbbb[j]; + float x = k * v[j]; + + float tmp = w * (x + s); + s = tmp; + s2 = tmp + w * s2; + gw_ += s2 * gy[j]; + } + gw += float(_r[t + 2*C]) * gw_; + } + _gw[b*C + h*_N_ + i] = F(ww * gw); + + for (int t = t111 - C; t >= t000; t -= C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + float gk = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t111 - C; t >= t000; t -= C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } +} + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); +} + +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); +} \ No newline at end of file diff --git a/zoology/mixers/rwkv/v5/wkv5_op.cpp b/zoology/mixers/rwkv/v5/wkv5_op.cpp new file mode 100644 index 0000000..aef76a4 --- /dev/null +++ b/zoology/mixers/rwkv/v5/wkv5_op.cpp @@ -0,0 +1,22 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { + cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv5 forward"); + m.def("backward", &backward, "wkv5 backward"); +} + +TORCH_LIBRARY(wkv5, m) { + m.def("forward", forward); + m.def("backward", backward); +} \ No newline at end of file diff --git a/zoology/mixers/rwkv5.py b/zoology/mixers/rwkv5.py new file mode 100644 index 0000000..5aa3972 --- /dev/null +++ b/zoology/mixers/rwkv5.py @@ -0,0 +1,616 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import os, math, gc, importlib +import torch +# torch._C._jit_set_profiling_executor(True) +# torch._C._jit_set_profiling_mode(True) +import torch.nn as nn +from torch.nn import functional as F +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from pytorch_lightning.strategies import DeepSpeedStrategy +if importlib.util.find_spec('deepspeed'): + import deepspeed + from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam + +# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam + +try: + print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"]) +except: + os.environ["RWKV_MY_TESTING"] = '' + +def __nop(ob): + return ob + + +MyModule = nn.Module +MyFunction = __nop +os.environ['RWKV_JIT_ON'] = '1' +if os.environ["RWKV_JIT_ON"] == "1": + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + +os.environ['RWKV_FLOAT_MODE'] = 'bf16' + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + + +from torch.utils.cpp_extension import load + +# HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"]) +HEAD_SIZE = 64 +dir_path = os.path.dirname(os.path.realpath(__file__)) +wkv5_cuda = load(name="wkv5", sources=[ + os.path.join(dir_path, "./rwkv/v5/wkv5_op.cpp"), + os.path.join(dir_path, "./rwkv/v5/wkv5_cuda.cu")], + verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"]) + +class WKV_5(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, r, k, v, w, u): + with torch.no_grad(): + assert r.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 + assert v.dtype == torch.bfloat16 + assert w.dtype == torch.bfloat16 + assert u.dtype == torch.bfloat16 + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + ew = (-torch.exp(w.float())).contiguous() + eew = (torch.exp(ew)).contiguous() + ctx.save_for_backward(r, k, v, eew, ew, u) + y = torch.empty((B, T, C), device=r.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1) + wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y) + return y + + @staticmethod + def backward(ctx, gy): + with torch.no_grad(): + assert gy.dtype == torch.bfloat16 + B = ctx.B + T = ctx.T + C = ctx.C + H = ctx.H + assert gy.is_contiguous() + r, k, v, eew, ew, u = ctx.saved_tensors + gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1) + gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1) + gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1) + gw = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1) + gu = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1) + wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu) + gw = torch.sum(gw, 0).view(H, C//H) + gu = torch.sum(gu, 0).view(H, C//H) + return (None, None, None, None, gr, gk, gv, gw, gu) + +def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u): + return WKV_5.apply(B, T, C, H, r, k, v, w, u) + +######################################################################################################## + +class RWKVTimeMixer(MyModule): + def __init__(self, l_max: int, d_model: int = 512, n_layer: int=12, layer_idx: int=-1): + + super().__init__() + self.layer_idx = layer_idx + self.ctx_len = l_max + self.d_model = d_model + dim_att = d_model + + # self.head_size = args.head_size_a + self.head_size = 64 + assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a + self.n_head = dim_att // self.head_size + assert dim_att % self.n_head == 0 + # self.head_size_divisor = args.head_size_divisor + self.head_size_divisor = 8 + + with torch.no_grad(): + ratio_0_to_1 = layer_idx / (n_layer - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_idx / n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, d_model) + for i in range(d_model): + ddd[0, 0, i] = i / d_model + + # fancy time_mix + self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) + self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) + + # fancy time_decay + decay_speed = torch.ones(dim_att) + for n in range(dim_att): + decay_speed[n] = -6 + 5 * (n / (dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + self.time_decay = nn.Parameter(decay_speed.reshape(self.n_head, self.head_size)) + # print(layer_idx, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) + + tmp = torch.zeros(dim_att) + for n in range(dim_att): + zigzag = ((n + 1) % 3 - 1) * 0.1 + tmp[n] = ratio_0_to_1 * (1 - (n / (dim_att - 1))) + zigzag + + self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.receptance = nn.Linear(d_model, dim_att, bias=False) + self.key = nn.Linear(d_model, dim_att, bias=False) + + self.value = nn.Linear(d_model, dim_att, bias=False) + self.output = nn.Linear(dim_att, d_model, bias=False) + self.gate = nn.Linear(d_model, dim_att, bias=False) + self.ln_x = nn.GroupNorm(self.n_head, dim_att) + + @MyFunction + def jit_func(self, x): + B, T, C = x.size() + + xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + xg = x * self.time_mix_g + xx * (1 - self.time_mix_g) + + r = self.receptance(xr) + k = self.key(xk) + v = self.value(xv) + g = F.silu(self.gate(xg)) + + return r, k, v, g + + @MyFunction + def jit_func_2(self, x, g): + B, T, C = x.size() + x = x.view(B * T, C) + + x = self.ln_x(x / self.head_size_divisor).view(B, T, C) + x = self.output(x * g) + return x + + def forward(self, x: torch.Tensor): + B, T, C = x.size() + H = self.n_head + + r, k, v, g = self.jit_func(x) + + # Convert tensors to bfloat16 for CUDA operation + r_bf16 = r.to(dtype=torch.bfloat16) + k_bf16 = k.to(dtype=torch.bfloat16) + v_bf16 = v.to(dtype=torch.bfloat16) + w_bf16 = self.time_decay.to(dtype=torch.bfloat16) + u_bf16 = self.time_faaaa.to(dtype=torch.bfloat16) + + x = RUN_CUDA_RWKV5(B, T, C, H, r_bf16, k_bf16, v_bf16, w_bf16, u_bf16) + + # Convert back to Float for subsequent operations + x_float = x.to(dtype=torch.float32) + + return self.jit_func_2(x_float, g) + +######################################################################################################## + +class RWKV_ChannelMix(MyModule): + def __init__(self, args, d_model=512, n_layer=12, layer_idx=1): + super().__init__() + self.args = args + self.layer_idx = layer_idx + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_idx / n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, d_model) + for i in range(d_model): + ddd[0, 0, i] = i / d_model + self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + + self.key = nn.Linear(d_model, self.dim_ffn, bias=False) + self.receptance = nn.Linear(d_model, d_model, bias=False) + self.value = nn.Linear(self.dim_ffn, d_model, bias=False) + + @MyFunction + def forward(self, x): + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + k = self.key(xk) + k = torch.relu(k) ** 2 + kv = self.value(k) + return torch.sigmoid(self.receptance(xr)) * kv + +class MishGLU(MyModule): + def __init__(self, args, d_model, layer_idx, n_layer): + super().__init__() + self.args = args + self.layer_idx = layer_idx + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): + ratio_1_to_almost0 = 1.0 - (layer_idx / n_layer) + + x = torch.ones(1, 1,d_model) + for i in range(d_model): + x[0, 0, i] = i /d_model + + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.aa = nn.Linear(d_model, args.dim_ffn, bias=False) + self.bb = nn.Linear(d_model, args.dim_ffn, bias=False) + self.value = nn.Linear(args.dim_ffn,d_model, bias=False) + + @MyFunction + def forward(self, x): + xx = self.time_shift(x) + xa = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xb = x * self.time_mix_r + xx * (1 - self.time_mix_r) + a = self.aa(xa) + b = self.bb(xb) + return self.value(a * F.mish(b)) + +######################################################################################################## +# The RWKV Model with our blocks +######################################################################################################## + + +class Block(nn.Module,): + def __init__(self, args, d_model, layer_idx): + super().__init__() + self.args = args + self.layer_idx = layer_idx + + self.ln1 = nn.LayerNorm(d_model) + self.ln2 = nn.LayerNorm(d_model) + + if self.layer_idx == 0: + self.ln0 = nn.LayerNorm(d_model) + if args.my_pos_emb > 0: + self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,d_model))) + self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,d_model))) + + if self.layer_idx == 0 and self.args.pre_ffn > 0: + self.ffnPre = RWKV_ChannelMix(args, 0) + else: + self.att = RWKVTimeMixer(args, layer_idx) + + if 'g' in os.environ["RWKV_MY_TESTING"]: + self.ffn = MishGLU(args, layer_idx) + else: + self.ffn = RWKV_ChannelMix(args, layer_idx) + + if args.tiny_att_dim > 0 and self.layer_idx == args.tiny_att_layer: + self.tiny_ln = nn.LayerNorm(d_model) + self.tiny_q = nn.Linear(d_model, args.tiny_att_dim, bias=False) + self.tiny_k = nn.Linear(d_model, args.tiny_att_dim, bias=False) + self.tiny_v = nn.Linear(d_model,d_model, bias=False) + self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + + if args.dropout > 0: + self.drop0 = nn.Dropout(p = args.dropout) + self.drop1 = nn.Dropout(p = args.dropout) + + def forward(self, x, x_emb=None): + args = self.args + B, T, C = x.size() + if self.layer_idx == 0: + x = self.ln0(x) + if args.my_pos_emb > 0: + pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] + x = x + pos_emb + + if self.args.dropout == 0: + if self.layer_idx == 0 and args.pre_ffn > 0: + x = x + self.ffnPre(self.ln1(x)) + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + else: + if self.layer_idx == 0 and args.pre_ffn > 0: + x = self.drop0(x + self.ffnPre(self.ln1(x))) + else: + x = self.drop0(x + self.att(self.ln1(x))) + x = self.drop1(x + self.ffn(self.ln2(x))) + + if args.tiny_att_dim > 0 and self.layer_idx == args.tiny_att_layer: + xx = self.tiny_ln(x) + q = self.tiny_q(xx)[:, :T, :] + k = self.tiny_k(xx)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5)) + c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0) + x = x + c @ self.tiny_v(x_emb) + return x + + +class L2Wrap(torch.autograd.Function): + @staticmethod + def forward(ctx, loss, y): + ctx.save_for_backward(y) + return loss + + @staticmethod + def backward(ctx, grad_output): + y = ctx.saved_tensors[0] + # to encourage the logits to be close to 0 + factor = 1e-4 / (y.shape[0] * y.shape[1]) + maxx, ids = torch.max(y, -1, keepdim=True) + gy = torch.zeros_like(y) + gy.scatter_(-1, ids, maxx * factor) + return (grad_output, gy) + + +class RWKV(pl.LightningModule): + def __init__(self, args, vocab_size, d_model, n_layer): + super().__init__() + self.args = args + if not hasattr(args, 'dim_att'): + args.dim_att = args.n_embd + if not hasattr(args, 'dim_ffn'): + args.dim_ffn = args.n_embd * 4 + if not hasattr(args, 'tiny_att_layer'): + args.tiny_att_layer = -1 + if not hasattr(args, 'tiny_att_dim'): + args.tiny_att_dim = -1 + assert args.n_embd % 32 == 0 + assert args.dim_att % 32 == 0 + assert args.dim_ffn % 32 == 0 + + self.emb = nn.Embedding(vocab_size,d_model) + + self.blocks = nn.ModuleList([Block(args, i) for i in range(n_layer)]) + + self.ln_out = nn.LayerNorm(d_model) + self.head = nn.Linear(d_model, vocab_size, bias=False) + + if args.head_qk > 0: + self.head_q = nn.Linear(d_model, args.head_qk, bias=False) + self.head_k = nn.Linear(d_model, args.head_qk, bias=False) + self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + if args.dropout > 0: + self.drop0 = nn.Dropout(p = args.dropout) + + def configure_optimizers(self): + args = self.args + + lr_decay = set() + lr_1x = set() + lr_2x = set() + lr_3x = set() + for n, p in self.named_parameters(): + if ("time_mix" in n) and (args.layerwise_lr > 0): + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) + elif ("time_decay" in n) and (args.layerwise_lr > 0): + if args.my_pile_stage == 2: + lr_3x.add(n) + else: + lr_2x.add(n) + elif ("time_faaaa" in n) and (args.layerwise_lr > 0): + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) + elif ("time_first" in n) and (args.layerwise_lr > 0): + lr_3x.add(n) + elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0): + lr_decay.add(n) + else: + lr_1x.add(n) + + lr_decay = sorted(list(lr_decay)) + lr_1x = sorted(list(lr_1x)) + lr_2x = sorted(list(lr_2x)) + lr_3x = sorted(list(lr_3x)) + # print('decay', lr_decay) + # print('1x', lr_1x) + # print('2x', lr_2x) + # print('3x', lr_3x) + param_dict = {n: p for n, p in self.named_parameters()} + + if args.layerwise_lr > 0: + if args.my_pile_stage == 2: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + ] + else: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + ] + else: + optim_groups = [{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}] + + if args.weight_decay > 0: + optim_groups += [{"params": [param_dict[n] for n in lr_decay], "weight_decay": args.weight_decay, "my_lr_scale": 1.0}] + if self.deepspeed_offload: + return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=True, amsgrad=False) + return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=True, amsgrad=False) + else: + if self.deepspeed_offload: + return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + + @property + def deepspeed_offload(self) -> bool: + strategy = self.trainer.strategy + if isinstance(strategy, DeepSpeedStrategy): + cfg = strategy.config["zero_optimization"] + return cfg.get("offload_optimizer") or cfg.get("offload_param") + return False + + def forward(self, idx): + args = self.args + B, T = idx.size() + assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." + + x = self.emb(idx) + x_emb = x + + if args.dropout > 0: + x = self.drop0(x) + if args.tiny_att_dim > 0: + for block in self.blocks: + if args.grad_cp == 1: + x = deepspeed.checkpointing.checkpoint(block, x, x_emb) + else: + x = block(x, x_emb) + else: + for block in self.blocks: + if args.grad_cp == 1: + x = deepspeed.checkpointing.checkpoint(block, x) + else: + x = block(x) + + x = self.ln_out(x) + + if args.head_qk > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + + if "32" in os.environ["RWKV_FLOAT_MODE"]: + c = c @ F.one_hot(idx, num_classes=args.vocab_size) + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() + + x = self.head(x) + c + else: + x = self.head(x) + + return x + + def training_step(self, batch, batch_idx): + args = self.args + if args.my_qa_mask != 1: + idx, targets = batch + logits = self(idx) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + # if '0' in os.environ["RWKV_MY_TESTING"]: + # print('logits', logits) + # torch.set_printoptions(threshold=10000) + # print('idx', idx) + # exit(0) + else: + idx, targets, mask = batch + mask = mask.view(-1) + sum_mask = torch.sum(mask).item() + # if sum_mask == 0: + # return torch.tensor([0.0], requires_grad=True) + + logits = self(idx) + if sum_mask == mask.shape[0]: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + # print('rank', self.global_rank, 'loss', loss.item()) + else: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') + # loss_raw = loss + loss = torch.sum(loss * mask) / sum_mask + + # torch.set_printoptions(threshold=10000) + # if True: #self.global_rank == 1: + # tmp = '' + # sss = 0 + # ccc = 0 + # for i in range(mask.shape[0]): + # if mask[i] > 0: + # tmp += str(idx.view(-1)[i].item()) + ',' + # sss += loss_raw.view(-1)[i].float().item() + # ccc += 1 + # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx) + + return L2Wrap.apply(loss, logits) + + def training_step_end(self, batch_parts): + if pl.__version__[0]!='2': + all = self.all_gather(batch_parts) + if self.trainer.is_global_zero: + self.trainer.my_loss_all = all + + def generate_init_weight(self): + print( + f""" +############################################################################ +# +# Init model weight (slow for large models)... +# +############################################################################ +""" + ) + m = {} + for n in self.state_dict(): + p = self.state_dict()[n] + shape = p.shape + + gain = 1.0 + scale = 1.0 + if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n: + if 'ln_x.weight' in n: + layer_scale = (1+int(n.split('.')[1])) / self.n_layer + m[n] = (p * 0.0) + (layer_scale ** 0.7) + else: + m[n] = p + else: + if n == "emb.weight": + scale = -1 * self.args.lr_init + else: + if shape[0] > shape[1]: + gain = math.sqrt(shape[0] / shape[1]) + + zero = [".att.output.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.'] + + for kk in zero: + if kk in n: + scale = 0 + if n == "head.weight": + scale = 0.5 + if "head_k." in n: + scale = 0.1 + if "head_q." in n: + scale = 0 + + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}") + + if self.args.accelerator.upper() == "GPU": + m[n] = torch.empty((shape[0], shape[1]), device="cuda") + else: + m[n] = torch.empty((shape[0], shape[1])) + + if scale == 0: + nn.init.zeros_(m[n]) + elif scale < 0: + nn.init.uniform_(m[n], a=scale, b=-scale) + else: + nn.init.orthogonal_(m[n], gain=gain * scale) + + m[n] = m[n].cpu() + if os.environ["RWKV_FLOAT_MODE"] == "fp16": + m[n] = m[n].half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + m[n] = m[n].bfloat16() + + # if n == "emb.weight": + # print(m[n]) + + gc.collect() + torch.cuda.empty_cache() + return m