Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Support custom obs encoders #1191

Merged
merged 30 commits into from
Jul 6, 2024
Merged

Conversation

ebezzi
Copy link
Member

@ebezzi ebezzi commented Jun 12, 2024

Add support for custom obs encoders. Example use case: when multiple obs columns need to be batched together before the tensor creation.

Copy link
Collaborator

@atolopko-czi atolopko-czi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted two issues that need to be addressed (see ⚠️ comments). Otherwise, just some nits.

@@ -164,6 +212,8 @@ def __next__(self) -> _SOMAChunk:
)
assert obs_batch.shape[0] == obs_joinids_chunk.shape[0]

# print("obs_batch", obs_batch)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm

@@ -700,7 +766,8 @@ def obs_encoders(self) -> Encoders:
self._init()
assert self._encoders is not None

return self._encoders
# return self._encoders
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm

for enc in self._custom_encoders:
enc.register(obs)
encoders.append(enc)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Should this still register default encoders for columns that are not the source of any custom encoder? This complicates matters, I realize, since each encoder would then have to publicize its source columns explicitly. But without it, any object-typed columns that do not have a custom encoder will not be usable in a Tensor. Minimally, the class should present a clear error if after transform, any non-numeric columns remain (with a hint to specify an encoder for each unencoded column).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea here is that, if the user specifies custom encoders, then they are responsible for ensuring that all the columns are transformed properly. With the new changes (see lined 346-347), if a column isn't explicitly registered by the encoder, it won't be added to the final obs_tensor. I believe this addresses this concern, but feel free to let me know if I am still missing something 😄

That said, I think we should expand the docstrings to better explain the consequences of using custom encoders.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I suppose the obs_columns and encoders could be mutually exclusive. That is, it would be nice if one could specify just the encoders without having to also specify the source obs_columns, since this could lead to errors. But here again, the Encoder would need to be able to report what columns it depends on so that the ExpDataPipe can retrieve them correctly.

