diff --git a/src/gt4sd/frameworks/enzeptional/core.py b/src/gt4sd/frameworks/enzeptional/core.py index 9e528e7bb..a6be1710e 100644 --- a/src/gt4sd/frameworks/enzeptional/core.py +++ b/src/gt4sd/frameworks/enzeptional/core.py @@ -347,11 +347,11 @@ def get_mutations( class Scorer: def __init__(self, scorer_filepath: str, scaler_filepath: Optional[str] = None): - '''Initialize the scorer. + """Initialize the scorer. Args: scorer_filepath (str): Pickled scorer filepath. - scaler_filepath (Optional[str], optional): Pickled scaler filepath. Defaults to None. - ''' + scaler_filepath (Optional[str], optional): Pickled scaler filepath. Defaults to None. + """ self.scorer_filepath = scorer_filepath self.scorer = load(scorer_filepath) if scaler_filepath is not None: @@ -359,12 +359,12 @@ def __init__(self, scorer_filepath: str, scaler_filepath: Optional[str] = None): def predict_proba(self, feature_vector): return self.scorer.predict_proba(feature_vector) - + def predict(self, feature_vector): if self.scaler is not None: feature_vector = self.scaler.transform(feature_vector) return self.scorer.predict(xgb.DMatrix(feature_vector)) - + class EnzymeOptimizer: """ @@ -623,16 +623,16 @@ def score_sequence(self, sequence: str) -> Dict[str, Any]: ] combined_embedding = np.concatenate(ordered_embeddings) combined_embedding = combined_embedding.reshape(1, -1) - + if self.use_xgboost_scorer: if self.scaler is not None: combined_embedding = self.scaler.transform(combined_embedding) score = self.scorer.predict(xgb.DMatrix(combined_embedding))[0] else: score = self.scorer.predict_proba(combined_embedding)[0][1] - + return {"sequence": sequence, "score": score} - + def score_sequences(self, sequences: List[str]) -> List[Dict[str, float]]: """ Scores a list of protein sequences. @@ -656,7 +656,7 @@ def score_sequences(self, sequences: List[str]) -> List[Dict[str, float]]: ] ordered_embeddings = [ embeddings[self.concat_order.index(item)] for item in self.concat_order - ] + ] combined_embedding = np.concatenate(ordered_embeddings) combined_embedding = combined_embedding.reshape(1, -1)