Skip to content
This repository has been archived by the owner on Mar 25, 2024. It is now read-only.

add NME_SC method for clustering #1

Open
wants to merge 12 commits into
base: max_speaker
Choose a base branch
from
75 changes: 75 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import soundfile as sf
Jeronymous marked this conversation as resolved.
Show resolved Hide resolved
import matplotlib.pyplot as plt
import os,sys,time


from simple_diarizer.diarizer import Diarizer
from simple_diarizer.utils import combined_waveplot

t0 = time.time()


diar = Diarizer(
embed_model='ecapa', # 'xvec' and 'ecapa' supported
cluster_method='NME-sc' # 'ahc' 'sc' and 'NME-sc' supported
)

WAV_FILE,NUM_SPEAKERS,max_spk= sys.argv[1:]


if NUM_SPEAKERS == 'None':
print('None')
segments = diar.diarize(WAV_FILE, num_speakers=None,max_speakers=int(max_spk))
else:
segments = diar.diarize(WAV_FILE, num_speakers=int(NUM_SPEAKERS))


t1 = time.time()
feature_t = t1 - t0
print("Time used for extracting features:", feature_t)



json = {}
_segments = []
_speakers = {}
seg_id = 1
spk_i = 1
spk_i_dict = {}

for seg in segments:

segment = {}
segment["seg_id"] = seg_id

# Ensure speaker id continuity and numbers speaker by order of appearance.
if seg['label'] not in spk_i_dict.keys():
spk_i_dict[seg['label']] = spk_i
spk_i += 1

spk_id = "spk" + str(spk_i_dict[seg['label']])
segment["spk_id"] = spk_id
segment["seg_begin"] = round(seg['start'])
segment["seg_end"] = round(seg['end'])

if spk_id not in _speakers:
_speakers[spk_id] = {}
_speakers[spk_id]["spk_id"] = spk_id
_speakers[spk_id]["duration"] = seg['end']-seg['start']
_speakers[spk_id]["nbr_seg"] = 1
else:
_speakers[spk_id]["duration"] += seg['end']-seg['start']
_speakers[spk_id]["nbr_seg"] += 1

_segments.append(segment)
seg_id += 1

for spkstat in _speakers.values():
spkstat["duration"] = round(spkstat["duration"])

json["speakers"] = list(_speakers.values())
json["segments"] = _segments


print(json["speakers"] )

135 changes: 135 additions & 0 deletions simple_diarizer/Spectral_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import numpy as np
import scipy
from sklearn.cluster import SpectralClustering

# NME low-level operations
# These functions are taken from the Kaldi scripts.

# Prepares binarized(0/1) affinity matrix with p_neighbors non-zero elements in each row
def get_kneighbors_conn(X_dist, p_neighbors):
X_dist_out = np.zeros_like(X_dist)
for i, line in enumerate(X_dist):
sorted_idx = np.argsort(line)
sorted_idx = sorted_idx[::-1]
indices = sorted_idx[:p_neighbors]
X_dist_out[indices, i] = 1
return X_dist_out


# Thresolds affinity matrix to leave p maximum non-zero elements in each row
def Threshold(A, p):
N = A.shape[0]
Ap = np.zeros((N, N))
for i in range(N):
thr = sorted(A[i, :], reverse=True)[p]
Ap[i, A[i, :] > thr] = A[i, A[i, :] > thr]
return Ap


# Computes Laplacian of a matrix
def Laplacian(A):
d = np.sum(A, axis=1) - np.diag(A)
D = np.diag(d)
return D - A


# Calculates eigengaps (differences between adjacent eigenvalues sorted in descending order)
def Eigengap(S):
S = sorted(S)
return np.diff(S)


# Computes parameters of normalized eigenmaps for automatic thresholding selection
def ComputeNMEParameters(A, p, max_num_clusters):
# p-Neighbour binarization
Ap = get_kneighbors_conn(A, p)
# Symmetrization
Ap = (Ap + np.transpose(Ap)) / 2
# Laplacian matrix computation
Lp = Laplacian(Ap)
# Get max_num_clusters+1 smallest eigenvalues
S = scipy.sparse.linalg.eigsh(
Lp,
k=max_num_clusters + 1,
which="SA",
tol=1e-6,
return_eigenvectors=False,
mode="buckling",
)
# Get largest eigenvalue
Smax = scipy.sparse.linalg.eigsh(
Lp, k=1, which="LA", tol=1e-6, return_eigenvectors=False, mode="buckling"
)
# Eigengap computation
e = Eigengap(S)
g = np.max(e[:max_num_clusters]) / (Smax + 1e-10)
r = p / g
k = np.argmax(e[:max_num_clusters])
return (e, g, k, r)


"""
Performs spectral clustering with Normalized Maximum Eigengap (NME)
Parameters:
A: affinity matrix (matrix of pairwise cosine similarities or PLDA scores between speaker embeddings)
num_clusters: number of clusters to generate (if None, determined automatically)
max_num_clusters: maximum allowed number of clusters to generate
pmax: maximum count for matrix binarization (should be at least 2)
pbest: best count for matrix binarization (if 0, determined automatically)
Returns: cluster assignments for every speaker embedding
"""


def NME_SpectralClustering(
A, num_clusters=None, max_num_clusters=10, pbest=0, pmin=3, pmax=20
):
print(num_clusters,max_num_clusters)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No print without any string to explain.
This one looks like a debug print. remove it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

corrected

