Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Highway2Vec - clustering and similarity #4

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion 04_ml.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,74 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Highway2Vec Clustering and similarity search\n",
"## Highway2Vec - Clustering and similarity search\n",
"\n",
"In this part we will see:\n",
"<!-- * How to use a pre-trained hex2vec model with srai\n",
"* How to train classification model based on srai embeddings\n",
"* How to use srai to gather training data -->"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from srai.loaders import OSMWayLoader, OSMNetworkType\n",
"from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf\n",
"from srai.joiners import IntersectionJoiner\n",
"from srai.embedders import Highway2VecEmbedder\n",
"\n",
"area = geocode_to_region_gdf(\"Wrocław, Poland\")\n",
"nodes, edges = OSMWayLoader(OSMNetworkType.DRIVE).load(area)\n",
"regions = H3Regionalizer(resolution=9).transform(area) \n",
"joint = IntersectionJoiner().transform(regions, edges)\n",
"\n",
" \n",
"embedder = Highway2VecEmbedder()\n",
"embedder.fit(regions, edges, joint)\n",
"embeddings = embedder.transform(regions, edges, joint)\n",
"embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from clustering import (\n",
" scale_embeddings,\n",
" generate_clustering_model,\n",
" generate_linkage_matrix,\n",
" plot_dendrogram,\n",
" cluster_regions,\n",
" plot_clustered_regions_with_roads\n",
")\n",
"import matplotlib.pyplot as plt\n",
"\n",
"embeddings_scaled = scale_embeddings(embeddings)\n",
"ac_model = generate_clustering_model(\n",
" embeddings_scaled,\n",
" {\n",
" \"n_clusters\": None,\n",
" \"distance_threshold\": 0,\n",
" \"metric\": \"euclidean\",\n",
" \"linkage\": \"ward\",\n",
" },\n",
")\n",
"\n",
"linkage_matrix = generate_linkage_matrix(ac_model)\n",
"plot_dendrogram(linkage_matrix, {\"truncate_mode\": \"level\", \"p\": 3})\n",
"plt.show()\n",
"clusters = [6]\n",
"regions_clustered = cluster_regions(\n",
" linkage_matrix, embeddings, regions, clusters #[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]\n",
")\n",
"plot_clustered_regions_with_roads(regions_clustered, edges, area, clusters)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
140 changes: 140 additions & 0 deletions clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
This is a boilerplate pipeline 'visualizations'
generated using Kedro 0.18.7
"""
from typing import Any, Dict, List, Tuple

import contextily as ctx
import geopandas as gpd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.figure import Figure
from scipy.cluster.hierarchy import cut_tree, dendrogram
from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import StandardScaler
from tqdm.auto import tqdm


MAP_SOURCE = ctx.providers.CartoDB.Positron
MATPLOTLIB_COLORMAP = "tab20"
PLOTLY_COLORMAP = list(
map(
lambda color: f"rgb{tuple(map(lambda color_compound: color_compound * 255, color))}",
matplotlib.colormaps[MATPLOTLIB_COLORMAP].colors,
)
)


def scale_embeddings(embeddings: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame(
StandardScaler().fit_transform(embeddings),
index=embeddings.index,
columns=embeddings.columns,
)


def generate_clustering_model(
embeddings: pd.DataFrame, clustering_params: Dict[str, Any]
):
model = AgglomerativeClustering(
n_clusters=clustering_params["n_clusters"],
distance_threshold=clustering_params["distance_threshold"],
metric=clustering_params["metric"],
linkage=clustering_params["linkage"],
)
model.fit(embeddings)

return model


def generate_linkage_matrix(model: AgglomerativeClustering) -> np.ndarray:
counts = np.zeros(model.children_.shape[0])
n_samples = len(model.labels_)
for i, merge in enumerate(model.children_):
current_count = 0
for child_idx in merge:
if child_idx < n_samples:
current_count += 1 # leaf node
else:
current_count += counts[child_idx - n_samples]
counts[i] = current_count

linkage_matrix = np.column_stack(
[model.children_, model.distances_, counts]
).astype(float)

return linkage_matrix


def plot_dendrogram(
linkage_matrix: np.ndarray, dendrogram_params: Dict[str, Any]
) -> Figure:
fig, _ = plt.subplots(figsize=(12, 7))
plt.xlabel("Number of microregions")
dendrogram(linkage_matrix, **dendrogram_params)
plt.tight_layout()
return fig


def cluster_regions(
linkage_matrix: np.ndarray,
embeddings: gpd.GeoDataFrame,
regions: gpd.GeoDataFrame,
clusters: List[int],
) -> gpd.GeoDataFrame:
regions_clustered = regions.loc[embeddings.index, :]

cut_tree_results = cut_tree(linkage_matrix, n_clusters=clusters)
for index, c in tqdm(list(enumerate(clusters))):
assigned_clusters = cut_tree_results[:, index]
regions_clustered[f"cluster_{c}"] = pd.Series(
assigned_clusters, index=regions_clustered.index
).astype("category")

return regions_clustered


def plot_clustered_regions_with_roads(
regions_clustered: gpd.GeoDataFrame,
roads: gpd.GeoDataFrame,
area: gpd.GeoDataFrame,
clusters: List[int],
) -> Dict[str, Figure]:
plots = {}
for c in clusters:
cluster_column = f"cluster_{c}"
fig, ax = _pyplot_clustered_regions_with_roads(
regions_clustered.sjoin(area),
roads.sjoin(area),
cluster_column,
title=cluster_column,
)
ax.set_axis_off()
plt.tight_layout()
plots[cluster_column] = fig
# plt.close()

return plots


def _pyplot_clustered_regions_with_roads(
regions: gpd.GeoDataFrame, roads: gpd.GeoDataFrame, column: str, title: str = ""
) -> Tuple[Figure, plt.Axes]:
fig, ax = plt.subplots(figsize=(10, 9))
ax.set_aspect("equal")
ax.set_title(title)
regions.to_crs(epsg=3857).plot(
column=column,
ax=ax,
alpha=0.9,
legend=True,
cmap=MATPLOTLIB_COLORMAP,
vmin=0,
vmax=len(PLOTLY_COLORMAP),
linewidth=0,
)
roads.to_crs(epsg=3857).plot(ax=ax, color="black", alpha=0.5, linewidth=0.2)
ctx.add_basemap(ax, source=MAP_SOURCE)
return fig, ax
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ geopandas==0.13.2
notebook>=6,<7 # RISE not compatible with version 7
RISE==5.7.1
osmnx==1.6.0
contextily==1.3.0
scikit-learn==1.3.0
tqdm==4.65.0
matplotlib==3.7.2