Skip to content

Commit

Permalink
fixed docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Niklas Abraham - INFlux committed Jan 16, 2025
1 parent 9e171e2 commit 110fe4c
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 155 deletions.
145 changes: 53 additions & 92 deletions src/pyeed/analysis/embedding_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Literal
import matplotlib.pyplot as plt
import scipy.spatial as sp
from scipy.spatial.distance import cosine

from pyeed.main import Pyeed
from pyeed.dbconnect import DatabaseConnector
Expand Down Expand Up @@ -42,9 +43,16 @@ def get_embedding(self, sequence_id: str, db: DatabaseConnector):

def _get_single_embedding_last_hidden_state(self, sequence, model, tokenizer, device):
"""
Generates embeddings for a single sequence.
And return the last hidden state. As a numpy array. Not the mean of the last hidden state.
Allows analysis of the single token in the sequences.
Generates embeddings for a single sequence using the last hidden state.
Args:
sequence (str): The protein sequence to embed
model: The transformer model to use
tokenizer: The tokenizer for the model
device: The device to run the model on (CPU/GPU)
Returns:
numpy.ndarray: Normalized embeddings for each token in the sequence
"""

with torch.no_grad():
Expand All @@ -59,7 +67,14 @@ def _get_single_embedding_last_hidden_state(self, sequence, model, tokenizer, de

def calculate_single_sequence_embedding(self, sequence: str, model_name: str = "facebook/esm2_t33_650M_UR50D"):
"""
Calculates an embedding for a single sequence.
Calculates an embedding for a single sequence using a specified model.
Args:
sequence (str): The protein sequence to embed
model_name (str): Name of the pretrained model to use. Defaults to "facebook/esm2_t33_650M_UR50D"
Returns:
numpy.ndarray: The normalized embedding for the sequence
"""
model, tokenizer, device = load_model_and_tokenizer(model_name)
return self._get_single_embedding_last_hidden_state(sequence, model, tokenizer, device)
Expand Down Expand Up @@ -108,16 +123,23 @@ def find_closest_matches_simple(self, start_sequence_id: str, db: DatabaseConnec

return distances[:n]

def visualization_2d_projection_tsne(self, db: DatabaseConnector, perplexity = 50, n_iter = 1000, ids_list = None, ids_list_labels = None):
def calculate_2d_projection_tsne(self, db: DatabaseConnector, perplexity = 50, n_iter = 1000, ids_list = None, ids_list_labels = None, random_state = 42):
"""
This function will perform a 2D projection of the embeddings using t-SNE and visualize it.
Parameters:
db (DatabaseConnector): The database connector object
perplexity (int): The perplexity parameter for the t-SNE algorithm. Default is 30
n_iter (int): The number of iterations for the t-SNE algorithm. Default is 1000
ids_list (list): A list of sequence ids for which we want to visualize the embeddings. Default is None
ids_list_labels (list): A list of labels for the sequence ids. Default is None
Performs a 2D projection of the embeddings using t-SNE and prepares visualization data.
Args:
db (DatabaseConnector): The database connector object
perplexity (int): The perplexity parameter for t-SNE. Defaults to 50
n_iter (int): Number of iterations for t-SNE. Defaults to 1000
ids_list (list[str], optional): List of sequence IDs to visualize. If None, uses all sequences
ids_list_labels (dict[str, str], optional): Dictionary mapping sequence IDs to their labels
Returns:
tuple: A tuple containing:
- list[str]: List of protein IDs
- numpy.ndarray: 2D projection coordinates
- list[str]: Labels for each point
- list[str]: Colors for each point
"""
if ids_list is None:
# get all the accession_ids
Expand All @@ -136,9 +158,7 @@ def visualization_2d_projection_tsne(self, db: DatabaseConnector, perplexity = 5
RETURN p.accession_id AS protein_id, p.embedding AS embedding
"""
result = db.execute_read(query, {"ids": ids_list})
print(f"Number of proteins fetched with embeddings: {len(result)}")
print(f"Number of proteins in id list: {len(ids_list)}")


# prepare data for visualization
protein_ids, labels = [], []
embeddings = []
Expand Down Expand Up @@ -174,89 +194,28 @@ def visualization_2d_projection_tsne(self, db: DatabaseConnector, perplexity = 5

# perform t-SNE
from sklearn.manifold import TSNE
tsne = TSNE(n_components = 2, perplexity = perplexity, max_iter = n_iter, random_state=42)
tsne = TSNE(n_components = 2, perplexity = perplexity, max_iter = n_iter, random_state=random_state)

embeddings_2d = tsne.fit_transform(embeddings)

return protein_ids, embeddings_2d, labels, colors

def visualization_2d_projection_umap(self, db: DatabaseConnector, n_neighbors = 15, min_dist = 0.1, metric = 'cosine', ids_list = None, ids_list_labels = None):
"""
This function will perform a 2D projection of the embeddings using UMAP and visualize it.
Parameters:
db (DatabaseConnector): The database connector object
n_neighbors (int): The number of neighbors parameter for the UMAP algorithm. Default is 15
min_dist (float): The minimum distance parameter for the UMAP algorithm. Default is 0.1
metric (str): The metric to use for the distance calculation. Default is 'cosine'
ids_list (list): A list of sequence ids for which we want to visualize the embeddings. Default is None
ids_list_labels (list): A list of labels for the sequence ids. Default is None
"""
if ids_list is None:
# get all the accession_ids
query = """
MATCH (p:Protein)
WHERE p.accession_id IS NOT NULL
RETURN p.accession_id AS protein_id
"""
ids_list = [record['protein_id'] for record in db.execute_read(query)]


# get the embeddings for the proteins based in the ids list
query = """
MATCH (p:Protein)
WHERE p.accession_id IN $ids
RETURN p.accession_id AS protein_id, p.embedding AS embedding
def plot_matrix_comparison(self, distance_matrix_1, distance_matrix_2, protein_ids_1, protein_ids_2, label_1, label_2, number_of_points_goal):
"""
result = db.execute_read(query, {"ids": ids_list})
print(f"Number of proteins fetched with embeddings: {len(result)}")
print(f"Number of proteins in id list: {len(ids_list)}")

# prepare data for visualization
protein_ids, labels = [], []
embeddings = []

for record in result:
protein_ids.append(record["protein_id"])
print(type(embeddings))
print(type(record["embedding"]))
embeddings.append(record["embedding"])
# the label is either None or the label from the ids_list_labels
if ids_list_labels is not None:
labels.append(ids_list_labels[record["protein_id"]])
else:
labels.append('None')

embeddings = np.array(embeddings)

# assign each label a color and create the color list with the corresponding colors
# the default colo for 'None' is black
colors = []
color_label_dict: dict[str, str] = {}
import matplotlib.pyplot as plt
cycle_colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
Plots a comparison between two distance matrices.
Args:
distance_matrix_1 (numpy.ndarray): First distance matrix to compare
distance_matrix_2 (numpy.ndarray): Second distance matrix to compare
protein_ids_1 (list[str]): Protein IDs corresponding to first matrix
protein_ids_2 (list[str]): Protein IDs corresponding to second matrix
label_1 (str): Label for the first matrix
label_2 (str): Label for the second matrix
number_of_points_goal (int): Target number of points to plot
for label in labels:
if label not in color_label_dict.keys():
color_label_dict[label] = cycle_colors[len(color_label_dict) % len(cycle_colors)]


# assign the colors
for label in labels:
if label == None:
colors.append('black')
else:
colors.append(color_label_dict[label])

# perform UMAP
import umap
umap_model = umap.UMAP(n_neighbors = n_neighbors, min_dist = min_dist, metric = metric)
embeddings_2d = umap_model.fit_transform(embeddings)

return protein_ids, embeddings_2d, labels, colors

def plot_matrix_comparison(self, distance_matrix_1, distance_matrix_2, protein_ids_1, protein_ids_2, label_1, label_2, number_of_points_goal):
Returns:
None: Displays the plot using matplotlib
"""
# general plot function
# we want to plot the two distance matrices against each other
fig = plt.figure(figsize=(15, 10))
Expand Down Expand Up @@ -292,7 +251,9 @@ def plot_matrix_comparison(self, distance_matrix_1, distance_matrix_2, protein_i
plt.ylabel(label_2)
plt.grid()
plt.tight_layout()
plt.show()

# return the plot so we can show it
return fig

def calculate_similarity(
self,
Expand Down
58 changes: 42 additions & 16 deletions src/pyeed/analysis/mutation_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,39 @@

class MutationDetection:
def __init__(self):
"""
Initialize the MutationDetection class.
"""
None

def get_mutations_between_sequences(self, sequence_id1: str, sequence_id2: str, db: DatabaseConnector, standard_numbering_tool_name: str, print_debug: bool = False, save_to_db: bool = True):
"""
This function will get the number of mutations between two sequences.
The mutations are detected in the following way:
- the standard numbering tool is used to get the positions of the sequences
- the mutations are then detected by comparing the positions of the sequences (using the standard numbering tool)
- the mutations are then returned as a dictionary with the following structure:
- from_positions: the positions of the mutations in the original sequence
- to_positions: the positions of the mutations in the mutated sequence
- from_monomers: the original monomers at the mutation positions
- to_monomers: the mutated monomers at the mutation positions
- the position numbers are the point in the sequence not the positions from the standard numbering tool
Get the mutations between two sequences using a standard numbering tool.
Parameters:
sequence_id1 (str): The sequence id for the first sequence
sequence_id2 (str): The sequence id for the second sequence
db (DatabaseConnector): The database connector object
sequence_id1 (str): The accession ID of the first sequence
sequence_id2 (str): The accession ID of the second sequence
db (DatabaseConnector): The database connector object
standard_numbering_tool_name (str): Name of the standard numbering tool to use
print_debug (bool, optional): Whether to print debug information. Defaults to False.
save_to_db (bool, optional): Whether to save mutations to database. Defaults to True.
Returns:
dict: A dictionary containing mutation information with the following structure:
- sequence_id1: ID of the first sequence
- sequence_id2: ID of the second sequence
- from_positions: List of positions in the first sequence where mutations occur (1-based)
- to_positions: List of positions in the second sequence where mutations occur (1-based)
- from_monomers: List of original monomers at mutation positions
- to_monomers: List of mutated monomers at mutation positions
Raises:
ValueError: If standard numbering positions cannot be found for both sequences
Note:
The mutations are detected by comparing sequence positions aligned using the
standard numbering tool. Only positions that exist in both sequences (common positions)
are compared for mutations.
"""
# Get the standard numbering positions for both sequences
query = f"""
Expand Down Expand Up @@ -105,7 +117,21 @@ def get_mutations_between_sequences(self, sequence_id1: str, sequence_id2: str,

def save_mutations_to_db(self, mutations: dict, db: DatabaseConnector):
"""
This function will save the mutations to the database
Save detected mutations to the database as relationships between proteins.
Parameters:
mutations (dict): Dictionary containing mutation information with the structure:
- sequence_id1: ID of the first sequence
- sequence_id2: ID of the second sequence
- from_positions: List of positions in the first sequence
- to_positions: List of positions in the second sequence
- from_monomers: List of original monomers
- to_monomers: List of mutated monomers
db (DatabaseConnector): The database connector object
Note:
Creates HAS_MUTATION relationships between proteins in the database,
with properties storing the mutation details (positions and monomers).
"""

# create the mutation relationship between the proteins
Expand Down
Loading

0 comments on commit 110fe4c

Please sign in to comment.