import numpy as np
from ..helpers.common import Node
from ..helpers.search import dfs
[docs]class Graph(object):
r""" Adjancency list and Adjancency matrix
representation of the graph based on the input
"""
def __init__(self, root):
r"""
root: node<object> or str
if str it'll generate node and the following graph
"""
self.nodesPresent = []
if isinstance(root, Node):
self.rootNode = root
elif isinstance(root, str) or isinstance(root, int):
self.rootNode = Node()
self.rootNode.name = root
else:
raise ValueError("Invalid type for root argument")
[docs] def add_node(self, node, parentNodes):
r""" adds node in the graph
node: can be ['int', 'str'] or Node object
parentNodes: can be ['int', 'str'] or Node object
"""
node = Node(node)
for parentNode in parentNodes:
if not isinstance(parentNode, Node):
parentNode = dfs(self.rootNode, parentNode).searchNode
if parentNode == -1:
raise ValueError("parentNode not found")
if not isinstance(node, Node):
node = Node(node)
node.parents.append(parentNode)
parentNode.children.append(node)
[docs] def add_edge(self, node1, node2):
r"""adds edge netween node1 and node2
node1 -> node2
node1: can be ['int', 'str'] or Node object
node2: can be ['int', 'str'] or Node object
"""
if not isinstance(node1, Node):
node1 = dfs(self.rootNode, node1).searchNode
if not isinstance(node2, Node):
node2 = dfs(self.rootNode, node2).searchNode
if (node1 == -1) or (node2 == -1):
raise ValueError("Node1 or Node2 not found")
node1.children.append(node2)
node2.parents.append(node1)
[docs] def delete_edge(self, node1, node2):
r"""removes edge between node1 and node2
deletes directed edge from node1->node2
node1: can be ['int', 'str'] or Node object
node2: can be ['int', 'str'] or Node object
"""
if not isinstance(node1, Node):
node1 = dfs(self.rootNode, node1).searchNode
if not isinstance(node2, Node):
node2 = dfs(self.rootNode, node2).searchNode
if (node1 == -1) or (node2 == -1):
raise ValueError("Node1 or Node2 not found")
for i, child in enumerate(node1.children):
if child.name == node2.name:
node1.children.pop(i)
break
for i, parent in enumerate(node2.parents):
if parent.name == node1.name:
node2.parents.pop(i)
break
return
[docs] def delete_node(self, node):
r""" removes all the edges and node from the graph
node: can be ['int', 'str'] or Node object
"""
if not isinstance(node, Node):
node = dfs(self.rootNode, node).searchNode
if (node == -1):
raise ValueError("Node doesn't exist")
for i, chnode in enumerate(node.children):
for j, pnode in enumerate(chnode.parents):
if pnode.name == node.name:
chnode.parents.pop(j)
break
node.children.pop(i)
del node
pass
[docs] def get_node(self, node):
r""" search and returns the node
node: can be ['int', 'str']
"""
node = dfs(self.rootNode, node).searchNode
if (node == -1):
raise ValueError("Node doesn't exist")
self.current_node = node
return node
[docs] def get_leafnodes(self, return_obj = True):
r""" finds all leaf nodes in the graph
return_obj: bool if true returns list of node object
else returns list of string (node names)
"""
leafnodes = []
def _recursive_(root):
if len(root.children) == 0:
return root
for childNode in root.children:
node = _recursive_(childNode)
if (not (node == None)) and (not (node in leafnodes)):
leafnodes.append(node)
pass
_recursive_(self.rootNode)
if return_obj:
return leafnodes
else:
names = []
for node in leafnodes:
names.append(node.name)
return np.array(names)
[docs] def calculate_conditional(self, nodes, values):
r""" calculates conditional distribution fixing the
values of given nodes
nodes: can be list of ['int', 'str' or Node object]
values: [node_value in same order]
"""
raise NotImplementedError()
[docs] def calculate_marginals(self, nodes):
r""" removes the node and restimates marginals
nodes: can be list of ['int', 'str' or Node object]
"""
raise NotImplementedError()
[docs] def print(self, node):
if node == None:
return
queue = [node]
visited = []
for node in queue:
if not node.name in visited:
print("node: {}, children: {}, parents: {}".format(node.name,
[nd.name for nd in node.children],
[nd.name for nd in node.parents]))
queue.extend(node.children)
visited.append(node.name)