Skip to content

Commit

Permalink
Merge pull request #147 from kbonney/gis_plotting
Browse files Browse the repository at this point in the history
Pull GIS fixes into current MSX main
  • Loading branch information
dbhart authored Dec 16, 2024
2 parents 35e46ca + 046006a commit 0978643
Show file tree
Hide file tree
Showing 2 changed files with 519 additions and 102 deletions.
284 changes: 188 additions & 96 deletions wntr/graphics/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
water network model.
"""
import logging
import math
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.path as mpath
from matplotlib import animation
import matplotlib as mpl
import numpy as np

try:
import plotly
Expand All @@ -21,6 +25,24 @@

logger = logging.getLogger(__name__)


arrow_verts = [
(0.0, 0.0),
(0.5, 0.5),
(0.5, -0.5),
(0.0, 0.0),
]

arrow_marker = mpath.Path(arrow_verts)

def _get_angle(line, loc=0.5):
# calculate orientation angle
p1 = line.interpolate(loc-0.01, normalized=True)
p2 = line.interpolate(loc+0.01, normalized=True)
angle = math.atan2(p2.y-p1.y, p2.x - p1.x) # radians
angle = math.degrees(angle)
return angle

def _format_node_attribute(node_attribute, wn):

if isinstance(node_attribute, str):
Expand All @@ -42,12 +64,13 @@ def _format_link_attribute(link_attribute, wn):
link_attribute = dict(link_attribute)

return link_attribute

def plot_network(wn, node_attribute=None, link_attribute=None, title=None,
node_size=20, node_range=[None,None], node_alpha=1, node_cmap=None, node_labels=False,
link_width=1, link_range=[None,None], link_alpha=1, link_cmap=None, link_labels=False,
add_colorbar=True, node_colorbar_label='Node', link_colorbar_label='Link',
directed=False, ax=None, show_plot=True, filename=None):

def plot_network(
wn, node_attribute=None, link_attribute=None, title=None,
node_size=20, node_range=None, node_alpha=1, node_cmap=None, node_labels=False,
link_width=1, link_range=None, link_alpha=1, link_cmap=None, link_labels=False,
add_colorbar=True, node_colorbar_label=None, link_colorbar_label=None,
directed=False, legend=False, ax=None, show_plot=True, filename=None):
"""
Plot network graphic
Expand Down Expand Up @@ -126,7 +149,7 @@ def plot_network(wn, node_attribute=None, link_attribute=None, title=None,
ax: matplotlib axes object, optional
Axes for plotting (None indicates that a new figure with a single
axes will be used)
show_plot: bool, optional
If True, show plot with plt.show()
Expand All @@ -137,113 +160,182 @@ def plot_network(wn, node_attribute=None, link_attribute=None, title=None,
-------
ax : matplotlib axes object
"""

if ax is None: # create a new figure
plt.figure(facecolor='w', edgecolor='k')
ax = plt.gca()

# Graph
G = wn.to_graph()
if not directed:
G = G.to_undirected()

# Position
pos = nx.get_node_attributes(G,'pos')
if len(pos) == 0:
pos = None

# Define node properties
add_node_colorbar = add_colorbar
if title is not None:
ax.set_title(title)

aspect = "equal"

tank_marker = "D"
reservoir_marker = "s"

if link_cmap is None:
link_cmap = plt.get_cmap('Spectral_r')
if node_cmap is None:
node_cmap = plt.get_cmap('Spectral_r')

if link_range is None:
link_range = (None, None)
if node_range is None:
node_range = (None, None)

# use attribute name if no other label is provided
if node_colorbar_label is None and isinstance(node_attribute, str):
node_colorbar_label = node_attribute
if link_colorbar_label is None and isinstance(link_attribute, str):
link_colorbar_label = link_attribute

wn_gis = wn.to_gis()
# add node_type so that node assets can be plotted separately
wn_gis.junctions["node_type"] = "Junction"
wn_gis.tanks["node_type"] = "Tank"
wn_gis.reservoirs["node_type"] = "Reservoir"
link_gdf = pd.concat((wn_gis.pipes, wn_gis.pumps, wn_gis.valves))
node_gdf = pd.concat((wn_gis.junctions, wn_gis.tanks, wn_gis.reservoirs))

# Node attribute
node_kwds = {}
node_cbar = add_colorbar
if node_attribute is not None:
node_gdf["_attribute"] = _format_node_attribute(node_attribute, wn)
node_kwds["column"] = "_attribute"

# handle cbar/cmap
if isinstance(node_attribute, list):
if node_cmap is None:
node_cmap = ['red', 'red']
add_node_colorbar = False

if node_cmap is None:
node_cmap = plt.get_cmap('Spectral_r')
elif isinstance(node_cmap, list):
if len(node_cmap) == 1:
node_cmap = node_cmap*2
node_cmap = custom_colormap(len(node_cmap), node_cmap)

node_attribute = _format_node_attribute(node_attribute, wn)
nodelist,nodecolor = zip(*node_attribute.items())

node_kwds["cmap"] = custom_colormap(2,["red", "red"])
node_cbar = False
elif isinstance(node_attribute, (dict, pd.Series, str)):
node_kwds["cmap"] = node_cmap

# manually extract min/max if no range is given
node_attribute_values = node_gdf[node_kwds["column"]]
if node_range[0] is None:
node_kwds["vmin"] = np.nanmin(node_attribute_values)
else:
node_kwds["vmin"] = node_range[0]
if node_range[1] is None:
node_kwds["vmax"] = np.nanmax(node_attribute_values)
else:
node_kwds["vmax"] = node_range[1]
else:
raise TypeError("attribute must be dict, Series, list, or str")
else:
nodelist = None
nodecolor = 'k'
node_kwds["color"] = "black"
node_cbar = False

node_kwds["alpha"] = node_alpha
node_kwds["markersize"] = node_size

add_link_colorbar = add_colorbar
node_cbar_kwds = {}
node_cbar_kwds["shrink"] = 0.5
node_cbar_kwds["pad"] = 0.0
node_cbar_kwds["alpha"] = node_alpha
node_cbar_kwds["label"] = node_colorbar_label

# Link attribute
link_kwds = {}
link_cbar = add_colorbar
if link_attribute is not None:
link_gdf["_attribute"] = pd.Series(_format_link_attribute(link_attribute, wn))
link_kwds["column"] = "_attribute"

# handle cbar/cmap
if isinstance(link_attribute, list):
if link_cmap is None:
link_cmap = ['red', 'red']
add_link_colorbar = False

if link_cmap is None:
link_cmap = plt.get_cmap('Spectral_r')
elif isinstance(link_cmap, list):
if len(link_cmap) == 1:
link_cmap = link_cmap*2
link_cmap = custom_colormap(len(link_cmap), link_cmap)
link_kwds["cmap"] = custom_colormap(2,["red", "red"])
link_cbar = False
elif isinstance(link_attribute, (dict, pd.Series, str)):
link_kwds["cmap"] = link_cmap

link_attribute = _format_link_attribute(link_attribute, wn)

# Replace link_attribute dictionary defined as
# {link_name: attr} with {(start_node, end_node, link_name): attr}
attr = {}
for link_name, value in link_attribute.items():
link = wn.get_link(link_name)
attr[(link.start_node_name, link.end_node_name, link_name)] = value
link_attribute = attr

linklist,linkcolor = zip(*link_attribute.items())
# manually extract min/max if no range is given
link_attribute_values = link_gdf[link_kwds["column"]]
if link_range[0] is None:
link_kwds["vmin"] = np.nanmin(link_attribute_values)
else:
link_kwds["vmin"] = link_range[0]
if link_range[1] is None:
link_kwds["vmax"] = np.nanmax(link_attribute_values)
else:
link_kwds["vmax"] = link_range[1]
else:
raise TypeError("attribute must be dict, Series, list, or str")
else:
linklist = None
linkcolor = 'k'
link_kwds["color"] = "black"
link_cbar = False

if title is not None:
ax.set_title(title)

edge_background = nx.draw_networkx_edges(G, pos, edge_color='grey',
width=0.5, ax=ax)
link_kwds["linewidth"] = link_width
link_kwds["alpha"] = link_alpha

background_link_kwds = {}
background_link_kwds["color"] = "grey"
background_link_kwds["linewidth"] = link_width / 2
background_link_kwds["alpha"] = link_alpha

link_cbar_kwds = {}
link_cbar_kwds["shrink"] = 0.5
link_cbar_kwds["pad"] = 0.05
link_cbar_kwds["label"] = link_colorbar_label
link_cbar_kwds["alpha"] = link_alpha

missing_node_kwds={"color": "black"}
missing_link_kwds={"color": "black"}

# plot nodes - each type is plotted separately to allow for different marker types
node_gdf[node_gdf.node_type == "Junction"].plot(
ax=ax, aspect=aspect, zorder=3, legend=False, label="Junction", missing_kwds=missing_node_kwds, **node_kwds)

node_kwds["markersize"] = node_size * 2.0
node_gdf[node_gdf.node_type == "Tank"].plot(
ax=ax, aspect=aspect, zorder=4, marker=tank_marker, legend=False, label="Tank", missing_kwds=missing_node_kwds, **node_kwds)

nodes = nx.draw_networkx_nodes(G, pos,
nodelist=nodelist, node_color=nodecolor, node_size=node_size,
alpha=node_alpha, cmap=node_cmap, vmin=node_range[0], vmax = node_range[1],
linewidths=0, ax=ax)
edges = nx.draw_networkx_edges(G, pos, edgelist=linklist, arrows=directed,
edge_color=linkcolor, width=link_width, alpha=link_alpha, edge_cmap=link_cmap,
edge_vmin=link_range[0], edge_vmax=link_range[1], ax=ax)
node_kwds["markersize"] = node_size * 3.0
node_gdf[node_gdf.node_type == "Reservoir"].plot(
ax=ax, aspect=aspect, zorder=5, marker=reservoir_marker, legend=False, label="Reservoir", missing_kwds=missing_node_kwds,**node_kwds)

if node_cbar:
sm = mpl.cm.ScalarMappable(cmap=node_kwds["cmap"])
sm.set_clim(node_kwds["vmin"], node_kwds["vmax"])

node_cbar = ax.figure.colorbar(sm, ax=ax, **node_cbar_kwds)

# plot links
# background
link_gdf.plot(
ax=ax, aspect=aspect, zorder=1, legend=False, **background_link_kwds)

# main plot
link_gdf.plot(
ax=ax, aspect=aspect, zorder=2, legend=False, missing_kwds=missing_link_kwds, **link_kwds)

if link_cbar:
sm = mpl.cm.ScalarMappable(cmap=link_kwds["cmap"])
sm.set_clim(link_kwds["vmin"], link_kwds["vmax"])

link_cbar = ax.figure.colorbar(sm, ax=ax, **link_cbar_kwds)

if node_labels:
labels = dict(zip(wn.node_name_list, wn.node_name_list))
nx.draw_networkx_labels(G, pos, labels, font_size=7, ax=ax)
for x, y, label in zip(node_gdf.geometry.x, node_gdf.geometry.y, node_gdf.index):
ax.annotate(label, xy=(x, y))#, xytext=(3, 3),)# textcoords="offset points")

if link_labels:
labels = {}
for link_name in wn.link_name_list:
link = wn.get_link(link_name)
labels[(link.start_node_name, link.end_node_name)] = link_name
nx.draw_networkx_edge_labels(G, pos, labels, font_size=7, ax=ax)
if add_node_colorbar and node_attribute:
clb = plt.colorbar(nodes, shrink=0.5, pad=0, ax=ax)
clb.ax.set_title(node_colorbar_label, fontsize=10)
if add_link_colorbar and link_attribute:
if link_range[0] is None:
vmin = min(link_attribute.values())
else:
vmin = link_range[0]
if link_range[1] is None:
vmax = max(link_attribute.values())
else:
vmax = link_range[1]
sm = plt.cm.ScalarMappable(cmap=link_cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
sm.set_array([])
clb = plt.colorbar(sm, shrink=0.5, pad=0.05, ax=ax)
clb.ax.set_title(link_colorbar_label, fontsize=10)

midpoints = link_gdf.geometry.apply(lambda x: x.interpolate(0.5, normalized=True))
for x, y, label in zip(midpoints.geometry.x, midpoints.geometry.y, link_gdf.index):
ax.annotate(label, xy=(x, y))#, xytext=(3, 3),)# textcoords="offset points")

if directed:
link_gdf["_midpoint"] = link_gdf.geometry.interpolate(0.5, normalized=True)
link_gdf["_angle"] = link_gdf.apply(lambda row: _get_angle(row.geometry), axis=1)
for idx , row in link_gdf.iterrows():
x,y = row["_midpoint"].x, row["_midpoint"].y
angle = row["_angle"]
ax.scatter(x,y, color="black", s=50, marker=(3,0, angle-90))

if legend:
handles, labels = ax.get_legend_handles_labels()
leg = ax.legend(handles, labels, loc='upper right', title="Legend")

ax.axis('off')

if filename:
Expand Down
Loading

0 comments on commit 0978643

Please sign in to comment.