diff --git a/LICENSE.md b/LICENSE.md
new file mode 100644
index 0000000..8890ad9
--- /dev/null
+++ b/LICENSE.md
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2021 The Meerkat Team.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/README.md b/README.md
index be5daca..a460772 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
-
![Meerkat logo](banner.png)
+
![Meerkat logo](assets/banner.png)
[![GitHub](https://img.shields.io/github/license/HazyResearch/meerkat)](https://img.shields.io/github/license/HazyResearch/meerkat)
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit)
@@ -9,7 +9,7 @@
-Zoology provides machine learning researchers with a simple playground for understanding and testing language model architectures on synthetic tasks. This repository can be used to reproduce the results in our paper *[Zoology: Measuring and Improving Recall in Efficient Language Models](https://arxiv.org/abs/2312.04927)*.
+Zoology provides machine learning researchers with a simple playground for understanding and testing language model architectures on synthetic tasks. This repository can be used to reproduce the results in our paper *[Zoology: Measuring and Improving Recall in Efficient Language Models](https://arxiv.org/abs/2312.04927)*. See the section on [reproducing paper experiments](#reproducing-paper-experiments) for details.
---
@@ -30,7 +30,7 @@ pip install -e .[extra,analysis]
```
If you want to keep this install as lightweight as possible; the only required dependencies are: `torch, einops, tqdm, pydantic, wandb`. There is some extra functionality (*e.g.* launching sweeps in parallel with Ray) that require additional dependencies. To install without the optional dependencies, run `pip install -e .`.
-Then, try running an example experiments with:
+Then, try running an example experiment with:
```
python -m zoology.launch zoology/experiments/examples/basic.py
```
@@ -40,6 +40,24 @@ python -m zoology.launch zoology/experiments/examples/basic_sweep.py
```
If you have access to multiple GPUs, you can run the sweep in parallel by adding the `-p` flag.
+
+## Reproducing paper experiments
+ In this section, we'll show how to reproduce the results in our paper *[Zoology: Measuring and Improving Recall in Efficient Language Models](https://arxiv.org/abs/2312.04927)* and [blogpost](https://hazyresearch.stanford.edu/blog/2023-12-11-zoology1-analysis).
+
+ The main synthetic data results in our work are summarized in Figure 2. The x-axis is the model dimension and the y-axis is accuracy on Mqar. Increasing the sequence
+length correlates with increased task difficulty. The results shown are the maximum performance for each model over four learning rates.
+
+
![Figure 2](assets/figure2.png)
+
+
+To reproduce these results, ensure you have WandB setup to log all the results and then run the command:
+```
+python -m zoology.launch zoology/experiments/paper/figure2.py -p
+```
+Note that there are 448 model/data configurations in this sweep, so it takes a while to run. We ran most of our experiments on an 8xA100 with the `-p` flag, which launches configurations in parallel. To run a smaller scale experiment, you can modify the loops in `figure2.py` file to only include a subset of the configurations you're interested in (*e.g.* you can drop some models, sequence lengths, or learning rates). For more details on how the experiments are configured, see the [configuration section](#configuration-experiments-and-sweeps).
+
+To produce the plot after the run, see the plotting code `zoology/analysis/paper/figure2.py`.
+
## Configuration, Experiments, and Sweeps
In this section, we'll walk through how to configure an experiment and launch sweeps.
@@ -169,8 +187,6 @@ When you launch an experiment with this configuration, the `my_data_builder` fun
**Caching dataset creation.** Sometimes it's useful to cache the dataset creation process, especially if it's expensive. To do so you can pass a `cache_dir` to the `DataConfig`: `DataConfig(..., cache_dir="my_cache_dir")`.
-
-
## About
This repo is being developed by members of the HazyResearch group.
diff --git a/banner.png b/assets/banner.png
similarity index 100%
rename from banner.png
rename to assets/banner.png
diff --git a/assets/figure2.png b/assets/figure2.png
new file mode 100644
index 0000000..b794671
Binary files /dev/null and b/assets/figure2.png differ
diff --git a/zoology/analysis/paper/figure2.py b/zoology/analysis/paper/figure2.py
new file mode 100644
index 0000000..904f816
--- /dev/null
+++ b/zoology/analysis/paper/figure2.py
@@ -0,0 +1,71 @@
+import os
+
+import pandas as pd
+from tqdm import tqdm
+import numpy as np
+import seaborn as sns
+import matplotlib.pyplot as plt
+
+from zoology.analysis.utils import fetch_wandb_runs
+
+
+
+
+def plot(
+ df: pd.DataFrame,
+ max_seq_len: int = 512,
+):
+
+ plot_df = df.groupby([
+ "model.sequence_mixer.name",
+ "model.d_model",
+ "data.input_seq_len",
+ ])["valid/accuracy"].max().reset_index()
+
+ run_dir = "/var/cr05_data/sim_data/code/petting-zoo/"
+ sns.set_theme(style="whitegrid")
+ g = sns.relplot(
+ data=plot_df[plot_df["data.input_seq_len"] <= max_seq_len],
+ y="valid/accuracy",
+ col="data.input_seq_len",
+ x="model.d_model",
+ hue="model.sequence_mixer.name",
+ kind="line",
+ marker="o",
+ height=2.25,
+ aspect=1,
+ )
+ g.set(xscale="log", ylabel="Accuracy", xlabel="Model dimension")
+
+ # Set custom x-ticks
+ ticks = [64, 128, 256, 512] # Modify this list as needed
+ for ax in g.axes.flat:
+ ax.set_xticks(ticks)
+ ax.get_xaxis().set_major_formatter(plt.ScalarFormatter()) # This will keep the tick labels as integers rather than in scientific notation
+
+ # Set custom y-ticks
+ y_ticks = [0, 0.25, 0.5, 0.75, 1.0]
+ for ax in g.axes.flat:
+ ax.set_yticks(y_ticks)
+
+ for ax, title in zip(g.axes.flat, g.col_names):
+ ax.set_title(f"Sequence Length: {title}")
+
+
+if __name__ == "__main__" :
+ df = fetch_wandb_runs(
+ launch_id=[
+ "default-2023-10-25-22-20-38",
+ "default-2023-10-26-19-09-31",
+ "default-2023-10-27-04-13-56",
+ "default-2023-10-29-17-31-26",
+ "default-2023-11-12-00-31-44",
+ "default-2023-11-13-00-31-15",
+ "default-2023-11-13-00-42-27"
+ ],
+ project_name="zoology"
+ )
+
+ df["data.input_seq_len"] = df["data.input_seq_len"].fillna(df["data.0.input_seq_len"])
+ plot(df=df, max_seq_len=1024)
+ plt.savefig("results.pdf")
diff --git a/zoology/data/associative_recall.py b/zoology/data/associative_recall.py
index 66361bb..8695e4e 100755
--- a/zoology/data/associative_recall.py
+++ b/zoology/data/associative_recall.py
@@ -136,7 +136,7 @@ def _ar(
inputs[:, 0:context_size] = kvs
# create a matrix of indices, which is needed to index correctly below
- rows = np.tile(np.arange(num_examples), (3, 1)).T
+ rows = np.tile(np.arange(num_examples), (num_queries, 1)).T
# sample random kv pairs to use for the queries
kv_idx_choices = np.arange(0, num_kv_pairs)
diff --git a/zoology/experiments/mqar_d_model.py b/zoology/experiments/mqar_dmodel.py
similarity index 96%
rename from zoology/experiments/mqar_d_model.py
rename to zoology/experiments/mqar_dmodel.py
index 1d10684..3d37a84 100644
--- a/zoology/experiments/mqar_d_model.py
+++ b/zoology/experiments/mqar_dmodel.py
@@ -72,6 +72,12 @@
"l_max": input_seq_len,
},
),
+ "rwkv5": dict(
+ name="zoology.mixers.rwkv5.RWKVTimeMixer",
+ kwargs={
+ "l_max": input_seq_len,
+ },
+ ),
"base_conv": dict(
name="zoology.mixers.base_conv.BaseConv",
kwargs={
diff --git a/zoology/experiments/paper/figure2.py b/zoology/experiments/paper/figure2.py
new file mode 100644
index 0000000..3e93171
--- /dev/null
+++ b/zoology/experiments/paper/figure2.py
@@ -0,0 +1,159 @@
+import uuid
+import numpy as np
+from zoology.config import TrainConfig, ModelConfig, DataConfig, LoggerConfig
+
+
+sweep_id = uuid.uuid4().hex[:6]
+sweep_name = "figure2" + sweep_id
+
+
+VOCAB_SIZE = 8_192
+
+
+configs = []
+for input_seq_len, num_kv_pairs in [
+ (64, 4),
+ (128, 8),
+ (256, 16),
+ (512, 64),
+]:
+ if input_seq_len == 1024:
+ batch_size = 64
+ elif input_seq_len == 512:
+ batch_size = 128
+ elif input_seq_len == 256:
+ batch_size = 256
+ else:
+ batch_size = 512
+
+ data = DataConfig(
+ num_train_examples=100_000,
+ num_test_examples=3_000,
+ vocab_size=VOCAB_SIZE,
+ input_seq_len=input_seq_len,
+ batch_size=batch_size,
+ # cache_dir="", # TODO: add a directory to cache your results!
+ builder={
+ "name": "zoology.data.associative_recall.multiquery_ar",
+ "kwargs": {
+ "num_kv_pairs": num_kv_pairs,
+ "train_power_a": 0.01,
+ "test_power_a": 0.01,
+ "random_non_queries": False
+ }
+ }
+ )
+
+ for d_model in [
+ 64,
+ 128,
+ 256,
+ 512
+ ]:
+ for lr in np.logspace(-4, -2, 4):
+
+ MIXERS = {
+ "attention": dict(
+ name="zoology.mixers.attention.MHA",
+ kwargs={
+ "dropout": 0.1,
+ "num_heads": 1
+ },
+ ),
+ "hyena": dict(
+ name="zoology.mixers.hyena.Hyena",
+ kwargs={
+ "l_max": input_seq_len
+ },
+ ),
+ "rwkv": dict(
+ name="zoology.mixers.rwkv.RWKVTimeMixer",
+ kwargs={
+ "l_max": input_seq_len,
+ },
+ ),
+ "base_conv": dict(
+ name="zoology.mixers.base_conv.BaseConv",
+ kwargs={
+ "l_max": input_seq_len,
+ # pass a list of kernel sizes for each of four layers
+ "kernel_size": [3, -1, 3, -1]
+ }
+ ),
+ "h3": dict(
+ name="zoology.mixers.h3.H3",
+ kwargs={
+ "l_max": input_seq_len,
+ "d_state": input_seq_len, # makes it mathematically equivalent to Hyena
+ "head_dim": 2
+ }
+ ),
+ "based": dict(
+ name="zoology.mixers.hybrid.Hybrid",
+ kwargs={
+ "configs": [
+ dict(
+ name="zoology.mixers.base_conv.BaseConv",
+ kwargs={
+ "l_max": input_seq_len,
+ # pass a list of kernel sizes for each of four layers
+ "kernel_size": 3,
+ "implicit_long_conv": True,
+ }
+ ),
+ dict(
+ name="zoology.mixers.based.Based",
+ kwargs={
+ "l_max": input_seq_len,
+ "feature_dim": 8,
+ "num_key_value_heads": 1,
+ "num_heads": 1,
+ "feature_name": "taylor_exp"
+ }
+ )
+ ]
+ }
+ ),
+ "mamba": dict(
+ name="zoology.mixers.mamba.Mamba",
+ kwargs={}
+ ),
+ }
+
+ for sequence_mixer in [
+ "attention",
+ "hyena",
+ "rwkv",
+ "base_conv",
+ "h3",
+ "based",
+ "mamba"
+ ]:
+
+ if 'mamba' in sequence_mixer:
+ block_type = "MambaBlock"
+ else:
+ block_type = "TransformerBlock"
+
+ model = ModelConfig(
+ d_model=d_model,
+ n_layers=4 if sequence_mixer != "attention" else 2,
+ block_type=block_type,
+ max_position_embeddings=input_seq_len if sequence_mixer == "attention" else 0,
+ vocab_size=VOCAB_SIZE,
+ sequence_mixer=MIXERS[sequence_mixer],
+ state_mixer=dict(name="torch.nn.Identity", kwargs={})
+ )
+ config = TrainConfig(
+ model=model,
+ data=data,
+ learning_rate=lr,
+ max_epochs=64,
+ run_id=f"{sequence_mixer}-seqlen{input_seq_len}-dmodel{d_model}-lr{lr}-kv{num_kv_pairs}",
+ logger=LoggerConfig(
+ project_name="zoology",
+ entity="hazy-research"
+ )
+
+ )
+ configs.append(config)
\ No newline at end of file
diff --git a/zoology/mixers/based.py b/zoology/mixers/based.py
index 4f6ee3e..76b35e3 100644
--- a/zoology/mixers/based.py
+++ b/zoology/mixers/based.py
@@ -69,11 +69,11 @@ def forward_mem_save(self, x: torch.Tensor) -> torch.Tensor:
-> Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
"""
# Slow but memory-saving way to compute 2nd-order terms; how do w/o outer-product first?
- x2 = oe.contract('...m,...n->...mn', x, x) / self.input_dim
+ x2 = oe.contract('...m,...n->...mn', x, x) / self.rd
x2d = torch.diagonal(x2, dim1=-2, dim2=-1) / self.r2
x2 = x2[..., self.tril_indices[0], self.tril_indices[1]]
x = torch.cat([torch.ones(x[..., :1].shape).to(x.device),
- x / self.rd, x2d, x2], dim=-1)
+ x / self.rrd, x2d, x2], dim=-1)
return x
diff --git a/zoology/mixers/rwkv.py b/zoology/mixers/rwkv.py
index ec2429e..b4d7c70 100644
--- a/zoology/mixers/rwkv.py
+++ b/zoology/mixers/rwkv.py
@@ -38,9 +38,10 @@ def backward(ctx, grad_output):
# # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
-wkv_cuda = load(name="wkv", sources=
- ["/var/cr05_data/sim_data/code/petting-zoo/src/models/mixers/cuda/wkv_op.cpp",
- "/var/cr05_data/sim_data/code/petting-zoo/src/models/mixers/cuda/wkv_cuda.cu"],
+dir_path = os.path.dirname(os.path.realpath(__file__))
+wkv_cuda = load(name="wkv", sources=[
+ os.path.join(dir_path, "./rwkv/v4/wkv_op.cpp"),
+ os.path.join(dir_path, "./rwkv/v4/wkv_cuda.cu")],
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
class WKV(torch.autograd.Function):
diff --git a/zoology/mixers/rwkv/v4/wkv_cuda.cu b/zoology/mixers/rwkv/v4/wkv_cuda.cu
new file mode 100644
index 0000000..a4522cb
--- /dev/null
+++ b/zoology/mixers/rwkv/v4/wkv_cuda.cu
@@ -0,0 +1,125 @@
+#include
+#include
+
+#define MIN_VALUE (-1e38)
+
+template
+__global__ void kernel_forward(const int B, const int T, const int C,
+ const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
+ F *__restrict__ const _y) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ F *__restrict__ const y = _y + _offset;
+
+ F p = 0, q = 0, o = MIN_VALUE;
+ // p and q are running sums divided by exp(o) (to avoid overflows)
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+
+ F no = max(o, u + k[ii]);
+ F A = exp(o - no);
+ F B = exp(u + k[ii] - no);
+ y[ii] = (A * p + B * v[ii]) / (A * q + B);
+
+ no = max(w + o, k[ii]);
+ A = exp(w + o - no);
+ B = exp(k[ii] - no);
+ p = A * p + B * v[ii];
+ q = A * q + B;
+ o = no;
+ }
+}
+
+template
+__global__ void kernel_backward(const int B, const int T, const int C,
+ const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
+ F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ const F *__restrict__ const gy = _gy + _offset;
+
+ F *__restrict__ const gk = _gk + _offset;
+ F *__restrict__ const gv = _gv + _offset;
+
+ F y[Tmax], z[Tmax], zexp[Tmax];
+
+ F gw = 0, gu = 0;
+ F p = 0, q = 0;
+ F dpdw = 0, dqdw = 0;
+ F o = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ F no = max(o, k[ii] + u);
+ F A = exp(o - no);
+ F B = exp(k[ii] + u - no);
+
+ F num = A * p + B * v[ii];
+ F iden = 1 / (A * q + B);
+
+ y[i] = num * iden;
+ z[i] = iden;
+ zexp[i] = k[ii] + u - no;
+
+ gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
+ gu += gy[ii] * (v[ii] - y[i]) * B * iden;
+
+ no = max(w + o, k[ii]);
+ A = exp(w + o - no);
+ B = exp(k[ii] - no);
+ dpdw = A * (p + dpdw);
+ dqdw = A * (q + dqdw);
+ p = A * p + B * v[ii];
+ q = A * q + B;
+ o = no;
+ }
+
+ F gp = 0, gq = 0;
+ o = MIN_VALUE;
+ for (int i = T - 1; i >= 0; i--) {
+ const int ii = i * C;
+ F A = gy[ii] * z[i] * exp(zexp[i]);
+ F B = exp(k[ii] + o);
+ gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
+ gv[ii] = A + B * gp;
+
+ F no = max(w + o, zexp[i] - k[ii] - u);
+ A = exp(w + o - no);
+ B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
+ gp = A * gp + B;
+ gq = A * gq - B * y[i];
+ o = no;
+ }
+
+ // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
+ const int _offsetBC = _b * C + _c;
+ _gw[_offsetBC] += gw * _w[_c];
+ _gu[_offsetBC] += gu;
+}
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward<<>>(B, T, C, w, u, k, v, y);
+}
+
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
+}
\ No newline at end of file
diff --git a/zoology/mixers/rwkv/v4/wkv_op.cpp b/zoology/mixers/rwkv/v4/wkv_op.cpp
new file mode 100644
index 0000000..e59a515
--- /dev/null
+++ b/zoology/mixers/rwkv/v4/wkv_op.cpp
@@ -0,0 +1,21 @@
+#include
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
+
+void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
+ cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr());
+}
+void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
+ cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr());
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward", &forward, "wkv forward");
+ m.def("backward", &backward, "wkv backward");
+}
+
+TORCH_LIBRARY(wkv, m) {
+ m.def("forward", forward);
+ m.def("backward", backward);
+}
\ No newline at end of file
diff --git a/zoology/mixers/rwkv/v5/wkv5_cuda.cu b/zoology/mixers/rwkv/v5/wkv5_cuda.cu
new file mode 100644
index 0000000..18bc730
--- /dev/null
+++ b/zoology/mixers/rwkv/v5/wkv5_cuda.cu
@@ -0,0 +1,202 @@
+#include
+#include
+#include "ATen/ATen.h"
+typedef at::BFloat16 bf16;
+
+template
+__global__ void kernel_forward(const int B, const int T, const int C, const int H,
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
+ F *__restrict__ const _y)
+{
+ const int b = blockIdx.x / H;
+ const int h = blockIdx.x % H;
+ const int i = threadIdx.x;
+ _w += h*_N_;
+ _u += h*_N_;
+
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
+ float state[_N_] = {0};
+
+ __syncthreads();
+ w[i] = _w[i];
+ u[i] = float(_u[i]);
+ __syncthreads();
+
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
+ {
+ __syncthreads();
+ r[i] = float(_r[t]);
+ k[i] = float(_k[t]);
+ __syncthreads();
+
+ const float v = float(_v[t]);
+ float y = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j+=4)
+ {
+ const float4& r_ = (float4&)(r[j]);
+ const float4& k_ = (float4&)(k[j]);
+ const float4& w_ = (float4&)(w[j]);
+ const float4& u_ = (float4&)(u[j]);
+ float4& s = (float4&)(state[j]);
+ float4 x;
+
+ x.x = k_.x * v;
+ x.y = k_.y * v;
+ x.z = k_.z * v;
+ x.w = k_.w * v;
+
+ y += r_.x * (u_.x * x.x + s.x);
+ y += r_.y * (u_.y * x.y + s.y);
+ y += r_.z * (u_.z * x.z + s.z);
+ y += r_.w * (u_.w * x.w + s.w);
+
+ s.x = s.x * w_.x + x.x;
+ s.y = s.y * w_.y + x.y;
+ s.z = s.z * w_.z + x.z;
+ s.w = s.w * w_.w + x.w;
+ }
+ _y[t] = F(y);
+ }
+}
+
+template
+__global__ void kernel_backward(const int B, const int T, const int C, const int H,
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
+ F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
+{
+ const int b = blockIdx.x / H;
+ const int h = blockIdx.x % H;
+ const int i = threadIdx.x;
+ _w += h*_N_;
+ _u += h*_N_;
+ __w += h*_N_;
+
+ __shared__ float w_[_N_], u_[_N_];
+ __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_];
+ __syncthreads();
+ w_[i] = _w[i];
+ u_[i] = float(_u[i]);
+ __syncthreads();
+
+ const float w = w_[i];
+ const float ww = __w[i];
+ const float u = u_[i];
+
+ float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
+
+ float gw = 0, gu = 0;
+ const int t000 = b*T*C + h*_N_ + i;
+ const int t111 = (b+1)*T*C + h*_N_ + i;
+ const int t222 = t111 - 2*C;
+
+ for (int t = t000; t < t111; t += C)
+ {
+ __syncthreads();
+ v[i] = float(_v[t]);
+ gy[i] = float(_gy[t]);
+ __syncthreads();
+
+ const float k = float(_k[t]);
+ float gr = 0, gu_ = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = state[j];
+ float x = k * v[j];
+
+ gr += (u * x + s) * gy[j];
+ gu_ += x * gy[j];
+ s = s * w + x;
+ }
+ _gr[t] = F(gr);
+ gu += float(_r[t]) * gu_;
+ }
+ _gu[b*C + h*_N_ + i] = F(gu);
+
+ for (int t = t000; t < t222; t += C)
+ {
+ __syncthreads();
+ v[i] = float(_v[t]);
+ gy[i] = float(_gy[t + 2*C]);
+ __syncthreads();
+
+ const float k = float(_k[t]);
+ float gw_ = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = saaaa[j];
+ float& s2 = sbbbb[j];
+ float x = k * v[j];
+
+ float tmp = w * (x + s);
+ s = tmp;
+ s2 = tmp + w * s2;
+ gw_ += s2 * gy[j];
+ }
+ gw += float(_r[t + 2*C]) * gw_;
+ }
+ _gw[b*C + h*_N_ + i] = F(ww * gw);
+
+ for (int t = t111 - C; t >= t000; t -= C)
+ {
+ __syncthreads();
+ v[i] = float(_v[t]);
+ gy[i] = float(_gy[t]);
+ __syncthreads();
+
+ const float rr = float(_r[t]);
+ float gk = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = scccc[j];
+ float x = rr * gy[j];
+
+ gk += (u * x + s) * v[j];
+ s = x + s * w;
+ }
+ _gk[t] = F(gk);
+ }
+
+ for (int t = t111 - C; t >= t000; t -= C)
+ {
+ __syncthreads();
+ r[i] = float(_r[t]);
+ k[i] = float(_k[t]);
+ __syncthreads();
+
+ const float gyy = float(_gy[t]);
+ float gv = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = sdddd[j];
+ float x = gyy * r[j];
+
+ gv += (u_[j] * x + s) * k[j];
+ s = x + s * w_[j];
+ }
+ _gv[t] = F(gv);
+ }
+}
+
+void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
+{
+ assert(H*_N_ == C);
+ assert(_N_%4 == 0);
+ kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y);
+}
+
+void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
+{
+ assert(H*_N_ == C);
+ assert(_N_%4 == 0);
+ kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
+}
\ No newline at end of file
diff --git a/zoology/mixers/rwkv/v5/wkv5_op.cpp b/zoology/mixers/rwkv/v5/wkv5_op.cpp
new file mode 100644
index 0000000..aef76a4
--- /dev/null
+++ b/zoology/mixers/rwkv/v5/wkv5_op.cpp
@@ -0,0 +1,22 @@
+#include
+#include "ATen/ATen.h"
+typedef at::BFloat16 bf16;
+
+void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
+void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
+
+void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
+ cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
+}
+void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
+ cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr());
+}
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward", &forward, "wkv5 forward");
+ m.def("backward", &backward, "wkv5 backward");
+}
+
+TORCH_LIBRARY(wkv5, m) {
+ m.def("forward", forward);
+ m.def("backward", backward);
+}
\ No newline at end of file
diff --git a/zoology/mixers/rwkv5.py b/zoology/mixers/rwkv5.py
new file mode 100644
index 0000000..5aa3972
--- /dev/null
+++ b/zoology/mixers/rwkv5.py
@@ -0,0 +1,616 @@
+########################################################################################################
+# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
+########################################################################################################
+
+import os, math, gc, importlib
+import torch
+# torch._C._jit_set_profiling_executor(True)
+# torch._C._jit_set_profiling_mode(True)
+import torch.nn as nn
+from torch.nn import functional as F
+import pytorch_lightning as pl
+from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
+from pytorch_lightning.strategies import DeepSpeedStrategy
+if importlib.util.find_spec('deepspeed'):
+ import deepspeed
+ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
+
+# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
+
+try:
+ print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
+except:
+ os.environ["RWKV_MY_TESTING"] = ''
+
+def __nop(ob):
+ return ob
+
+
+MyModule = nn.Module
+MyFunction = __nop
+os.environ['RWKV_JIT_ON'] = '1'
+if os.environ["RWKV_JIT_ON"] == "1":
+ MyModule = torch.jit.ScriptModule
+ MyFunction = torch.jit.script_method
+
+os.environ['RWKV_FLOAT_MODE'] = 'bf16'
+
+########################################################################################################
+# CUDA Kernel
+########################################################################################################
+
+
+from torch.utils.cpp_extension import load
+
+# HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"])
+HEAD_SIZE = 64
+dir_path = os.path.dirname(os.path.realpath(__file__))
+wkv5_cuda = load(name="wkv5", sources=[
+ os.path.join(dir_path, "./rwkv/v5/wkv5_op.cpp"),
+ os.path.join(dir_path, "./rwkv/v5/wkv5_cuda.cu")],
+ verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
+
+class WKV_5(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, B, T, C, H, r, k, v, w, u):
+ with torch.no_grad():
+ assert r.dtype == torch.bfloat16
+ assert k.dtype == torch.bfloat16
+ assert v.dtype == torch.bfloat16
+ assert w.dtype == torch.bfloat16
+ assert u.dtype == torch.bfloat16
+ assert HEAD_SIZE == C // H
+ ctx.B = B
+ ctx.T = T
+ ctx.C = C
+ ctx.H = H
+ assert r.is_contiguous()
+ assert k.is_contiguous()
+ assert v.is_contiguous()
+ assert w.is_contiguous()
+ assert u.is_contiguous()
+ ew = (-torch.exp(w.float())).contiguous()
+ eew = (torch.exp(ew)).contiguous()
+ ctx.save_for_backward(r, k, v, eew, ew, u)
+ y = torch.empty((B, T, C), device=r.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
+ wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
+ return y
+
+ @staticmethod
+ def backward(ctx, gy):
+ with torch.no_grad():
+ assert gy.dtype == torch.bfloat16
+ B = ctx.B
+ T = ctx.T
+ C = ctx.C
+ H = ctx.H
+ assert gy.is_contiguous()
+ r, k, v, eew, ew, u = ctx.saved_tensors
+ gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
+ gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
+ gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
+ gw = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
+ gu = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
+ wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
+ gw = torch.sum(gw, 0).view(H, C//H)
+ gu = torch.sum(gu, 0).view(H, C//H)
+ return (None, None, None, None, gr, gk, gv, gw, gu)
+
+def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
+ return WKV_5.apply(B, T, C, H, r, k, v, w, u)
+
+########################################################################################################
+
+class RWKVTimeMixer(MyModule):
+ def __init__(self, l_max: int, d_model: int = 512, n_layer: int=12, layer_idx: int=-1):
+
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.ctx_len = l_max
+ self.d_model = d_model
+ dim_att = d_model
+
+ # self.head_size = args.head_size_a
+ self.head_size = 64
+ assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a
+ self.n_head = dim_att // self.head_size
+ assert dim_att % self.n_head == 0
+ # self.head_size_divisor = args.head_size_divisor
+ self.head_size_divisor = 8
+
+ with torch.no_grad():
+ ratio_0_to_1 = layer_idx / (n_layer - 1) # 0 to 1
+ ratio_1_to_almost0 = 1.0 - (layer_idx / n_layer) # 1 to ~0
+ ddd = torch.ones(1, 1, d_model)
+ for i in range(d_model):
+ ddd[0, 0, i] = i / d_model
+
+ # fancy time_mix
+ self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
+ self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
+ self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+ self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+
+ # fancy time_decay
+ decay_speed = torch.ones(dim_att)
+ for n in range(dim_att):
+ decay_speed[n] = -6 + 5 * (n / (dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+ self.time_decay = nn.Parameter(decay_speed.reshape(self.n_head, self.head_size))
+ # print(layer_idx, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
+
+ tmp = torch.zeros(dim_att)
+ for n in range(dim_att):
+ zigzag = ((n + 1) % 3 - 1) * 0.1
+ tmp[n] = ratio_0_to_1 * (1 - (n / (dim_att - 1))) + zigzag
+
+ self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
+
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+ self.receptance = nn.Linear(d_model, dim_att, bias=False)
+ self.key = nn.Linear(d_model, dim_att, bias=False)
+
+ self.value = nn.Linear(d_model, dim_att, bias=False)
+ self.output = nn.Linear(dim_att, d_model, bias=False)
+ self.gate = nn.Linear(d_model, dim_att, bias=False)
+ self.ln_x = nn.GroupNorm(self.n_head, dim_att)
+
+ @MyFunction
+ def jit_func(self, x):
+ B, T, C = x.size()
+
+ xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
+ xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
+ xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
+
+ r = self.receptance(xr)
+ k = self.key(xk)
+ v = self.value(xv)
+ g = F.silu(self.gate(xg))
+
+ return r, k, v, g
+
+ @MyFunction
+ def jit_func_2(self, x, g):
+ B, T, C = x.size()
+ x = x.view(B * T, C)
+
+ x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
+ x = self.output(x * g)
+ return x
+
+ def forward(self, x: torch.Tensor):
+ B, T, C = x.size()
+ H = self.n_head
+
+ r, k, v, g = self.jit_func(x)
+
+ # Convert tensors to bfloat16 for CUDA operation
+ r_bf16 = r.to(dtype=torch.bfloat16)
+ k_bf16 = k.to(dtype=torch.bfloat16)
+ v_bf16 = v.to(dtype=torch.bfloat16)
+ w_bf16 = self.time_decay.to(dtype=torch.bfloat16)
+ u_bf16 = self.time_faaaa.to(dtype=torch.bfloat16)
+
+ x = RUN_CUDA_RWKV5(B, T, C, H, r_bf16, k_bf16, v_bf16, w_bf16, u_bf16)
+
+ # Convert back to Float for subsequent operations
+ x_float = x.to(dtype=torch.float32)
+
+ return self.jit_func_2(x_float, g)
+
+########################################################################################################
+
+class RWKV_ChannelMix(MyModule):
+ def __init__(self, args, d_model=512, n_layer=12, layer_idx=1):
+ super().__init__()
+ self.args = args
+ self.layer_idx = layer_idx
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+
+ with torch.no_grad(): # fancy init of time_mix
+ ratio_1_to_almost0 = 1.0 - (layer_idx / n_layer) # 1 to ~0
+ ddd = torch.ones(1, 1, d_model)
+ for i in range(d_model):
+ ddd[0, 0, i] = i / d_model
+ self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
+ self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
+
+ self.key = nn.Linear(d_model, self.dim_ffn, bias=False)
+ self.receptance = nn.Linear(d_model, d_model, bias=False)
+ self.value = nn.Linear(self.dim_ffn, d_model, bias=False)
+
+ @MyFunction
+ def forward(self, x):
+ xx = self.time_shift(x)
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
+ k = self.key(xk)
+ k = torch.relu(k) ** 2
+ kv = self.value(k)
+ return torch.sigmoid(self.receptance(xr)) * kv
+
+class MishGLU(MyModule):
+ def __init__(self, args, d_model, layer_idx, n_layer):
+ super().__init__()
+ self.args = args
+ self.layer_idx = layer_idx
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+
+ with torch.no_grad():
+ ratio_1_to_almost0 = 1.0 - (layer_idx / n_layer)
+
+ x = torch.ones(1, 1,d_model)
+ for i in range(d_model):
+ x[0, 0, i] = i /d_model
+
+ self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
+ self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
+ self.aa = nn.Linear(d_model, args.dim_ffn, bias=False)
+ self.bb = nn.Linear(d_model, args.dim_ffn, bias=False)
+ self.value = nn.Linear(args.dim_ffn,d_model, bias=False)
+
+ @MyFunction
+ def forward(self, x):
+ xx = self.time_shift(x)
+ xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
+ xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
+ a = self.aa(xa)
+ b = self.bb(xb)
+ return self.value(a * F.mish(b))
+
+########################################################################################################
+# The RWKV Model with our blocks
+########################################################################################################
+
+
+class Block(nn.Module,):
+ def __init__(self, args, d_model, layer_idx):
+ super().__init__()
+ self.args = args
+ self.layer_idx = layer_idx
+
+ self.ln1 = nn.LayerNorm(d_model)
+ self.ln2 = nn.LayerNorm(d_model)
+
+ if self.layer_idx == 0:
+ self.ln0 = nn.LayerNorm(d_model)
+ if args.my_pos_emb > 0:
+ self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,d_model)))
+ self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,d_model)))
+
+ if self.layer_idx == 0 and self.args.pre_ffn > 0:
+ self.ffnPre = RWKV_ChannelMix(args, 0)
+ else:
+ self.att = RWKVTimeMixer(args, layer_idx)
+
+ if 'g' in os.environ["RWKV_MY_TESTING"]:
+ self.ffn = MishGLU(args, layer_idx)
+ else:
+ self.ffn = RWKV_ChannelMix(args, layer_idx)
+
+ if args.tiny_att_dim > 0 and self.layer_idx == args.tiny_att_layer:
+ self.tiny_ln = nn.LayerNorm(d_model)
+ self.tiny_q = nn.Linear(d_model, args.tiny_att_dim, bias=False)
+ self.tiny_k = nn.Linear(d_model, args.tiny_att_dim, bias=False)
+ self.tiny_v = nn.Linear(d_model,d_model, bias=False)
+ self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
+
+ if args.dropout > 0:
+ self.drop0 = nn.Dropout(p = args.dropout)
+ self.drop1 = nn.Dropout(p = args.dropout)
+
+ def forward(self, x, x_emb=None):
+ args = self.args
+ B, T, C = x.size()
+ if self.layer_idx == 0:
+ x = self.ln0(x)
+ if args.my_pos_emb > 0:
+ pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
+ x = x + pos_emb
+
+ if self.args.dropout == 0:
+ if self.layer_idx == 0 and args.pre_ffn > 0:
+ x = x + self.ffnPre(self.ln1(x))
+ else:
+ x = x + self.att(self.ln1(x))
+ x = x + self.ffn(self.ln2(x))
+ else:
+ if self.layer_idx == 0 and args.pre_ffn > 0:
+ x = self.drop0(x + self.ffnPre(self.ln1(x)))
+ else:
+ x = self.drop0(x + self.att(self.ln1(x)))
+ x = self.drop1(x + self.ffn(self.ln2(x)))
+
+ if args.tiny_att_dim > 0 and self.layer_idx == args.tiny_att_layer:
+ xx = self.tiny_ln(x)
+ q = self.tiny_q(xx)[:, :T, :]
+ k = self.tiny_k(xx)[:, :T, :]
+ c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
+ c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
+ x = x + c @ self.tiny_v(x_emb)
+ return x
+
+
+class L2Wrap(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, loss, y):
+ ctx.save_for_backward(y)
+ return loss
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ y = ctx.saved_tensors[0]
+ # to encourage the logits to be close to 0
+ factor = 1e-4 / (y.shape[0] * y.shape[1])
+ maxx, ids = torch.max(y, -1, keepdim=True)
+ gy = torch.zeros_like(y)
+ gy.scatter_(-1, ids, maxx * factor)
+ return (grad_output, gy)
+
+
+class RWKV(pl.LightningModule):
+ def __init__(self, args, vocab_size, d_model, n_layer):
+ super().__init__()
+ self.args = args
+ if not hasattr(args, 'dim_att'):
+ args.dim_att = args.n_embd
+ if not hasattr(args, 'dim_ffn'):
+ args.dim_ffn = args.n_embd * 4
+ if not hasattr(args, 'tiny_att_layer'):
+ args.tiny_att_layer = -1
+ if not hasattr(args, 'tiny_att_dim'):
+ args.tiny_att_dim = -1
+ assert args.n_embd % 32 == 0
+ assert args.dim_att % 32 == 0
+ assert args.dim_ffn % 32 == 0
+
+ self.emb = nn.Embedding(vocab_size,d_model)
+
+ self.blocks = nn.ModuleList([Block(args, i) for i in range(n_layer)])
+
+ self.ln_out = nn.LayerNorm(d_model)
+ self.head = nn.Linear(d_model, vocab_size, bias=False)
+
+ if args.head_qk > 0:
+ self.head_q = nn.Linear(d_model, args.head_qk, bias=False)
+ self.head_k = nn.Linear(d_model, args.head_qk, bias=False)
+ self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
+ if args.dropout > 0:
+ self.drop0 = nn.Dropout(p = args.dropout)
+
+ def configure_optimizers(self):
+ args = self.args
+
+ lr_decay = set()
+ lr_1x = set()
+ lr_2x = set()
+ lr_3x = set()
+ for n, p in self.named_parameters():
+ if ("time_mix" in n) and (args.layerwise_lr > 0):
+ if args.my_pile_stage == 2:
+ lr_2x.add(n)
+ else:
+ lr_1x.add(n)
+ elif ("time_decay" in n) and (args.layerwise_lr > 0):
+ if args.my_pile_stage == 2:
+ lr_3x.add(n)
+ else:
+ lr_2x.add(n)
+ elif ("time_faaaa" in n) and (args.layerwise_lr > 0):
+ if args.my_pile_stage == 2:
+ lr_2x.add(n)
+ else:
+ lr_1x.add(n)
+ elif ("time_first" in n) and (args.layerwise_lr > 0):
+ lr_3x.add(n)
+ elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0):
+ lr_decay.add(n)
+ else:
+ lr_1x.add(n)
+
+ lr_decay = sorted(list(lr_decay))
+ lr_1x = sorted(list(lr_1x))
+ lr_2x = sorted(list(lr_2x))
+ lr_3x = sorted(list(lr_3x))
+ # print('decay', lr_decay)
+ # print('1x', lr_1x)
+ # print('2x', lr_2x)
+ # print('3x', lr_3x)
+ param_dict = {n: p for n, p in self.named_parameters()}
+
+ if args.layerwise_lr > 0:
+ if args.my_pile_stage == 2:
+ optim_groups = [
+ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
+ {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
+ {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
+ ]
+ else:
+ optim_groups = [
+ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
+ {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
+ {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
+ ]
+ else:
+ optim_groups = [{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}]
+
+ if args.weight_decay > 0:
+ optim_groups += [{"params": [param_dict[n] for n in lr_decay], "weight_decay": args.weight_decay, "my_lr_scale": 1.0}]
+ if self.deepspeed_offload:
+ return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=True, amsgrad=False)
+ return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=True, amsgrad=False)
+ else:
+ if self.deepspeed_offload:
+ return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
+ return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
+ # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
+
+ @property
+ def deepspeed_offload(self) -> bool:
+ strategy = self.trainer.strategy
+ if isinstance(strategy, DeepSpeedStrategy):
+ cfg = strategy.config["zero_optimization"]
+ return cfg.get("offload_optimizer") or cfg.get("offload_param")
+ return False
+
+ def forward(self, idx):
+ args = self.args
+ B, T = idx.size()
+ assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
+
+ x = self.emb(idx)
+ x_emb = x
+
+ if args.dropout > 0:
+ x = self.drop0(x)
+ if args.tiny_att_dim > 0:
+ for block in self.blocks:
+ if args.grad_cp == 1:
+ x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
+ else:
+ x = block(x, x_emb)
+ else:
+ for block in self.blocks:
+ if args.grad_cp == 1:
+ x = deepspeed.checkpointing.checkpoint(block, x)
+ else:
+ x = block(x)
+
+ x = self.ln_out(x)
+
+ if args.head_qk > 0:
+ q = self.head_q(x)[:, :T, :]
+ k = self.head_k(x)[:, :T, :]
+ c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
+ c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
+
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size)
+ elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
+
+ x = self.head(x) + c
+ else:
+ x = self.head(x)
+
+ return x
+
+ def training_step(self, batch, batch_idx):
+ args = self.args
+ if args.my_qa_mask != 1:
+ idx, targets = batch
+ logits = self(idx)
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+ # if '0' in os.environ["RWKV_MY_TESTING"]:
+ # print('logits', logits)
+ # torch.set_printoptions(threshold=10000)
+ # print('idx', idx)
+ # exit(0)
+ else:
+ idx, targets, mask = batch
+ mask = mask.view(-1)
+ sum_mask = torch.sum(mask).item()
+ # if sum_mask == 0:
+ # return torch.tensor([0.0], requires_grad=True)
+
+ logits = self(idx)
+ if sum_mask == mask.shape[0]:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+ # print('rank', self.global_rank, 'loss', loss.item())
+ else:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
+ # loss_raw = loss
+ loss = torch.sum(loss * mask) / sum_mask
+
+ # torch.set_printoptions(threshold=10000)
+ # if True: #self.global_rank == 1:
+ # tmp = ''
+ # sss = 0
+ # ccc = 0
+ # for i in range(mask.shape[0]):
+ # if mask[i] > 0:
+ # tmp += str(idx.view(-1)[i].item()) + ','
+ # sss += loss_raw.view(-1)[i].float().item()
+ # ccc += 1
+ # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
+
+ return L2Wrap.apply(loss, logits)
+
+ def training_step_end(self, batch_parts):
+ if pl.__version__[0]!='2':
+ all = self.all_gather(batch_parts)
+ if self.trainer.is_global_zero:
+ self.trainer.my_loss_all = all
+
+ def generate_init_weight(self):
+ print(
+ f"""
+############################################################################
+#
+# Init model weight (slow for large models)...
+#
+############################################################################
+"""
+ )
+ m = {}
+ for n in self.state_dict():
+ p = self.state_dict()[n]
+ shape = p.shape
+
+ gain = 1.0
+ scale = 1.0
+ if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
+ if 'ln_x.weight' in n:
+ layer_scale = (1+int(n.split('.')[1])) / self.n_layer
+ m[n] = (p * 0.0) + (layer_scale ** 0.7)
+ else:
+ m[n] = p
+ else:
+ if n == "emb.weight":
+ scale = -1 * self.args.lr_init
+ else:
+ if shape[0] > shape[1]:
+ gain = math.sqrt(shape[0] / shape[1])
+
+ zero = [".att.output.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']
+
+ for kk in zero:
+ if kk in n:
+ scale = 0
+ if n == "head.weight":
+ scale = 0.5
+ if "head_k." in n:
+ scale = 0.1
+ if "head_q." in n:
+ scale = 0
+
+ print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
+
+ if self.args.accelerator.upper() == "GPU":
+ m[n] = torch.empty((shape[0], shape[1]), device="cuda")
+ else:
+ m[n] = torch.empty((shape[0], shape[1]))
+
+ if scale == 0:
+ nn.init.zeros_(m[n])
+ elif scale < 0:
+ nn.init.uniform_(m[n], a=scale, b=-scale)
+ else:
+ nn.init.orthogonal_(m[n], gain=gain * scale)
+
+ m[n] = m[n].cpu()
+ if os.environ["RWKV_FLOAT_MODE"] == "fp16":
+ m[n] = m[n].half()
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+ m[n] = m[n].bfloat16()
+
+ # if n == "emb.weight":
+ # print(m[n])
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ return m