But I also agree this could just be handled via clear documentation and useful error messages.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree that, if using custom encoders, obs_columns could be omitted. There is a slight performance increase in using them, as it doesn't require to fetch the whole obs though. Maybe we can just make them optional?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My 2 cents is that this should error if both are passed, and the encoders should be able to report which columns they are going to grab. This links onto my other comment about adopting a scikit-learn like (or maybe even compatible) API, where column transformers and metadata routing are used to retrieve information like column names for their transformers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my heart of hearts, I would probably even allow specifying an encoder for X. E.g. instead of specifying use_sparse_X you could have a DenseEncoder().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with implementing the logic where either the parameter can be defined. I also think that using an encoder for X is a great idea - I'll create a ticket for that as it will go out of scope for this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the obs_columns vs encoder are now mutually exclusive, so the comment is nearly addressed--does the X encoder idea still need an issue created for it? (didn't see one)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added: #1219

self._encoder = LabelEncoder()
self.col = col

def register(self, obs: pd.DataFrame) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be called fit? This would make the API very similar to a scikit-learn transformer, which I believe is close to what this is replacing.

It could be nice to be able to use some of their existing transformers here, or at least provide a familiar API

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also an existing concept of a ColumnEncoder which keeps track of column names

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that fit is a better name. Will change it.

Comment on lines 96 to 99
@property
def name(self) -> str:
"""Name of the encoder."""
return self.col
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the user might provide multiple encoders, I think it would be nice if they could provide multiple encoders of the same class. This would mean that the name probably has to be dynamic.

It could also be nice to provide each encoder as: encoders={name: Encoder(), ...} and then being able to access the endoded values at training time as: batch[name]. I think it's a little error prone right now that the positions of the encoded values are set by the positions of the encoder, even though the code that needs to synchronize this isn't in the same place. E.g. if the first encoder is removed, the training loop may need to change each variable = y_batch[:, i]-like statements.

for enc in self._custom_encoders:
enc.register(obs)
encoders.append(enc)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My 2 cents is that this should error if both are passed, and the encoders should be able to report which columns they are going to grab. This links onto my other comment about adopting a scikit-learn like (or maybe even compatible) API, where column transformers and metadata routing are used to retrieve information like column names for their transformers.

for enc in self._custom_encoders:
enc.register(obs)
encoders.append(enc)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my heart of hearts, I would probably even allow specifying an encoder for X. E.g. instead of specifying use_sparse_X you could have a DenseEncoder().

@@ -419,6 +421,7 @@ def test_distributed__returns_data_partition_for_rank(
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
encoders=[DefaultEncoder("soma_joinid"), DefaultEncoder("label")],
Copy link
Member Author

@ebezzi ebezzi Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test use the soma_joinid to assert positional conditions, so this is how you force the soma_joinid to be part of the encoded values.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, then this point really should be explained in the docstring for encoders param.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also would obs_column_namess=["soma_joinid", "label"] be equivalent?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the same - soma_joinid is ignored because by default only string columns are encoded when using the default behavior. Using custom encoders override this.

Copy link

codecov bot commented Jun 28, 2024

Codecov Report

Attention: Patch coverage is 95.10490% with 7 lines in your changes missing coverage. Please review.

Project coverage is 91.30%. Comparing base (f775282) to head (f165f7f).
Report is 2 commits behind head on main.

Files Patch % Lines
...s/src/cellxgene_census/experimental/ml/encoders.py 89.70% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1191      +/-   ##
==========================================
+ Coverage   91.19%   91.30%   +0.11%     
==========================================
  Files          77       80       +3     
  Lines        5971     6256     +285     
==========================================
+ Hits         5445     5712     +267     
- Misses        526      544      +18     
Flag Coverage Δ
unittests 91.30% <95.10%> (+0.11%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@atolopko-czi atolopko-czi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/a couple of minor comments and maybe a couple more tests:

The various existing tests do exercise both obs_columns and encoders, but the encoders param doesn't have an explicit test. So could add a couple of tests perhaps:

  • A test similar to test_encoders but the for encoders param
  • A custom encoder that reads from multiple columns (since that was a motivation for this work)

if obs_column_names and encoders:
raise ValueError(
"Cannot specify both `obs_column_names` and `encoders`. If `encoders` are specified, columns will be inferred automatically."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

for enc in self._custom_encoders:
enc.register(obs)
encoders.append(enc)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the obs_columns vs encoder are now mutually exclusive, so the comment is nearly addressed--does the X encoder idea still need an issue created for it? (didn't see one)

@@ -419,6 +421,7 @@ def test_distributed__returns_data_partition_for_rank(
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
encoders=[DefaultEncoder("soma_joinid"), DefaultEncoder("label")],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, then this point really should be explained in the docstring for encoders param.

@@ -419,6 +421,7 @@ def test_distributed__returns_data_partition_for_rank(
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
encoders=[DefaultEncoder("soma_joinid"), DefaultEncoder("label")],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also would obs_column_namess=["soma_joinid", "label"] be equivalent?

if len(encoders) != len({enc.name for enc in encoders}):
raise ValueError("Encoders must have unique names")

self.obs_column_names = list(itertools.chain(*[enc.columns for enc in encoders]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if two encoders rely on the same column?

My guess is that it would error since this errors: query.obs(column_names=["soma_joinid", "soma_joinid"]).concat().to_pandas()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a dedup here.

Copy link
Collaborator

@ivirshup ivirshup left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'm not totally sure about the API (e.g. what happens if an encoder returns a non-1d value?), but this can be revisited in future.

@ebezzi ebezzi force-pushed the ebezzi/support-custom-obs-encoders branch from 498b8a0 to 96e0650 Compare July 3, 2024 20:27
Copy link
Contributor

@pablo-gar pablo-gar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notebook looks good!

@ebezzi ebezzi merged commit 8457e3f into main Jul 6, 2024
17 checks passed
@ebezzi ebezzi deleted the ebezzi/support-custom-obs-encoders branch July 6, 2024 01:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants