"""Plotting functions."""
import logging
import os
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib import gridspec
from matplotlib import patches
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from tqdm import tqdm
try:
import plotly.graph_objects as go
from plotly.offline import plot as _plot
except ImportError: # pragma: no cover
pass
from pygenstability.optimal_scales import identify_optimal_scales
L = logging.getLogger(__name__)
[docs]def plot_scan(
all_results,
scale_axis=True,
figure_name="scan_results.pdf",
use_plotly=False,
live=True,
plotly_filename="scan_results.html",
):
"""Plot results of pygenstability with matplotlib or plotly.
Args:
all_results (dict): results of pygenstability scan
scale_axis (bool): display scale of scale index on scale axis
figure_name (str): name of matplotlib figure
use_plotly (bool): use matplotlib or plotly backend
live (bool): for plotly backend, open browser with pot
plotly_filename (str): filename of .html figure from plotly
"""
if len(all_results["scales"]) == 1: # pragma: no cover
L.info("Cannot plot the results if only one scale point, we display the result instead:")
L.info(all_results)
return None
if use_plotly:
return plot_scan_plotly(all_results, live=live, filename=plotly_filename)
return plot_scan_plt(all_results, scale_axis=scale_axis, figure_name=figure_name)
[docs]def plot_scan_plotly( # pylint: disable=too-many-branches,too-many-statements,too-many-locals
all_results,
live=False,
filename="clusters.html",
):
"""Plot results of pygenstability with plotly."""
scales = _get_scales(all_results, scale_axis=True)
hovertemplate = str("<b>scale</b>: %{x:.2f}, <br>%{text}<extra></extra>")
if "NVI" in all_results:
nvi_data = all_results["NVI"]
nvi_opacity = 1.0
nvi_title = "Variation of information"
nvi_ticks = True
else: # pragma: no cover
nvi_data = np.zeros(len(scales))
nvi_opacity = 0.0
nvi_title = None
nvi_ticks = False
text = [
f"""Number of communities: {n}, <br> Stability: {np.round(s, 3)},
<br> Normalised Variation Information: {np.round(vi, 3)}, <br> Index: {i}"""
for n, s, vi, i in zip(
all_results["number_of_communities"],
all_results["stability"],
nvi_data,
np.arange(0, len(scales)),
)
]
ncom = go.Scatter(
x=scales,
y=all_results["number_of_communities"],
mode="lines+markers",
hovertemplate=hovertemplate,
name="Number of communities",
xaxis="x2",
yaxis="y4",
text=text,
marker_color="red",
)
if "ttprime" in all_results:
z = all_results["ttprime"]
showscale = True
tprime_title = "log10(scale)"
else: # pragma: no cover
z = np.nan + np.zeros([len(scales), len(scales)])
showscale = False
tprime_title = None
ttprime = go.Heatmap(
z=z,
x=scales,
y=scales,
colorscale="YlOrBr_r",
yaxis="y2",
xaxis="x2",
hoverinfo="skip",
colorbar={"title": "VI", "len": 0.2, "yanchor": "middle", "y": 0.5},
showscale=showscale,
)
if "stability" in all_results:
stab = go.Scatter(
x=scales,
y=all_results["stability"],
mode="lines+markers",
hovertemplate=hovertemplate,
text=text,
name="Stability",
marker_color="blue",
)
vi = go.Scatter(
x=scales,
y=nvi_data,
mode="lines+markers",
hovertemplate=hovertemplate,
text=text,
name="NVI",
yaxis="y3",
xaxis="x",
marker_color="green",
opacity=nvi_opacity,
)
layout = go.Layout(
yaxis={
"title": "Stability",
"titlefont": {"color": "blue"},
"tickfont": {"color": "blue"},
"domain": [0.0, 0.28],
},
yaxis2={
"title": tprime_title,
"titlefont": {"color": "black"},
"tickfont": {"color": "black"},
"domain": [0.32, 1],
"side": "right",
"range": [scales[0], scales[-1]],
},
yaxis3={
"title": nvi_title,
"titlefont": {"color": "green"},
"tickfont": {"color": "green"},
"showticklabels": nvi_ticks,
"overlaying": "y",
"side": "right",
},
yaxis4={
"title": "Number of communities",
"titlefont": {"color": "red"},
"tickfont": {"color": "red"},
"overlaying": "y2",
},
xaxis={"range": [scales[0], scales[-1]]},
xaxis2={"range": [scales[0], scales[-1]]},
)
fig = go.Figure(data=[stab, ncom, vi, ttprime], layout=layout)
fig.update_layout(xaxis_title="log10(scale)")
if filename is not None:
_plot(fig, filename=filename, auto_open=live)
return fig, layout
[docs]def plot_single_partition(
graph, all_results, scale_id, edge_color="0.5", edge_width=0.5, node_size=100
):
"""Plot the community structures for a given scale.
Args:
graph (networkx.Graph): graph to plot
all_results (dict): results of pygenstability scan
scale_id (int): index of scale to plot
folder (str): folder to save figures
edge_color (str): color of edges
edge_width (float): width of edges
node_size (float): size of nodes
ext (str): extension of figures files
"""
if any("pos" not in graph.nodes[u] for u in graph):
pos = nx.spring_layout(graph)
for u in graph:
graph.nodes[u]["pos"] = pos[u]
pos = {u: graph.nodes[u]["pos"] for u in graph}
node_color = all_results["community_id"][scale_id]
nx.draw_networkx_nodes(
graph,
pos=pos,
node_color=node_color,
node_size=node_size,
cmap=plt.get_cmap("tab20"),
)
nx.draw_networkx_edges(graph, pos=pos, width=edge_width, edge_color=edge_color)
plt.axis("off")
plt.title(
str(r"$log_{10}(scale) =$ ")
+ str(np.round(np.log10(all_results["scales"][scale_id]), 2))
+ ", with "
+ str(all_results["number_of_communities"][scale_id])
+ " communities"
)
[docs]def plot_optimal_partitions(
graph,
all_results,
edge_color="0.5",
edge_width=0.5,
folder="optimal_partitions",
ext=".pdf",
show=False,
):
"""Plot the community structures at each optimal scale.
Args:
graph (networkx.Graph): graph to plot
all_results (dict): results of pygenstability scan
edge_color (str): color of edges
edge_width (float): width of edgs
folder (str): folder to save figures
ext (str): extension of figures files
show (bool): show each plot with plt.show() or not
"""
if not os.path.isdir(folder):
os.mkdir(folder)
if "selected_partitions" not in all_results: # pragma: no cover
identify_optimal_scales(all_results)
selected_scales = all_results["selected_partitions"]
n_selected_scales = len(selected_scales)
if n_selected_scales == 0: # pragma: no cover
return
for optimal_scale_id in selected_scales:
plot_single_partition(
graph, all_results, optimal_scale_id, edge_color=edge_color, edge_width=edge_width
)
plt.savefig(f"{folder}/scale_{optimal_scale_id}{ext}", bbox_inches="tight")
if show: # pragma: no cover
plt.show()
[docs]def plot_communities(
graph, all_results, folder="communities", edge_color="0.5", edge_width=0.5, ext=".pdf"
):
"""Plot the community structures at each scale in a folder.
Args:
graph (networkx.Graph): graph to plot
all_results (dict): results of pygenstability scan
folder (str): folder to save figures
edge_color (str): color of edges
edge_width (float): width of edgs
ext (str): extension of figures files
"""
if not os.path.isdir(folder):
os.mkdir(folder)
mpl_backend = matplotlib.get_backend()
matplotlib.use("Agg")
for scale_id in tqdm(range(len(all_results["scales"]))):
plt.figure()
plot_single_partition(
graph, all_results, scale_id, edge_color=edge_color, edge_width=edge_width
)
plt.savefig(os.path.join(folder, "scale_" + str(scale_id) + ext), bbox_inches="tight")
plt.close()
matplotlib.use(mpl_backend)
[docs]def plot_communities_matrix(graph, all_results, folder="communities_matrix", ext=".pdf"):
"""Plot communities at all scales in matrix form.
Args:
graph (array): as a numpy matrix
all_results (dict): clustring results
folder (str): folder to save figures
ext (str): figure file format
"""
if not os.path.isdir(folder):
os.mkdir(folder)
for scale_id in tqdm(range(len(all_results["scales"]))):
plt.figure()
com_ids = all_results["community_id"][scale_id]
ids = []
lines = [0]
for i in range(len(set(com_ids))):
_ids = list(np.argwhere(com_ids == i).flatten())
lines.append(len(_ids))
ids += _ids
plt.imshow(graph[ids][:, ids], origin="lower")
lines = np.cumsum(lines)
for i in range(len(lines) - 1):
print(
[lines[i], lines[i]],
[lines[i], lines[i + 1]],
[lines[i], lines[i]],
[lines[i + 1], lines[i]],
)
plt.plot((lines[i], lines[i + 1]), (lines[i], lines[i]), c="k")
plt.plot((lines[i], lines[i]), (lines[i], lines[i + 1]), c="k")
plt.plot((lines[i + 1], lines[i + 1]), (lines[i + 1], lines[i]), c="k")
plt.plot((lines[i + 1], lines[i]), (lines[i + 1], lines[i + 1]), c="k")
plt.savefig(os.path.join(folder, "scale_" + str(scale_id) + ext), bbox_inches="tight")
def _get_scales(all_results, scale_axis=True):
"""Get the scale vector."""
if not scale_axis: # pragma: no cover
return np.arange(len(all_results["scales"]))
if all_results["run_params"]["log_scale"]:
return np.log10(all_results["scales"])
return all_results["scales"] # pragma: no cover
def _plot_number_comm(all_results, ax, scales):
"""Plot number of communities."""
ax.plot(scales, all_results["number_of_communities"], "-", c="C3", label="size", lw=2.0)
ax.set_ylim(0, 1.1 * max(all_results["number_of_communities"]))
ax.set_ylabel("# clusters", color="C3")
ax.tick_params("y", colors="C3")
def _plot_ttprime(all_results, ax, scales):
"""Plot ttprime."""
contourf_ = ax.contourf(scales, scales, all_results["ttprime"], cmap="YlOrBr_r", extend="min")
ax.set_ylabel(r"$log_{10}(t^\prime)$")
ax.yaxis.tick_left()
ax.yaxis.set_label_position("left")
ax.axis([scales[0], scales[-1], scales[0], scales[-1]])
ax.set_xlabel(r"$log_{10}(t)$")
axins = inset_axes(
ax,
width="3%",
height="40%",
loc="lower left",
bbox_to_anchor=(0.05, 0.45, 1, 1),
bbox_transform=ax.transAxes,
borderpad=0,
)
axins.tick_params(labelsize=7)
plt.colorbar(contourf_, cax=axins, label="NVI(t,t')")
def _plot_NVI(all_results, ax, scales):
"""Plot variation information."""
ax.plot(scales, all_results["NVI"], "-", lw=2.0, c="C2", label="VI")
ax.yaxis.tick_right()
ax.tick_params("y", colors="C2")
ax.set_ylabel(r"NVI", color="C2")
ax.axhline(1, ls="--", lw=1.0, c="C2")
ax.axis([scales[0], scales[-1], 0.0, np.max(all_results["NVI"]) * 1.1])
ax.set_xlabel(r"$log_{10}(t)$")
def _plot_stability(all_results, ax, scales):
"""Plot stability."""
ax.plot(scales, all_results["stability"], "-", label=r"Stability", c="C0")
ax.tick_params("y", colors="C0")
ax.set_ylabel("Stability", color="C0")
ax.set_ylim(0, 1.1 * max(all_results["stability"]))
ax.yaxis.set_label_position("left")
def _plot_optimal_scales(all_results, ax, scales, ax1, ax2):
"""Plot stability."""
ax.plot(
scales,
all_results["block_detection_curve"],
"-",
lw=2.0,
c="C4",
label="Block NVI",
)
ax.plot(
scales[all_results["selected_partitions"]],
all_results["block_detection_curve"][all_results["selected_partitions"]],
"o",
lw=2.0,
c="C4",
label="optimal scales",
)
ax.tick_params("y", colors="C4")
ax.set_ylabel("Block NVI", color="C4")
ax.yaxis.set_label_position("left")
ax.set_xlabel(r"$log_{10}(t)$")
for scale in scales[all_results["selected_partitions"]]:
ax.axvline(scale, ls="--", color="C4")
ax1.axvline(scale, ls="--", color="C4")
ax2.axvline(scale, ls="--", color="C4")
[docs]def plot_scan_plt(all_results, scale_axis=True, figure_name="scan_results.svg"):
"""Plot results of pygenstability with matplotlib."""
scales = _get_scales(all_results, scale_axis=scale_axis)
gs = gridspec.GridSpec(3, 1, height_ratios=[0.5, 1.0, 0.5])
gs.update(hspace=0)
axes = []
if "ttprime" in all_results:
ax0 = plt.subplot(gs[1, 0])
axes.append(ax0)
_plot_ttprime(all_results, ax=ax0, scales=scales)
ax1 = ax0.twinx()
else: # pragma: no cover
ax1 = plt.subplot(gs[1, 0])
axes.append(ax1)
ax1.set_xticks([])
_plot_NVI(all_results, ax=ax1, scales=scales)
if "ttprime" in all_results:
ax1.yaxis.tick_right()
ax1.yaxis.set_label_position("right")
ax2 = plt.subplot(gs[0, 0])
if "stability" in all_results:
_plot_stability(all_results, ax=ax2, scales=scales)
ax2.set_xticks([])
axes.append(ax2)
if "NVI" in all_results:
ax3 = ax2.twinx()
_plot_number_comm(all_results, ax=ax3, scales=scales)
axes.append(ax3)
if "block_detection_curve" in all_results:
ax4 = plt.subplot(gs[2, 0])
_plot_optimal_scales(all_results, ax=ax4, scales=scales, ax1=ax1, ax2=ax2)
axes.append(ax4)
for ax in axes:
ax.set_xlim(scales[0], scales[-1])
if figure_name is not None:
plt.savefig(figure_name)
return axes
[docs]def plot_clustered_adjacency(
adjacency,
all_results,
scale,
labels=None,
figsize=(12, 10),
cmap="Blues",
figure_name="clustered_adjacency.pdf",
):
"""Plot the clustered adjacency matrix of the graph at a given scale.
Args:
adjacency (ndarray): adjacency matrix to plot
all_results (dict): results of PyGenStability
scale (int): scale index for clustering
labels (list): node labels, or None
figsize (tubple): figure size
cmap (str): colormap for matrix elements
figure_name (str): filename of the figure with extension
"""
comms, counts = np.unique(all_results["community_id"][scale], return_counts=True)
node_ids = []
for comm in comms:
node_ids += list(np.where(all_results["community_id"][scale] == comm)[0])
adjacency = adjacency[np.ix_(node_ids, node_ids)]
adjacency[adjacency == 0] = np.nan
plt.figure(figsize=figsize)
plt.imshow(adjacency, aspect="auto", cmap=cmap)
ax = plt.gca()
pos = 0
for comm, count in zip(comms, counts):
rect = patches.Rectangle(
(pos - 0.5, pos - 0.5),
count,
count,
linewidth=5,
facecolor="none",
edgecolor="g",
)
ax.add_patch(rect)
pos += count
ax.set_xticks(np.arange(len(adjacency)))
ax.set_yticks(np.arange(len(adjacency)))
if labels is not None: # pragma: no cover
labels_plot = [labels[i] for i in node_ids]
ax.set_xticklabels(labels_plot)
ax.set_yticklabels(labels_plot)
plt.colorbar()
plt.xticks(rotation=90)
plt.axis([-0.5, len(adjacency) - 0.5, -0.5, len(adjacency) - 0.5])
plt.suptitle(
"log10(scale) = "
+ str(np.round(np.log10(all_results["scales"][scale]), 2))
+ ", number_of_communities="
+ str(all_results["number_of_communities"][scale])
)
plt.savefig(figure_name, bbox_inches="tight")