Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contributing to the jax version. #1972

Open
ashutosh-dwivedi-e3502 opened this issue Nov 29, 2021 · 8 comments
Open

Contributing to the jax version. #1972

ashutosh-dwivedi-e3502 opened this issue Nov 29, 2021 · 8 comments

Comments

@ashutosh-dwivedi-e3502
Copy link

I want to contribute in the effort for the jax version, I see that there's already a branch call jax with ~20 commits but is quite behind the main branch.
Can you give me some more details on current jax effort and how I can contribute.

@astonzhang
Copy link
Member

Thanks! See #1825

@Roy-Kid
Copy link

Roy-Kid commented Feb 6, 2022

I also want to contribute to jax version and I pull the jax branch. But I only find there are a few changes in this branch, see:
image
Do I pull the right branch?

@ghost
Copy link

ghost commented May 2, 2022

Since some chapters of the book are out of sync with the master branch, which version should be implented in JAX?

For example, the code in Linear Regression from Scratch:

def synthetic_data(w, b, num_examples):  #@save
    """Generate y = Xw + b + noise."""
    X = np.random.normal(0, 1, (num_examples, len(w)))
    y = np.dot(X, w) + b
    y += np.random.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

true_w = np.array([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

Is different from the corresponding code in linear-regression-scratch.md:

%%tab all
class LinearRegressionScratch(d2l.Module):  #@save
    def __init__(self, num_inputs, lr, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()
        if tab.selected('mxnet'):
            self.w = d2l.normal(0, sigma, (num_inputs, 1))
            self.b = d2l.zeros(1)
            self.w.attach_grad()
            self.b.attach_grad()
        if tab.selected('pytorch'):
            self.w = d2l.normal(0, sigma, (num_inputs, 1), requires_grad=True)
            self.b = d2l.zeros(1, requires_grad=True)
        if tab.selected('tensorflow'):
            w = tf.random.normal((num_inputs, 1), mean=0, stddev=0.01)
            b = tf.zeros(1)
            self.w = tf.Variable(w, trainable=True)
            self.b = tf.Variable(b, trainable=True)

In this case, should I manually modify the one from the master branch like:

        if tab.selected('tensorflow'):
            w = tf.random.normal((num_inputs, 1), mean=0, stddev=0.01)
            b = tf.zeros(1)
            self.w = tf.Variable(w, trainable=True)
            self.b = tf.Variable(b, trainable=True)
+      if tab.selected('jax'):
+           key = random.PRNGKey(42)
+           self.w = random.normal(key, (num_inputs, 1)) * 0.01 + 0
+           self.b = jnp.zeros(1)

Or perhaps making the D2L module work with JAX at first would be better.

What do you suggest? @astonzhang

@atgctg
Copy link
Contributor

atgctg commented Aug 3, 2022

Hey @AnirudhDagar, could you please merge master into the jax branch?

I'm working on a JAX version for the v1 release of the book and have a few chapters ready for review. I just don't want to open any PRs until the jax branch is up to date with v1.

Or should I give it a go?

@AnirudhDagar
Copy link
Member

Hi @atgctg, thanks for your interest in JAX port. I've synced the branch. I'm almost done with chapter 3, didn't know that you were working on it as well. But feel free to raise a PR for the chapter. In the future to avoid duplication and two people working on the same thing, let's move ahead with a tracker.

@atgctg
Copy link
Contributor

atgctg commented Aug 3, 2022

Thanks! I'll open a PR then.

It would be great to standardize the API first, so other chapters can build on that. I would love to see your approach as well.

@AnirudhDagar
Copy link
Member

Would you be so kind as to wait until this evening before sending a PR? I'd like to fix some CI issues first, which might affect Jax development. I'll let you know once that is fixed. Thanks! :)

@atgctg
Copy link
Contributor

atgctg commented Aug 3, 2022

Of course!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants