Source code for pyehm.plotting

from typing import Union

import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout

from .net import EHMNet, EHM2Net, EHM2Tree
from .utils import to_nx_graph


[docs]def plot_net(net: Union[EHMNet, EHM2Net], ax: plt.Axes = None, annotate=True): """Plot the net. Parameters ---------- net : :class:`~.EHMNet` | :class:`~.EHM2Net` The net to plot. ax : :class:`matplotlib.axes.Axes` Axes on which to plot the net. If ``None``, a new figure and axes will be created. annotate : :class:`bool` Flag that dictates whether to draw node and edge labels on the plotted net. The default is ``True`` """ if ax is None: fig = plt.figure() ax = fig.gca() g = to_nx_graph(net) pos = graphviz_layout(g, prog="dot") nx.draw(g, pos, ax=ax, node_size=0) if annotate: labels = dict() for n in g.nodes: t = g.nodes[n]['track'] s = str(g.nodes[n]['identity']) if len(g.nodes[n]['identity']) else 'Ø' if t is not None: labels[n] = '{{{}, {}}}'.format(t, s) else: labels[n] = 'Ø' pos_labels = {} for node, coords in pos.items(): pos_labels[node] = (coords[0] + 10, coords[1]) nx.draw_networkx_labels(g, pos_labels, ax=ax, labels=labels, horizontalalignment='left') edge_labels = nx.get_edge_attributes(g, 'detections') for key in edge_labels: edge_labels[key] = str(edge_labels[key]).replace('{', '').replace('}', '') nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels)
[docs]def plot_tree(tree: EHM2Tree, ax: plt.Axes = None, annotate=True): """Plot the tree. Parameters ---------- tree : :class:`~.EHM2Tree` The tree to plot. ax: :class:`matplotlib.axes.Axes` Axes on which to plot the tree. If ``None``, a new figure and axes will be created. annotate : :class:`bool` Flag that dictates whether to draw node labels on the plotted tree. The default is ``True`` """ if ax is None: fig = plt.figure() ax = fig.gca() g = to_nx_graph(tree) pos = graphviz_layout(g, prog="dot") nx.draw(g, pos, ax=ax) labels = {n: g.nodes[n]['track'] for n in g.nodes} # if g.nodes[n]['leaf']} pos_labels = {} if annotate: for node, coords in pos.items(): # if g.nodes[node]['leaf']: pos_labels[node] = (coords[0], coords[1]) nx.draw_networkx_labels(g, pos_labels, ax=ax, labels=labels, font_color='white')