Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge fix niklas #122

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 54 additions & 92 deletions src/pyeed/analysis/embedding_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# a lot of code was provided by Tim Panzer in his bachelor thesis results
import torch
import numpy as np
from typing import Literal
import matplotlib.pyplot as plt
import scipy.spatial as sp
from scipy.spatial.distance import cosine

from pyeed.main import Pyeed
from pyeed.dbconnect import DatabaseConnector
Expand Down Expand Up @@ -41,9 +43,16 @@ def get_embedding(self, sequence_id: str, db: DatabaseConnector):

def _get_single_embedding_last_hidden_state(self, sequence, model, tokenizer, device):
"""
Generates embeddings for a single sequence.
And return the last hidden state. As a numpy array. Not the mean of the last hidden state.
Allows analysis of the single token in the sequences.
Generates embeddings for a single sequence using the last hidden state.

Args:
sequence (str): The protein sequence to embed
model: The transformer model to use
tokenizer: The tokenizer for the model
device: The device to run the model on (CPU/GPU)

Returns:
numpy.ndarray: Normalized embeddings for each token in the sequence
"""

with torch.no_grad():
Expand All @@ -58,7 +67,14 @@ def _get_single_embedding_last_hidden_state(self, sequence, model, tokenizer, de

def calculate_single_sequence_embedding(self, sequence: str, model_name: str = "facebook/esm2_t33_650M_UR50D"):
"""
Calculates an embedding for a single sequence.
Calculates an embedding for a single sequence using a specified model.

Args:
sequence (str): The protein sequence to embed
model_name (str): Name of the pretrained model to use. Defaults to "facebook/esm2_t33_650M_UR50D"

Returns:
numpy.ndarray: The normalized embedding for the sequence
"""
model, tokenizer, device = load_model_and_tokenizer(model_name)
return self._get_single_embedding_last_hidden_state(sequence, model, tokenizer, device)
Expand Down Expand Up @@ -107,16 +123,23 @@ def find_closest_matches_simple(self, start_sequence_id: str, db: DatabaseConnec

return distances[:n]

def visualization_2d_projection_tsne(self, db: DatabaseConnector, perplexity = 50, n_iter = 1000, ids_list = None, ids_list_labels = None):
def calculate_2d_projection_tsne(self, db: DatabaseConnector, perplexity = 50, n_iter = 1000, ids_list = None, ids_list_labels = None, random_state = 42):
"""
This function will perform a 2D projection of the embeddings using t-SNE and visualize it.

Parameters:
db (DatabaseConnector): The database connector object
perplexity (int): The perplexity parameter for the t-SNE algorithm. Default is 30
n_iter (int): The number of iterations for the t-SNE algorithm. Default is 1000
ids_list (list): A list of sequence ids for which we want to visualize the embeddings. Default is None
ids_list_labels (list): A list of labels for the sequence ids. Default is None
Performs a 2D projection of the embeddings using t-SNE and prepares visualization data.

Args:
db (DatabaseConnector): The database connector object
perplexity (int): The perplexity parameter for t-SNE. Defaults to 50
n_iter (int): Number of iterations for t-SNE. Defaults to 1000
ids_list (list[str], optional): List of sequence IDs to visualize. If None, uses all sequences
ids_list_labels (dict[str, str], optional): Dictionary mapping sequence IDs to their labels

Returns:
tuple: A tuple containing:
- list[str]: List of protein IDs
- numpy.ndarray: 2D projection coordinates
- list[str]: Labels for each point
- list[str]: Colors for each point
"""
if ids_list is None:
# get all the accession_ids
Expand All @@ -135,9 +158,7 @@ def visualization_2d_projection_tsne(self, db: DatabaseConnector, perplexity = 5
RETURN p.accession_id AS protein_id, p.embedding AS embedding
"""
result = db.execute_read(query, {"ids": ids_list})
print(f"Number of proteins fetched with embeddings: {len(result)}")
print(f"Number of proteins in id list: {len(ids_list)}")


# prepare data for visualization
protein_ids, labels = [], []
embeddings = []
Expand Down Expand Up @@ -173,89 +194,28 @@ def visualization_2d_projection_tsne(self, db: DatabaseConnector, perplexity = 5

# perform t-SNE
from sklearn.manifold import TSNE
tsne = TSNE(n_components = 2, perplexity = perplexity, max_iter = n_iter, random_state=42)
tsne = TSNE(n_components = 2, perplexity = perplexity, max_iter = n_iter, random_state=random_state)

embeddings_2d = tsne.fit_transform(embeddings)

return protein_ids, embeddings_2d, labels, colors

def visualization_2d_projection_umap(self, db: DatabaseConnector, n_neighbors = 15, min_dist = 0.1, metric = 'cosine', ids_list = None, ids_list_labels = None):
"""
This function will perform a 2D projection of the embeddings using UMAP and visualize it.

Parameters:
db (DatabaseConnector): The database connector object
n_neighbors (int): The number of neighbors parameter for the UMAP algorithm. Default is 15
min_dist (float): The minimum distance parameter for the UMAP algorithm. Default is 0.1
metric (str): The metric to use for the distance calculation. Default is 'cosine'
ids_list (list): A list of sequence ids for which we want to visualize the embeddings. Default is None
ids_list_labels (list): A list of labels for the sequence ids. Default is None
"""
if ids_list is None:
# get all the accession_ids
query = """
MATCH (p:Protein)
WHERE p.accession_id IS NOT NULL
RETURN p.accession_id AS protein_id
"""
ids_list = [record['protein_id'] for record in db.execute_read(query)]


# get the embeddings for the proteins based in the ids list
query = """
MATCH (p:Protein)
WHERE p.accession_id IN $ids
RETURN p.accession_id AS protein_id, p.embedding AS embedding
def plot_matrix_comparison(self, distance_matrix_1, distance_matrix_2, protein_ids_1, protein_ids_2, label_1, label_2, number_of_points_goal):
"""
result = db.execute_read(query, {"ids": ids_list})
print(f"Number of proteins fetched with embeddings: {len(result)}")
print(f"Number of proteins in id list: {len(ids_list)}")

# prepare data for visualization
protein_ids, labels = [], []
embeddings = []

for record in result:
protein_ids.append(record["protein_id"])
print(type(embeddings))
print(type(record["embedding"]))
embeddings.append(record["embedding"])
# the label is either None or the label from the ids_list_labels
if ids_list_labels is not None:
labels.append(ids_list_labels[record["protein_id"]])
else:
labels.append('None')

embeddings = np.array(embeddings)

# assign each label a color and create the color list with the corresponding colors
# the default colo for 'None' is black
colors = []
color_label_dict: dict[str, str] = {}
import matplotlib.pyplot as plt
cycle_colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
Plots a comparison between two distance matrices.

Args:
distance_matrix_1 (numpy.ndarray): First distance matrix to compare
distance_matrix_2 (numpy.ndarray): Second distance matrix to compare
protein_ids_1 (list[str]): Protein IDs corresponding to first matrix
protein_ids_2 (list[str]): Protein IDs corresponding to second matrix
label_1 (str): Label for the first matrix
label_2 (str): Label for the second matrix
number_of_points_goal (int): Target number of points to plot

for label in labels:
if label not in color_label_dict.keys():
color_label_dict[label] = cycle_colors[len(color_label_dict) % len(cycle_colors)]


# assign the colors
for label in labels:
if label == None:
colors.append('black')
else:
colors.append(color_label_dict[label])

# perform UMAP
import umap
umap_model = umap.UMAP(n_neighbors = n_neighbors, min_dist = min_dist, metric = metric)
embeddings_2d = umap_model.fit_transform(embeddings)

return protein_ids, embeddings_2d, labels, colors

def plot_matrix_comparison(self, distance_matrix_1, distance_matrix_2, protein_ids_1, protein_ids_2, label_1, label_2, number_of_points_goal):
Returns:
None: Displays the plot using matplotlib
"""
# general plot function
# we want to plot the two distance matrices against each other
fig = plt.figure(figsize=(15, 10))
Expand Down Expand Up @@ -291,7 +251,9 @@ def plot_matrix_comparison(self, distance_matrix_1, distance_matrix_2, protein_i
plt.ylabel(label_2)
plt.grid()
plt.tight_layout()
plt.show()

# return the plot so we can show it
return fig

def calculate_similarity(
self,
Expand Down
58 changes: 42 additions & 16 deletions src/pyeed/analysis/mutation_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,39 @@

class MutationDetection:
def __init__(self):
"""
Initialize the MutationDetection class.
"""
None

def get_mutations_between_sequences(self, sequence_id1: str, sequence_id2: str, db: DatabaseConnector, standard_numbering_tool_name: str, print_debug: bool = False, save_to_db: bool = True):
"""
This function will get the number of mutations between two sequences.
The mutations are detected in the following way:
- the standard numbering tool is used to get the positions of the sequences
- the mutations are then detected by comparing the positions of the sequences (using the standard numbering tool)
- the mutations are then returned as a dictionary with the following structure:
- from_positions: the positions of the mutations in the original sequence
- to_positions: the positions of the mutations in the mutated sequence
- from_monomers: the original monomers at the mutation positions
- to_monomers: the mutated monomers at the mutation positions
- the position numbers are the point in the sequence not the positions from the standard numbering tool
Get the mutations between two sequences using a standard numbering tool.


Parameters:
sequence_id1 (str): The sequence id for the first sequence
sequence_id2 (str): The sequence id for the second sequence
db (DatabaseConnector): The database connector object

sequence_id1 (str): The accession ID of the first sequence
sequence_id2 (str): The accession ID of the second sequence
db (DatabaseConnector): The database connector object
standard_numbering_tool_name (str): Name of the standard numbering tool to use
print_debug (bool, optional): Whether to print debug information. Defaults to False.
save_to_db (bool, optional): Whether to save mutations to database. Defaults to True.

Returns:
dict: A dictionary containing mutation information with the following structure:
- sequence_id1: ID of the first sequence
- sequence_id2: ID of the second sequence
- from_positions: List of positions in the first sequence where mutations occur (1-based)
- to_positions: List of positions in the second sequence where mutations occur (1-based)
- from_monomers: List of original monomers at mutation positions
- to_monomers: List of mutated monomers at mutation positions

Raises:
ValueError: If standard numbering positions cannot be found for both sequences

Note:
The mutations are detected by comparing sequence positions aligned using the
standard numbering tool. Only positions that exist in both sequences (common positions)
are compared for mutations.
"""
# Get the standard numbering positions for both sequences
query = f"""
Expand Down Expand Up @@ -105,7 +117,21 @@ def get_mutations_between_sequences(self, sequence_id1: str, sequence_id2: str,

def save_mutations_to_db(self, mutations: dict, db: DatabaseConnector):
"""
This function will save the mutations to the database
Save detected mutations to the database as relationships between proteins.

Parameters:
mutations (dict): Dictionary containing mutation information with the structure:
- sequence_id1: ID of the first sequence
- sequence_id2: ID of the second sequence
- from_positions: List of positions in the first sequence
- to_positions: List of positions in the second sequence
- from_monomers: List of original monomers
- to_monomers: List of mutated monomers
db (DatabaseConnector): The database connector object

Note:
Creates HAS_MUTATION relationships between proteins in the database,
with properties storing the mutation details (positions and monomers).
"""

# create the mutation relationship between the proteins
Expand Down
Loading
Loading