if pbest == 0:
print("Selecting best number of neighbors for affinity matrix thresolding:")
rbest = None
kbest = None
for p in range(pmin, pmax + 1):
e, g, k, r = ComputeNMEParameters(A, p, max_num_clusters)
print("p={}, g={}, k={}, r={}, e={}".format(p, g, k, r, e))
if rbest is None or rbest > r:
rbest = r
pbest = p
kbest = k
print("Best number of neighbors is {}".format(pbest))
num_clusters = num_clusters if num_clusters is not None else (kbest + 1)
# Handle some edge cases in AMI SDM
num_clusters = 4 if num_clusters == 1 else num_clusters
return NME_SpectralClustering_sklearn(
A, num_clusters, pbest
)
if num_clusters is None:
print("Compute number of clusters to generate:")
e, g, k, r = ComputeNMEParameters(A, pbest, max_num_clusters)
print("Number of clusters to generate is {}".format(k + 1))
return NME_SpectralClustering_sklearn(A, k + 1, pbest)
return NME_SpectralClustering_sklearn(A, num_clusters, pbest)


"""
Performs spectral clustering with Normalized Maximum Eigengap (NME) with fixed threshold and number of clusters
Parameters:
A: affinity matrix (matrix of pairwise cosine similarities or PLDA scores between speaker embeddings)
OLVec: 0/1 vector denoting which segments are overlap segments
num_clusters: number of clusters to generate
pbest: best count for matrix binarization
Returns: cluster assignments for every speaker embedding
"""


def NME_SpectralClustering_sklearn(A, num_clusters, pbest):
print("Number of speakers is {}".format(num_clusters))
# Ap = Threshold(A, pbest)
Ap = get_kneighbors_conn(A, pbest) # thresholded and binarized
Ap = (Ap + np.transpose(Ap)) / 2


model = SpectralClustering(
n_clusters=num_clusters, affinity="precomputed", random_state=0
)
labels = model.fit_predict(Ap)
return labels
63 changes: 57 additions & 6 deletions simple_diarizer/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from scipy.ndimage import gaussian_filter
from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering
from sklearn.metrics import pairwise_distances

from .Spectral_clustering import NME_SpectralClustering
Jeronymous marked this conversation as resolved.
Show resolved Hide resolved

def similarity_matrix(embeds, metric="cosine"):
return pairwise_distances(embeds, metric=metric)
Expand Down Expand Up @@ -43,9 +43,7 @@ def cluster_AHC(embeds, n_clusters=None, threshold=None, metric="cosine", **kwar
# A lot of these methods are lifted from
# https://github.com/wq2012/SpectralCluster
##########################################


def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwargs):
def cluster_SC(embeds, n_clusters=None, max_speakers= None, threshold=None, enhance_sim=True, **kwargs):
"""
Cluster embeds using Spectral Clustering
"""
Expand All @@ -59,7 +57,7 @@ def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwar
if n_clusters is None:
(eigenvalues, eigenvectors) = compute_sorted_eigenvectors(S)
# Get number of clusters.
k = compute_number_of_clusters(eigenvalues, 100, threshold)
k = compute_number_of_clusters(eigenvalues, max_speakers, threshold)

# Get spectral embeddings.
spectral_embeddings = eigenvectors[:, :k]
Expand All @@ -82,6 +80,34 @@ def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwar
return cluster_model.fit_predict(S)


def cluster_NME_SC(embeds, n_clusters=None, max_speakers= None, threshold=None, enhance_sim=True, **kwargs):
"""
Cluster embeds using NME-Spectral Clustering

if n_clusters is None:
assert threshold, "If num_clusters is not defined, threshold must be defined"
"""

S = cos_similarity(embeds)
if n_clusters is None:
Jeronymous marked this conversation as resolved.
Show resolved Hide resolved
labels = NME_SpectralClustering(
S,
num_clusters=n_clusters,
max_num_clusters=max_speakers

Jeronymous marked this conversation as resolved.
Show resolved Hide resolved
)
else:
labels = NME_SpectralClustering(
S,
num_clusters=n_clusters,


)


return labels


def diagonal_fill(A):
"""
Sets the diagonal elemnts of the matrix to the max of each row
Expand Down Expand Up @@ -134,7 +160,7 @@ def row_max_norm(A):
def sim_enhancement(A):
func_order = [
diagonal_fill,
gaussian_blur,

row_threshold_mult,
symmetrization,
diffusion,
Expand All @@ -144,6 +170,31 @@ def sim_enhancement(A):
A = f(A)
return A

def cos_similarity(x):
"""Compute cosine similarity matrix in CPU & memory sensitive way

Args:
x (np.ndarray): embeddings, 2D array, embeddings are in rows

Returns:
np.ndarray: cosine similarity matrix

"""
assert x.ndim == 2, f"x has {x.ndim} dimensions, it must be matrix"
x = x / (np.sqrt(np.sum(np.square(x), axis=1, keepdims=True)) + 1.0e-32)
assert np.allclose(np.ones_like(x[:, 0]), np.sum(np.square(x), axis=1))
max_n_elm = 200000000
step = max(max_n_elm // (x.shape[0] * x.shape[0]), 1)
retval = np.zeros(shape=(x.shape[0], x.shape[0]), dtype=np.float64)
x0 = np.expand_dims(x, 0)
x1 = np.expand_dims(x, 1)
for i in range(0, x.shape[1], step):
product = x0[:, :, i : i + step] * x1[:, :, i : i + step]
retval += np.sum(product, axis=2, keepdims=False)
assert np.all(retval >= -1.0001), retval
assert np.all(retval <= 1.0001), retval
return retval


def compute_affinity_matrix(X):
"""Compute the affinity matrix from data.
Expand Down
Loading