Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyO3: Add pytorch like .to() operator to candle.Tensor #1100

Merged
merged 3 commits into from
Oct 19, 2023

Conversation

LLukas22
Copy link
Contributor

Allows casting or moving (or both at once) of a tensor via .to(). I get why we don't want this on the rust side of things but for the wrapper it's just way more convenient as it's heavily used in existing pytorch models.

candle-pyo3/src/lib.rs Outdated Show resolved Hide resolved
candle-pyo3/src/lib.rs Outdated Show resolved Hide resolved
@LLukas22
Copy link
Contributor Author

Alright, i changed it to allow a dtype or device to only be provided once, otherwise there will be an ValueError. I also added the option to pass in another other tensor as an parameter, similar to pytorch. I also added tests for these cases. I'm not to happy with how the implementation turned out, but i currently don't know how to improve it.

@LaurentMazare LaurentMazare merged commit 6684b71 into huggingface:main Oct 19, 2023
@LaurentMazare
Copy link
Collaborator

Thanks, agreed that the code doesn't look super nice with all this mutability, certainly feels error prone but let's start with this and clean this up later if it sticks.

EricLBuehler pushed a commit to EricLBuehler/candle that referenced this pull request Oct 25, 2023
…ce#1100)

* add `.to()` operator

* Only allow each value to be provided once via `args` or `kwargs`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants