Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce JAX OLS Class #790

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft

Introduce JAX OLS Class #790

wants to merge 8 commits into from

Conversation

s3alfisc
Copy link
Member

@s3alfisc s3alfisc commented Jan 14, 2025

  • Move all jax code to a dedicated JAX module.
  • Add a basic JAXOLS class.
from pyfixest.estimation.jax.OLSJAX import OLSJAX
import jax.numpy as jnp
from jax.random import PRNGKey
import jax

N = 1000
k = 3
X = jax.random.normal(PRNGKey(0), shape=(N, k))
beta = jnp.array([1, 2, 3])
Y = (X @ beta + jax.random.normal(PRNGKey(0), shape=(N,))).reshape(-1, 1)
fe = jax.random.randint(PRNGKey(0), minval=0, maxval=5, shape=(N,1))

ols = OLSJAX(Y = Y, X = X, fe = fe, vcov = "HC1")
ols.fit()
ols.tidy()

Idea: implement all computational steps in pure JAX. All non-directly implemented methods can be inherited from Feols (i.e. tidy, inference, ritest, etc unless directly implemented).

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@s3alfisc s3alfisc marked this pull request as draft January 14, 2025 20:48
@s3alfisc s3alfisc mentioned this pull request Jan 14, 2025
Comment on lines +10 to +11
class OLSJAX:
def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstrings?

Copy link
Contributor

@juanitorduz juanitorduz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM up to the docstrings ;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file should probably be deleted in this PR

Copy link
Member

@apoorvalal apoorvalal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this probably supercedes the numpy API I requested a while ago, provided we keep the methods+signatures consistent with statsmodels on the base OLSJAX class and wrap them appropriately through pf.feols.

@s3alfisc
Copy link
Member Author

provided we keep the methods+signatures consistent with statsmodels on the base OLSJAX class and wrap them appropriately through pf.feols.

Yes, that was exactly the idea! I think some minor refactoring of the different OLS classes will be required (for example, we might have to move all the method calls from FixestMulti into the get_fit() methods, though that's something I wanted to do anyways for the numpy API). Then I was thinking about adding a Interface class that inherits that will be called from FixestMulti but inherits from Feols and whose get_fit method is overwritten by the JAX method. Might be a bad design though? It for sure sounds more complex than it will be, I'll push a PR in the next days, then it will be easier to see what I'd like to do.

self.Y, self.X = self.demean(
Y=self.Y_orignal, X=self.X_orignal, fe=self.fe, weights=self.weights
)
self.beta = jnp.linalg.lstsq(self.X, self.Y)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious whether using lx.linear_solve gets speed-gains here; example

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speed gains look superb, how do you feel about adding lineax as an optional dependency to the JAX env?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm game; Kidger's stuff is very high quality and stable, so I doubt that it will break anytime soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll test out lineax's solver on this PR over the weekend; expect a commit

@property
def residuals(self):
self.uhat = self.Y - self.X @ self.beta
return self.uhat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if residuals and scores methods are simply setting the properties on this class, do they need return statements? alternatively, if they do return arrays, set them inside fit as self.uhat = self.residuals ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree with you, I think simply doing something as

    @property
    def residuals(self):
        return self.Y - self.X @ self.beta

and then potentially setting an uhat attribute in fit() would be cleaner? Optimally we can completely drop the uhat attribute, I don't think it is needed per se (except for some compatibility with the current implementation).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense. re: not setting it as an attribute - agree; i err on the side of not attaching too many things to the model class because it ends up being a memory-consumption headache. Given that the benefits of jax are primarily for big data ™️, making the model class as lean as possible by not attaching any data (that's already built into pf.feols right) would make sense to me

@s3alfisc
Copy link
Member Author

2c92c3d implements a POC integration of the OLSJAX class into the pyfixest stream. There are still plenty of things to polish and implement (vcov options, feols API, multicollinearity checks etc). But it should provide you with an idea how we could generally integrate "proper numpy APIs" into the current pyfixest internal design (irrespective of JAX integration).

I am not sure if what I put together is a masterclass in software engineering, so I'd love to hear your opinions @apoorvalal @juanitorduz . Moving all other estimation classes to such a design would be a bit of work - as a benefit, we'd get a portable estimation classes with numpy / JAX APIs. I still have to ponder about this a bit.

For now, I will open a new PR that simply creates the JAX module to unblock @iamlemec's PR to add his performance-improved implementation of the demeaning algo, as merging the OLSJAX class and an appropriate interface will require a little bit more work & thinking.

@s3alfisc s3alfisc changed the title Introduce JAX Module Introduce JAX OLS Class Jan 19, 2025
@juanitorduz
Copy link
Contributor

Personally, I think it's better to aim for "good" instead of "perfect" and work on smaller iterations. So, I do not have any objections to merging an initial version and creating issues to work on iterations. Just food for though ;) (we just need to be conscious of the technical debt, but I think at his earlier stage the risk is low)

@s3alfisc
Copy link
Member Author

Personally, I think it's better to aim for "good" instead of "perfect" and work on smaller iterations. So, I do not have any objections to merging an initial version and creating issues to work on iterations

Oh I'm glad to hear! For me there's not much worse than a never ending PR that eventually needs multiple rebases to align with the main branch. Horror! 😄

So I'll proceed step by step:

  • finalize the OLS class
  • write a wrapper class and make OLSJAX callable via feols() / FixestMulti
  • port the multicollinearity checks to JAX

After this is set up, we can consider if the wrapper class actually looks decent, or if it will be hard to maintain in the long term and iterate?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants