# -*- coding: utf-8 -*-
from typing import Union, List, Sequence
import networkx
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from networkx.drawing.nx_pydot import graphviz_layout
from networkx.algorithms.components.connected import connected_components
[docs]class EHMNetNode:
"""A node in the :class:`~.EHMNet` constructed by :class:`~.EHM`.
Parameters
----------
layer: :class:`int`
Index of the network layer in which the node is placed. Since a different layer in the network is built for
each track, this also represented the index of the track this node relates to.
identity: :class:`set` of :class:`int`
The identity of the node. As per Section 3.1 of [EHM1]_, "the identity for each node is an indication of how
measurement assignments made for tracks already considered affect assignments for tracks remaining to be
considered".
"""
def __init__(self, layer, identity=None):
# Index of the layer (track) in the network
self.layer = layer
# Identity of the node
self.identity = identity if identity else set()
# Index of the node when added to the network. This is set by the network and
# should not be edited.
self.ind = None
def __repr__(self):
return 'EHMNetNode(ind={}, layer={}, identity={})'.format(self.ind, self.layer, self.identity)
[docs]class EHM2NetNode(EHMNetNode):
"""A node in the :class:`~.EHMNet` constructed by :class:`~.EHM2`.
Parameters
----------
layer: :class:`int`
Index of the network layer in which the node is placed.
track: :class:`int`
Index of track this node relates to.
subnet: :class:`int`
Index of subnet to which the node belongs.
identity: :class:`set` of :class:`int`
The identity of the node. As per Section 3.1 of [EHM1]_, "the identity for each node is an indication of how
measurement assignments made for tracks already considered affect assignments for tracks remaining to be
considered".
"""
def __init__(self, layer, track=None, subnet=0, identity=None):
super().__init__(layer, identity)
# Index of track this node relates to
self.track = track
# Index of subnet the node belongs to
self.subnet = subnet
def __repr__(self):
return 'EHM2NetNode(ind={}, layer={}, track={}, subnet={}, identity={})'.format(self.ind, self.layer,
self.track, self.subnet,
self.identity)
[docs]class EHMNet:
"""Represents the nets constructed by :class:`~.EHM` and :class:`~.EHM2`.
Parameters
----------
nodes: :class:`list` of :class:`~.EHMNetNode` or :class:`~.EHM2NetNode`
The nodes comprising the net.
validation_matrix: :class:`numpy.ndarray`
An indicator matrix of shape (num_tracks, num_detections + 1) indicating the possible
(aka. valid) associations between tracks and detections. The first column corresponds
to the null hypothesis (hence contains all ones).
edges: :class:`dict`
A dictionary that represents the edges between nodes in the network. The dictionary keys are tuples of the form
```(parent, child)```, where ```parent``` and ```child``` are the source and target nodes respectively. The
values of the dictionary are the measurement indices that describe the parent-child relationship.
"""
def __init__(self, nodes, validation_matrix, edges=None):
self._num_layers = 0
self.validation_matrix = validation_matrix
self.edges = edges if edges is not None else dict()
self.parents_per_detection = dict()
self.children_per_detection = dict()
self.nodes_per_track = dict()
self.nodes_per_layer_subnet = dict()
self.nodes_per_identity = dict()
self._parents = dict()
self._children = dict()
self._nodes = nodes
for n_i, node in enumerate(nodes):
node.ind = n_i
if isinstance(node, EHM2NetNode):
if node.layer + 1 > self._num_layers:
self._num_layers = node.layer + 1
else:
if node.layer + 2 > self._num_layers:
self._num_layers = node.layer + 2
# Create layer-subnet-node look-up
if isinstance(node, EHM2NetNode):
try:
self.nodes_per_layer_subnet[(node.layer, node.subnet)].add(node)
except KeyError:
self.nodes_per_layer_subnet[(node.layer, node.subnet)] = {node}
@property
def root(self) -> Union[EHMNetNode, EHM2NetNode]:
"""The root node of the net."""
return self.nodes[0]
@property
def num_nodes(self) -> int:
"""Number of nodes in the net"""
return len(self._nodes)
@property
def num_layers(self) -> int:
"""Number of layers in the net"""
return self._num_layers
@property
def nodes(self) -> Union[List[EHMNetNode], List[EHM2NetNode]]:
"""The nodes comprising the net"""
return self._nodes
@property
def nodes_forward(self) -> Union[Sequence[EHMNetNode], Sequence[EHM2NetNode]]:
"""The net nodes, ordered by increasing layer"""
return sorted(self.nodes, key=lambda x: x.layer)
@property
def nx_graph(self) -> networkx.Graph:
"""A NetworkX representation of the net. Mainly used for plotting the net."""
g = nx.Graph()
for child in sorted(self.nodes, key=lambda x: x.layer):
parents = self.get_parents(child)
if isinstance(child, EHM2NetNode):
track = child.track
else:
track = child.layer + 1 if child.layer + 2 < self.num_layers else None
identity = child.identity
g.add_node(child.ind, track=track, identity=identity)
for parent in parents:
label = str(self.edges[(parent, child)]).replace('{', '').replace('}', '')
g.add_edge(parent.ind, child.ind, detections=label)
return g
[docs] def add_node(self, node: Union[EHMNetNode, EHM2NetNode], parent: Union[EHMNetNode, EHM2NetNode], detection: int):
"""Add a new node in the network.
Parameters
----------
node: :class:`~.EHMNetNode` or :class:`~.EHM2NetNode`
The node to be added.
parent: :class:`~.EHMNetNode` or :class:`~.EHM2NetNode`
The parent of the node.
detection: :class:`int`
Index of measurement representing the parent child relationship.
"""
# Set the node index
node.ind = len(self.nodes)
# Add node to graph
self.nodes.append(node)
# Create edge from parent to child
self.edges[(parent, node)] = {detection}
# Create parent-child-detection look-up
self.parents_per_detection[(node, detection)] = {parent}
self._parents[node] = {parent}
try:
self.children_per_detection[(parent, detection)].add(node)
except KeyError:
self.children_per_detection[(parent, detection)] = {node}
try:
self._children[parent].add(node)
except KeyError:
self._children[parent] = {node}
if isinstance(node, EHM2NetNode):
if node.layer + 1 > self._num_layers:
self._num_layers = node.layer + 1
# Create layer-subnet-node look-up
try:
self.nodes_per_layer_subnet[(node.layer, node.subnet)].add(node)
except KeyError:
self.nodes_per_layer_subnet[(node.layer, node.subnet)] = {node}
else:
if node.layer + 2 > self._num_layers:
self._num_layers = node.layer + 2
[docs] def add_edge(self, parent: Union[EHMNetNode, EHM2NetNode], child: Union[EHMNetNode, EHM2NetNode], detection: int):
""" Add edge between two nodes, or update an already existing edge by adding the detection to it.
Parameters
----------
parent: :class:`~.EHMNetNode` or :class:`~.EHM2NetNode`
The parent node, i.e. the source of the edge.
child: :class:`~.EHMNetNode` or :class:`~.EHM2NetNode`
The child node, i.e. the target of the edge.
detection: :class:`int`
Index of measurement representing the parent child relationship.
"""
try:
self.edges[(parent, child)].add(detection)
except KeyError:
self.edges[(parent, child)] = {detection}
try:
self.parents_per_detection[(child, detection)].add(parent)
except KeyError:
self.parents_per_detection[(child, detection)] = {parent}
try:
self.children_per_detection[(parent, detection)].add(child)
except KeyError:
self.children_per_detection[(parent, detection)] = {child}
try:
self._children[parent].add(child)
except KeyError:
self._children[parent] = {child}
try:
self._parents[child].add(parent)
except KeyError:
self._parents[child] = {parent}
[docs] def get_parents(self, node: Union[EHMNetNode, EHM2NetNode]) -> Union[Sequence[EHMNetNode], Sequence[EHM2NetNode]]:
"""Get the parents of a node.
Parameters
----------
node: :class:`~.EHMNetNode` or :class:`~.EHM2NetNode`
The node whose parents should be returned
Returns
-------
:class:`list` of :class:`~.EHMNetNode` or :class:`~.EHM2NetNode`
List of parent nodes
"""
try:
parents = list(self._parents[node])
except KeyError:
parents = []
return parents # [edge[0] for edge in self.edges if edge[1] == node]
[docs] def get_children(self, node: Union[EHMNetNode, EHM2NetNode]) -> Union[Sequence[EHMNetNode], Sequence[EHM2NetNode]]:
"""Get the children of a node.
Parameters
----------
node: :class:`~.EHMNetNode` or :class:`~.EHM2NetNode`
The node whose children should be returned
Returns
-------
:class:`list` of :class:`~.EHMNetNode` or :class:`~.EHM2NetNode`
List of child nodes
"""
try:
children = list(self._children[node])
except KeyError:
children = []
return children # [edge[1] for edge in self.edges if edge[0] == node]
[docs] def plot(self, ax: plt.Axes = None, annotate=True):
"""Plot the net.
Parameters
----------
ax: :class:`matplotlib.axes.Axes`
Axes on which to plot the net
annotate: :class:`bool`
Flag that dictates whether or not to draw node and edge labels on the plotted net. The default is ``True``
"""
if ax is None:
fig = plt.figure()
ax = fig.gca()
g = self.nx_graph
pos = graphviz_layout(g, prog="dot")
nx.draw(g, pos, ax=ax, node_size=0)
if annotate:
labels = dict()
for n in g.nodes:
t = g.nodes[n]['track']
s = str(g.nodes[n]['identity']) if len(g.nodes[n]['identity']) else 'Ø'
if t is not None:
labels[n] = '{{{}, {}}}'.format(t, s)
else:
labels[n] = 'Ø'
pos_labels = {}
for node, coords in pos.items():
pos_labels[node] = (coords[0] + 10, coords[1])
nx.draw_networkx_labels(g, pos_labels, ax=ax, labels=labels, horizontalalignment='left')
edge_labels = nx.get_edge_attributes(g, 'detections')
nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels)
[docs]class EHM2Tree:
"""Represents the track tree structure generated by :func:`~pyehm.core.EHM2.construct_tree`.
The :class:`~.EHM2Tree` object represents both a tree as well as the root node in the tree.
Parameters
----------
track: :class:`int`
The index of the track represented by the root node of the tree
children: :class:`list` of :class:`~.EHM2Tree`
Sub-trees that are children of the current tree
detections: :class:`set` of :class:`int`
Set of accumulated detections
subtree: :class:`int`
Index of subtree the current tree belongs to.
"""
def __init__(self, track, children, detections, subtree):
self.track = track
self.children = children
self.detections = detections
self.subtree = subtree
@property
def depth(self) -> int:
"""The depth of the tree"""
depth = 1
c_depth = 0
for child in self.children:
child_depth = child.depth
if child_depth > c_depth:
c_depth = child_depth
return depth + c_depth
@property
def nodes(self) -> List['EHM2Tree']:
"""The nodes/subtrees in the tree"""
nodes = [self]
for child in self.children:
nodes += child.nodes
return nodes
@property
def nx_graph(self) -> networkx.Graph:
"""A NetworkX representation of the tree. Mainly used for plotting the tree."""
g = nx.Graph()
return self._traverse_tree_nx(self, g)
@classmethod
def _traverse_tree_nx(cls, tree, g, parent=None):
child = g.number_of_nodes() + 1
track = tree.track
detections = tree.detections
g.add_node(child, track=track, detections=detections)
if parent:
g.add_edge(parent, child)
for sub_tree in tree.children:
cls._traverse_tree_nx(sub_tree, g, child)
return g
[docs] def plot(self, ax: plt.Axes = None):
"""Plot the tree.
Parameters
----------
ax: :class:`matplotlib.axes.Axes`
Axes on which to plot the tree
"""
if ax is None:
fig = plt.figure()
ax = fig.gca()
g = self.nx_graph
pos = graphviz_layout(g, prog="dot")
nx.draw(g, pos, ax=ax)
labels = {n: g.nodes[n]['track'] for n in g.nodes} # if g.nodes[n]['leaf']}
pos_labels = {}
for node, coords in pos.items():
# if g.nodes[node]['leaf']:
pos_labels[node] = (coords[0], coords[1])
nx.draw_networkx_labels(g, pos_labels, ax=ax, labels=labels, font_color='white')
[docs]class Cluster:
"""A cluster of tracks sharing common detections.
Parameters
----------
tracks: :class:`list` of `int`
Indices of tracks in cluster
detections: :class:`list` of `int`
Indices of detections in cluster
validation_matrix: :class:`numpy.ndarray`
The validation matrix for tracks and detections in the cluster
likelihood_matrix: :class:`numpy.ndarray`
The likelihood matrix for tracks and detections in the cluster
"""
def __init__(self, tracks=None, detections=None, validation_matrix=None, likelihood_matrix=None):
self.tracks = tracks
self.detections = detections
self.validation_matrix = validation_matrix
self.likelihood_matrix = likelihood_matrix
[docs]def gen_clusters(validation_matrix, likelihood_matrix=None):
"""Cluster tracks into groups sharing detections
Parameters
----------
validation_matrix: :class:`numpy.ndarray`
An indicator matrix of shape (num_tracks, num_detections + 1) indicating the possible
(aka. valid) associations between tracks and detections. The first column corresponds
to the null hypothesis (hence contains all ones).
likelihood_matrix: :class:`numpy.ndarray`
A matrix of shape (num_tracks, num_detections + 1) containing the unnormalised
likelihoods for all combinations of tracks and detections. The first column corresponds
to the null hypothesis. The default is None, in which case the likelihood matrices of
the generated clusters will also be None.
Returns
-------
list of :class:`Cluster` objects
A list of :class:`Cluster` objects, where each cluster contains the indices of the rows
(tracks) and columns (detections) pertaining to the cluster
list of int
A list of row (track) indices that have not been associated to any detections
"""
# Validation matrix for all detections except null
validation_matrix_true = validation_matrix[:, 1:]
# Initiate parameters
num_tracks, num_detections = np.shape(validation_matrix_true) # Number of tracks
# Form clusters of tracks sharing measurements
missed_tracks = set([i for i in range(num_tracks)])
clusters = list()
# List of tracks gated for each detection
v_lists = [np.flatnonzero(validation_matrix_true[:, detection]) for detection in range(num_detections)]
# Get clusters of tracks sharing common detections
G = to_graph(v_lists)
track_clusters = [t for t in connected_components(G)]
# Create cluster objects that contain the indices of tracks (rows) and detections (cols)
for tracks in track_clusters:
v_detections = {0}
for track in tracks:
v_detections |= set(np.flatnonzero(validation_matrix_true[track, :]) + 1)
# Extract validation and likelihood matrices for cluster
tracks = sorted(tracks)
v_detections = sorted(v_detections)
c_validation_matrix = validation_matrix[tracks, :][:, v_detections]
if likelihood_matrix is not None:
c_likelihood_matrix = likelihood_matrix[tracks, :][:, v_detections]
else:
c_likelihood_matrix = None
clusters.append(Cluster(tracks, v_detections, c_validation_matrix, c_likelihood_matrix))
# Get tracks (rows) that are not associated to any detections
detected_tracks = set([j for i in track_clusters for j in i])
missed_tracks = missed_tracks - detected_tracks
return clusters, list(missed_tracks)
def to_graph(lst):
G = nx.Graph()
for part in lst:
# each sublist is a bunch of nodes
G.add_nodes_from(part)
# it also implies a number of edges:
G.add_edges_from(to_edges(part))
return G
def to_edges(lst):
"""
treat `l` as a Graph and return it's edges
to_edges(['a','b','c','d']) -> [(a,b),(b,c),(c,d)]
"""
if not len(lst):
return
it = iter(lst)
last = next(it)
for current in it:
yield last, current
last = current