-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integration with probabilistic_model #277
Comments
Hi Tom, One interesting integration with cirkit could be rewriting your code such that it converts a tree-shaped PGM into our symbolic circuit representation. I think this can also be very useful to us to understand if we need to refine or extend our symbolic representation. I also found very nice how you deal with random variables where, e.g., you can name them. Right now, we only have a scope data structure where variables are instead non-negative integers, which I guess it's too low-level as a representation and instead would benefit from your approach. For the functional vs object-oriented design choice: I think our library benefits from a functional design as in my view it is a language and a compiler, and since it can be quite hard to implement we care about having code that we can check the correctness more easily. For instance, in our case this is done by having immutable objects that are always consistent up to initialization. Thanks |
Hey, The PGMs in my package get converted to networkx circuits for inference anyways. The nx circuits can then be compiled to jax circuits for 4000 times speed up in log-likelihood evaluation. What I got from your torch implementation that you also have a networkx like structure for edges, am I wrong? Regarding the scope data structure: I like your approach there. For the computational graph it is not really nescessary to know about metadata of a variable, the column index is enough. I actually copied your approach there to this datastructure (https://probabilistic-model.readthedocs.io/en/latest/autoapi/probabilistic_model/probabilistic_circuit/jax/inner_layer/index.html#probabilistic_model.probabilistic_circuit.jax.inner_layer.InputLayer). Regarding OOP vs Functional: JAX actually follows the same paradigm there. However, since it is not handy for python developers, the community wrote alot of wrappers around that (equinox, flax, etc.). These are not allowing mutation of objects in the classical sence. I strongly recommend giving that a look. Hope to hear from you soon! Best, |
Merry Christmas, I wanted to re-check since you updated the api-doc. Perhaps a good scheme would be to merge these projects all together? I checked the API doc again, however I am still unsure how to interpret the datastructures that you provide. I bid you to aid me in the process of parsing your compiled circuits and I offer to implement/maintain/test the parser myself. I am looking forward towards this integration and am very happy to see the growth of the circuit community through this project. Best, |
Hi Tom, sorry for the late reply. It would be very nice to see if a JAX backend would simplify and automate the optimizations we currently do in the PyTorch compiler (e.g., folding and einsums optimizations). That's definitely something worth looking into. I do not personally have bandwidth for implementing a parser of compiled circuits between the two libs. However, I am interested in understanding what makes hard implementing your inference routines in cirkit (e.g., a poor data structure design that we can improve from our side?). |
Well met!
I am a researcher at the Institute for Artificial Intelligence at the University of Bremen. My area of study is tractable probabilistic cognitive robot plans. For this purpose I have written the package probabilistic_model (https://github.com/tomsch420/probabilistic_model) which contains PGMs in networkx and PCs in networkx, jax and torch.
I talked with Antonio about my progress and a potential integration with your framework.
I would be very happy to do so, since we can drastically reduce the amount of duplicated work while combining our knowledge to build a better framework.
My architecture is inspired by the one that Anji Liu presented (https://arxiv.org/pdf/2406.00766) but bets on sparse tensors. Currently, my work indicates that JAX has the best integration for that. While I do not have an extensive benchmark yet, I testet with ~90k parameters and 200k samples. One forward pass took approximately 50ms with JAX in that situation, while the sum layers were extremely sparse.
Furthermore, my Architecture is completely object oriented. From your code I saw a lot of functional design and factories, which Python is not build for. (Also indicated by #276)
I would be happy to discuss the architectures in detail and assist you in building a framework that is fit for all kinds of users.
Greetings,
Tom
The text was updated successfully, but these errors were encountered: