Skip to content

Commit

Permalink
fix: fixed enzyme optmization with Kcat fitness function
Browse files Browse the repository at this point in the history
Signed-off-by: yvesnana <[email protected]>
  • Loading branch information
yvesnana committed Apr 24, 2024
1 parent e73bac8 commit 4d4c6fc
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/gt4sd/frameworks/enzeptional/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,24 +347,24 @@ def get_mutations(

class Scorer:
def __init__(self, scorer_filepath: str, scaler_filepath: Optional[str] = None):
'''Initialize the scorer.
"""Initialize the scorer.
Args:
scorer_filepath (str): Pickled scorer filepath.
scaler_filepath (Optional[str], optional): Pickled scaler filepath. Defaults to None.
'''
scaler_filepath (Optional[str], optional): Pickled scaler filepath. Defaults to None.
"""
self.scorer_filepath = scorer_filepath
self.scorer = load(scorer_filepath)
if scaler_filepath is not None:
self.scaler = load(scaler_filepath)

def predict_proba(self, feature_vector):
return self.scorer.predict_proba(feature_vector)

def predict(self, feature_vector):
if self.scaler is not None:
feature_vector = self.scaler.transform(feature_vector)
return self.scorer.predict(xgb.DMatrix(feature_vector))


class EnzymeOptimizer:
"""
Expand Down Expand Up @@ -623,16 +623,16 @@ def score_sequence(self, sequence: str) -> Dict[str, Any]:
]
combined_embedding = np.concatenate(ordered_embeddings)
combined_embedding = combined_embedding.reshape(1, -1)

if self.use_xgboost_scorer:
if self.scaler is not None:
combined_embedding = self.scaler.transform(combined_embedding)
score = self.scorer.predict(xgb.DMatrix(combined_embedding))[0]
else:
score = self.scorer.predict_proba(combined_embedding)[0][1]

return {"sequence": sequence, "score": score}

def score_sequences(self, sequences: List[str]) -> List[Dict[str, float]]:
"""
Scores a list of protein sequences.
Expand All @@ -656,7 +656,7 @@ def score_sequences(self, sequences: List[str]) -> List[Dict[str, float]]:
]
ordered_embeddings = [
embeddings[self.concat_order.index(item)] for item in self.concat_order
]
]
combined_embedding = np.concatenate(ordered_embeddings)
combined_embedding = combined_embedding.reshape(1, -1)

Expand Down

0 comments on commit 4d4c6fc

Please sign in to comment.