-
Notifications
You must be signed in to change notification settings - Fork 4.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Contributing to the jax version. #1972
Comments
Thanks! See #1825 |
Since some chapters of the book are out of sync with the master branch, which version should be implented in JAX? For example, the code in Linear Regression from Scratch: def synthetic_data(w, b, num_examples): #@save
"""Generate y = Xw + b + noise."""
X = np.random.normal(0, 1, (num_examples, len(w)))
y = np.dot(X, w) + b
y += np.random.normal(0, 0.01, y.shape)
return X, y.reshape((-1, 1))
true_w = np.array([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000) Is different from the corresponding code in %%tab all
class LinearRegressionScratch(d2l.Module): #@save
def __init__(self, num_inputs, lr, sigma=0.01):
super().__init__()
self.save_hyperparameters()
if tab.selected('mxnet'):
self.w = d2l.normal(0, sigma, (num_inputs, 1))
self.b = d2l.zeros(1)
self.w.attach_grad()
self.b.attach_grad()
if tab.selected('pytorch'):
self.w = d2l.normal(0, sigma, (num_inputs, 1), requires_grad=True)
self.b = d2l.zeros(1, requires_grad=True)
if tab.selected('tensorflow'):
w = tf.random.normal((num_inputs, 1), mean=0, stddev=0.01)
b = tf.zeros(1)
self.w = tf.Variable(w, trainable=True)
self.b = tf.Variable(b, trainable=True) In this case, should I manually modify the one from the master branch like: if tab.selected('tensorflow'):
w = tf.random.normal((num_inputs, 1), mean=0, stddev=0.01)
b = tf.zeros(1)
self.w = tf.Variable(w, trainable=True)
self.b = tf.Variable(b, trainable=True)
+ if tab.selected('jax'):
+ key = random.PRNGKey(42)
+ self.w = random.normal(key, (num_inputs, 1)) * 0.01 + 0
+ self.b = jnp.zeros(1) Or perhaps making the D2L module work with JAX at first would be better. What do you suggest? @astonzhang |
Hey @AnirudhDagar, could you please merge I'm working on a JAX version for the v1 release of the book and have a few chapters ready for review. I just don't want to open any PRs until the Or should I give it a go? |
Hi @atgctg, thanks for your interest in JAX port. I've synced the branch. I'm almost done with chapter 3, didn't know that you were working on it as well. But feel free to raise a PR for the chapter. In the future to avoid duplication and two people working on the same thing, let's move ahead with a tracker. |
Thanks! I'll open a PR then. It would be great to standardize the API first, so other chapters can build on that. I would love to see your approach as well. |
Would you be so kind as to wait until this evening before sending a PR? I'd like to fix some CI issues first, which might affect Jax development. I'll let you know once that is fixed. Thanks! :) |
Of course! |
I want to contribute in the effort for the jax version, I see that there's already a branch call jax with ~20 commits but is quite behind the main branch.
Can you give me some more details on current jax effort and how I can contribute.
The text was updated successfully, but these errors were encountered: