-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce JAX OLS Class #790
base: master
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
class OLSJAX: | ||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstrings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM up to the docstrings ;)
benchmarks/gpu_pyfixest_errors.ipynb
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file should probably be deleted in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this probably supercedes the numpy API I requested a while ago, provided we keep the methods+signatures consistent with statsmodels on the base OLSJAX
class and wrap them appropriately through pf.feols
.
Yes, that was exactly the idea! I think some minor refactoring of the different OLS classes will be required (for example, we might have to move all the method calls from |
self.Y, self.X = self.demean( | ||
Y=self.Y_orignal, X=self.X_orignal, fe=self.fe, weights=self.weights | ||
) | ||
self.beta = jnp.linalg.lstsq(self.X, self.Y)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious whether using lx.linear_solve
gets speed-gains here; example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Speed gains look superb, how do you feel about adding lineax as an optional dependency to the JAX env?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm game; Kidger's stuff is very high quality and stable, so I doubt that it will break anytime soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'll test out lineax's solver on this PR over the weekend; expect a commit
pyfixest/estimation/jax/OLSJAX.py
Outdated
@property | ||
def residuals(self): | ||
self.uhat = self.Y - self.X @ self.beta | ||
return self.uhat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if residuals
and scores
methods are simply setting the properties on this class, do they need return statements? alternatively, if they do return arrays, set them inside fit
as self.uhat = self.residuals
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree with you, I think simply doing something as
@property
def residuals(self):
return self.Y - self.X @ self.beta
and then potentially setting an uhat
attribute in fit()
would be cleaner? Optimally we can completely drop the uhat
attribute, I don't think it is needed per se (except for some compatibility with the current implementation).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense. re: not setting it as an attribute - agree; i err on the side of not attaching too many things to the model class because it ends up being a memory-consumption headache. Given that the benefits of jax are primarily for big data ™️, making the model class as lean as possible by not attaching any data (that's already built into pf.feols
right) would make sense to me
2c92c3d implements a POC integration of the OLSJAX class into the pyfixest stream. There are still plenty of things to polish and implement (vcov options, feols API, multicollinearity checks etc). But it should provide you with an idea how we could generally integrate "proper numpy APIs" into the current pyfixest internal design (irrespective of JAX integration). I am not sure if what I put together is a masterclass in software engineering, so I'd love to hear your opinions @apoorvalal @juanitorduz . Moving all other estimation classes to such a design would be a bit of work - as a benefit, we'd get a portable estimation classes with numpy / JAX APIs. I still have to ponder about this a bit. For now, I will open a new PR that simply creates the JAX module to unblock @iamlemec's PR to add his performance-improved implementation of the demeaning algo, as merging the OLSJAX class and an appropriate interface will require a little bit more work & thinking. |
Personally, I think it's better to aim for "good" instead of "perfect" and work on smaller iterations. So, I do not have any objections to merging an initial version and creating issues to work on iterations. Just food for though ;) (we just need to be conscious of the technical debt, but I think at his earlier stage the risk is low) |
Oh I'm glad to hear! For me there's not much worse than a never ending PR that eventually needs multiple rebases to align with the main branch. Horror! 😄 So I'll proceed step by step:
After this is set up, we can consider if the wrapper class actually looks decent, or if it will be hard to maintain in the long term and iterate? |
Idea: implement all computational steps in pure JAX. All non-directly implemented methods can be inherited from Feols (i.e. tidy, inference, ritest, etc unless directly implemented).