Source code for pyehm.utils

from typing import Union

import networkx as nx

from _pyehm.utils import Cluster, gen_clusters  # noqa: F401
from .net import EHMNet, EHM2Net, EHM2Tree


[docs]def to_nx_graph(obj: Union[EHMNet, EHM2Net, EHM2Tree]) -> nx.Graph: """Get a NetworkX representation of a net or tree. Mainly used for plotting. Parameters ---------- obj : :class:`~.EHMNet` | :class:`~.EHM2Net` | :class:`~.EHM2Tree` The object to convert to a NetworkX graph. Returns ------- :class:`networkx.Graph` The NetworkX graph representation of the object. """ if isinstance(obj, EHMNet): g = nx.Graph() for child in sorted(obj.nodes, key=lambda x: x.layer): parents = obj.get_parents(child) track = child.layer + 1 if child.layer + 2 < obj.num_layers else None identity = child.identity g.add_node(child.id, track=track, identity=identity) for parent in parents: label = obj.get_edges(parent, child) g.add_edge(parent.id, child.id, detections=label) return g elif isinstance(obj, EHM2Net): g = nx.Graph() for parent in obj.nodes: track = parent.track if parent.track != -1 else None identity = parent.identity g.add_node(parent.id, track=track, identity=identity) for detection in range(obj.validation_matrix.shape[1]): children = obj.get_children_per_detection(parent, detection) for child in children: if not g.has_node(child.id): track = child.track identity = child.identity g.add_node(child.id, track=track, identity=identity) if not g.has_edge(parent.id, child.id): g.add_edge(parent.id, child.id, detections={detection}) else: g[parent.id][child.id]["detections"].add(detection) return g else: g = nx.Graph() return _traverse_tree_nx(obj, g)
def _traverse_tree_nx(tree, g, parent=None): child = g.number_of_nodes() + 1 track = tree.track detections = tree.detections g.add_node(child, track=track, detections=detections) if parent: g.add_edge(parent, child) for sub_tree in tree.children: _traverse_tree_nx(sub_tree, g, child) return g