diff --git a/birdvisu/annotations/__init__.py b/birdvisu/annotations/__init__.py index 88ceb6a..bef0c90 100644 --- a/birdvisu/annotations/__init__.py +++ b/birdvisu/annotations/__init__.py @@ -25,7 +25,7 @@ approach.""" from ..topo_v3 import TopologyV3, VertexID, Edge from collections import defaultdict from collections.abc import Hashable -from dataclasses import dataclass +from dataclasses import dataclass, field from abc import ABC, abstractmethod from typing import Any @@ -65,7 +65,15 @@ class AnnotatedTopology: del self.annotations[ann_id] self.running_annotations.add(ann_id) annotator = ann_id.annotator(ann_id.param) + annotation = annotator.annotate(self) + + if annotation.annotated_topology is not None and annotation.annotated_topology != self: + raise ValueError('Annotator claims to annotate different topology!') + annotation.topology = self + if annotation.annotator_id is not None and annotation.annotator_id != ann_id: + raise ValueError('Annotator fakes its ID!') + annotation.annotator_id = ann_id for v in annotation.for_vertex: self.vertex_annotators[v].add(ann_id) for e in annotation.for_edge: @@ -78,18 +86,24 @@ class AnnotatedTopology: @dataclass(frozen=True) class AnnotatorID: + # If you try creating the Annotator from a factory, please know what are + # the implications for the AnnotatorIDs. You might want to save the ID for + # yourself. annotator: type['Annotator'] param: None | Hashable = None @dataclass class Annotation: - annotator_id: AnnotatorID - annotated_topology: AnnotatedTopology # Use of Any here means "something reasonable and stringifiable". We do not # know whether this can be specified reasonably. - for_vertex: dict[VertexID, Any] - for_edge: dict[Edge, Any] - for_topology: Any | None + for_vertex: dict[VertexID, Any] = field(default_factory=dict) + for_edge: dict[Edge, Any] = field(default_factory=dict) + for_topology: Any | None = None + + # Annotators may return annotation without the two handles, but they are + # filled after they end by AnnotatedTopology.run_annotator. + annotated_topology: AnnotatedTopology | None = None + annotator_id: AnnotatorID | None = None class Annotator(ABC): """Annotator itself. diff --git a/birdvisu/annotations/analysis.py b/birdvisu/annotations/analysis.py new file mode 100644 index 0000000..ad63bd1 --- /dev/null +++ b/birdvisu/annotations/analysis.py @@ -0,0 +1,114 @@ +from . import Annotator, Annotation +from ..topo_v3 import MetricType +import heapq +from enum import Enum +from functools import total_ordering + +class TopologyDifference(Annotator): + """Finds differences between ancestors. + + Currently, we only support the "reference vs. current" comparison, since + that is the most useful case and it is clear which vertices are new and + which old.""" + class Status(Enum): + Missing = 'missing' + New = 'new' + Discrepant = 'discrepant' + + def __init__(self, _param): + self.old_label = 'reference' + self.new_label = 'current' + def annotate(self, atopo): + result = Annotation() + old = atopo.topology.ancestors[self.old_label] + new = atopo.topology.ancestors[self.new_label] + + # Vertices: + oldv = set(old.vertices.keys()) + newv = set(new.vertices.keys()) + # TODO: Maybe match vertices? + only_old = oldv - newv + only_new = newv - oldv + common = oldv & newv + discrepant = set() + for vtxid in common: + o = old.vertices[vtxid] + n = new.vertices[vtxid] + # Only field that can differ is type, assuming consistent topology + if o.type != n.type: + discrepant.add(vtxid) + for v in only_old: result.for_vertex[v] = self.Status.Missing + for v in only_new: result.for_vertex[v] = self.Status.New + for v in discrepant: result.for_vertex[v] = self.Status.Discrepant + + # Edges: + olde = old.edges + newe = new.edges + only_old = olde - newe + only_new = newe - olde + common = olde & newe + discrepant = set() + for edge in common: + o = old.edges[edge] + n = new.edges[edge] + if o != n: discrepant.add(edge) + for e in only_old: result.for_edge[e] = self.Status.Missing + for e in only_new: result.for_edge[e] = self.Status.New + for e in discrepant: result.for_edge[e] = self.Status.Discrepant + + return result + +class ShortestPathTree(Annotator): + """Annotates the shortest path tree edges with True. + + Takes a tuple of the starting vertex and topology ancestor identifier as + the parameter (None for the whole topology). + + Since we have computed the distances of individual vertices, we annotate + the vertices too. The annotations are of form (metric_type, distance) + + If the start vertex is not found, annotates the whole topology with None.""" + def __init__(self, param): + vertex, ancestor = param + self.start_vtxid = vertex + self.ancestor = ancestor + def annotate(self, atopo): + result = Annotation() + topo = atopo.topology.ancestors[self.ancestor] + if self.start_vtxid not in topo.vertices: + result.for_topology = None + return result + # We need a simple wrapper around the edges, so they are comparable + @total_ordering + class CE: + def __init__(self, distance, edge): + self.mt = edge.metric_type + self.dist = distance + self.e = edge + def __lt__(self, other): + return (self.mt, self.dist) < (other.mt, other.dist) + def __ew__(self, other): + return (self.mt, self.dist) == (other.mt, other.dist) + heap = [CE(e.cost, e) for e in topo.vertices[self.start_vtxid].outgoing_edges] + heap.sort() + # We run a bit modified Dijkstra algorithm, since we want to find a DAG + # of shortest paths, not a tree. The heap contains edges together with + # target's distances and we add all the edges that see the specific + # target with the same distance. + distances = { + self.start_vtxid: (MetricType.Type1, 0) + } + + while len(heap) > 0: + ce = heapq.heappop(heap) + if ce.e.target not in distances or distances[ce.e.target] == (ce.mt, ce.dist): + result.for_edge[ce.e] = True + distances[ce.e.target] = (ce.mt, ce.dist) + for oe in topo.vertices[ce.e.target].outgoing_edges: + # Type 2 metrics are external only, so if ce.e.mt == 2, + # this point is not reached (ce.e.target has no outgoing edges) + assert ce.mt == 1 + new_dist = ce.dist + oe.cost if ce.mt == oe.metric_type else oe.cost + heapq.heappush(heap, CE(new_dist, oe)) + result.for_vertex = distances + return result