Source code for pgm.representation.LinkedListMN

import numpy as np
from ..helpers.common import Node
from ..helpers.search import dfs

[docs]class Graph(object): """ Adjancency list and Adjancency matrix representation of the graph based on the input """ def __init__(self, root): """ root: node<object> or str if str it'll generate node and the following graph """ self.nodesPresent = [] if isinstance(root, Node): self.rootNode = root elif isinstance(root, str) or isinstance(root, int): self.rootNode = Node() self.rootNode.name = root self.rootNode.type = 'MN' else: raise ValueError("Invalid type for root argument")
[docs] def add_node(self, node, nbrNodes): """ adds node in the graph node: can be ['int', 'str'] nbrNodes: can be list(['int' or 'str']) """ node = Node(node) node.type = 'MN' for nbrNode in nbrNodes: nbrNode = dfs(self.rootNode, nbrNode, type = 'MN').searchNode if nbrNode == -1: raise ValueError("nbrNode not found") node.nbrs.append(nbrNode) nbrNode.nbrs.append(node)
[docs] def add_edge(self, node1, node2): """ adds edge netween node1 and node2 node1 -> node2 node1: can be ['int', 'str'] node2: can be ['int', 'str'] """ node1 = dfs(self.rootNode, node1, type = 'MN').searchNode node2 = dfs(self.rootNode, node2, type = 'MN').searchNode if (node1 == -1) or (node2 == -1): raise ValueError("Node1 or Node2 not found") if not node2 in node1.nbrs: node1.nbrs.append(node2) if not node1 in node2.nbrs: node2.nbrs.append(node1)
[docs] def delete_edge(self, node1, node2): """ removes edge between node1 and node2 deletes node1->node2 edge node1: can be ['int', 'str'] node2: can be ['int', 'str'] """ node1 = dfs(self.rootNode, node1, type = 'MN').searchNode node2 = dfs(self.rootNode, node2, type = 'MN').searchNode if (node1 == -1) or (node2 == -1): raise ValueError("Node1 or Node2 not found") for i, nbr in enumerate(node1.nbrs): if nbr.name == node2.name: node1.nbrs.pop(i) break for i, nbr in enumerate(node2.nbrs): if nbr.name == node1.name: node2.nbrs.pop(i) break return
[docs] def delete_node(self, node): """ removes all the edges and node from the graph node: can be ['int', 'str'] """ node = dfs(self.rootNode, node, type = 'MN').searchNode if not (node == -1): raise ValueError ("Node doesn't exist") for i, nbr in enumerate(node.nbrs): for j, nnbr in enumerate(nbr.nbrs): if nnbr.name == node.name: nbr.nbrs.pop(j) break node.nbrs.pop(i) del node return
[docs] def get_node(self, node): r""" search and returns the node node: can be ['int', 'str'] """ node = dfs(self.rootNode, node, type = 'MN').searchNode if (node == -1): raise ValueError("Node doesn't exist") self.current_node = node return node
[docs] def calculate_conditional(self, nodes, values): r""" calculates conditional distribution fixing the values of given nodes nodes: can be list of ['int', 'str' or Node object] values: [node_value in same order] """ raise NotImplementedError()
[docs] def calculate_marginals(self, nodes): r""" removes the node and restimates marginals nodes: can be list of ['int', 'str' or Node object] """ raise NotImplementedError()
[docs] def print(self, node): if node == None: return queue = [node] visited = [] for node in queue: if not node.name in visited: print("node: {}, nbrs: {}".format(node.name, [nd.name for nd in node.nbrs])) queue.extend(node.nbrs) visited.append(node.name)