Skip to content

Commit

Permalink
Cache the untransformed response dict
Browse files Browse the repository at this point in the history
  • Loading branch information
stewarthe6 committed Dec 13, 2024
1 parent 23bd84f commit 61c18fd
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions atomsci/ddm/pipeline/model_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ def __init__(self, params, featurization):
self.subset_response_dict = {}
# Cache for subset-specific response values matched to IDs, used by k-fold CV code
self.subset_weight_dict = {}
# Cache for untransformed response values matched to IDs, used by k-fold CV code
self.untransformed_response_dict = {}


# ****************************************************************************************
def load_full_dataset(self):
Expand Down Expand Up @@ -718,10 +721,11 @@ def get_untransformed_responses(self, ids):
""" Returns a numpy array of untransformed response values
"""
response_vals = np.zeros((len(ids), self.untransformed_dataset.y.shape[1]))
response_dict = dict([(id, y) for id, y in zip(self.untransformed_dataset.ids, self.untransformed_dataset.y)])
if len(self.untransformed_response_dict) == 0:
self.untransformed_response_dict = dict(zip(self.untransformed_dataset.ids, self.untransformed_dataset.y))

for i, id in enumerate(ids):
response_vals[i] = response_dict[id]
response_vals[i] = self.untransformed_response_dict[id]

# we need to double check that all responses_vals we asked for were found
assert len(response_vals) == len(set(ids))
Expand Down

0 comments on commit 61c18fd

Please sign in to comment.