Source code for hidet.utils.structure

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TypeVar, Dict, List, Tuple, Callable

GraphNode = TypeVar('GraphNode')


[docs]class DirectedGraph: """Directed graph. A directed graph representation. """ def __init__(self): self.adj_list: Dict[GraphNode, List[GraphNode]] = {} def __contains__(self, item: GraphNode): return item in self.adj_list def astext(self, node2str: Callable[[GraphNode], str] = object.__str__) -> str: node_name: Dict[GraphNode, str] = {node: node2str(node) for node in self.adj_list} max_length = max(len(v) for v in node_name.values()) buf: List[str] = [] for u in self.adj_list: head = '{{:>{}}}: '.format(max_length).format(node_name[u]) tail = ', '.join([node_name[v] for v in self.adj_list[u]]) buf.append(head + tail) return '\n'.join(buf)
[docs] @staticmethod def from_edges(edges: List[Tuple[GraphNode, GraphNode]]) -> DirectedGraph: """Create a directed graph from edges. The edges should be a list of (src, dst) tuples, and each tuple represents an edge. Parameters ---------- edges: List[Tuple[GraphNode, GraphNode]] The edges of the directed graph to be created. Returns ------- ret: DirectedGraph The created directed graph. """ graph = DirectedGraph() for u, v in edges: graph.add_edge(u, v) return graph
[docs] def has_node(self, node: GraphNode) -> bool: """Whether the node has been added to the graph. Parameters ---------- node: GraphNode The node to be checked. Returns ------- ret: bool True if the node has been added. """ return node in self.adj_list
[docs] def has_edge(self, src: GraphNode, dst: GraphNode) -> bool: """Whether there is an edge (src, dst) in the graph. Parameters ---------- src: GraphNode The source node of the edge. dst: GraphNode The destination node of the edge. Returns ------- ret: bool True if the edge (src, dst) is in the graph. """ return src in self.adj_list and dst in self.adj_list[src]
[docs] def add_node(self, node: GraphNode): """Add a node. Parameters ---------- node: GraphNode The node to be added to the graph. """ if not self.has_node(node): self.adj_list[node] = []
[docs] def add_edge(self, src: GraphNode, dst: GraphNode): """Add an edge. The node `src` and `dst` will be added to the graph if they have not been added. Parameters ---------- src: GraphNode The source of the edge. dst: GraphNode The destination of the edge. """ if not self.has_node(src): self.add_node(src) if not self.has_node(dst): self.add_node(dst) self.adj_list[src].append(dst)
[docs] def topological_order(self) -> List[GraphNode]: """Get a topological order of the nodes in the directed graph. Returns ------- ret: List[GraphNode] The nodes in the topological order. Raises ------ ValueError If the directed graph is cyclic (i.e., there is a loop in the graph). """ in_degree: Dict[GraphNode, int] = {node: 0 for node in self.adj_list} for u in self.adj_list: for v in self.adj_list[u]: in_degree[v] += 1 qu: List[GraphNode] = [] for node, degree in in_degree.items(): if degree == 0: qu.append(node) order: List[GraphNode] = [] while len(qu) > 0: u = qu.pop() order.append(u) for v in self.adj_list[u]: in_degree[v] -= 1 if in_degree[v] == 0: qu.append(v) if len(order) != len(self.adj_list): raise ValueError('Loop detected during generating topological order for a directed graph.') return order