Skip to content

Commit

Permalink
Draw graph (#82)
Browse files Browse the repository at this point in the history
* Add print function for graphs

* Add docstring

* Add labels options as arguments

* Add save to file option

* Add test for plot and further tests

* Fix save to file graph
  • Loading branch information
martinateruzzi authored May 20, 2021
1 parent 2651f4c commit 87d6a87
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 0 deletions.
127 changes: 127 additions & 0 deletions grape/general_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import numpy as np
import networkx as nx
import pandas as pd
from matplotlib import cm
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt

warnings.simplefilter(action='ignore', category=FutureWarning)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
Expand Down Expand Up @@ -220,6 +223,16 @@ def sources(self):

return [idx for idx in self if self.type[idx] == 'SOURCE']

@property
def hubs(self):
"""
:return: list of graph hubs.
:rtype: list
"""

return [idx for idx in self if self.type[idx] == 'HUB']

@property
def users(self):
"""
Expand Down Expand Up @@ -1215,3 +1228,117 @@ def compute_service(self):
splitting[(head, tail)] += 1./usr_per_node[head]

return computed_service, splitting

def print_graph(self, radius=None, initial_pos=None, fixed_nodes=None,
n_iter=500, thresh=0.0001, size=800, border='black', fsize=12,
fcolor='k', family='sans-serif', title='Graph', input_cmap=None,
save_to_file=None):
"""
Print the graph.
Positions of the nodes are generated from a spring layout simulation,
if not asked to be fixed during it.
Initial positions can be specified for the nodes.
Both initial positions and fixed positions can be specified just for
a subset of the nodes.
The shapes of the nodes characterize their type
(SOURCE/HUB/USER/SWITCH).
Nodes color is set as white with black borders by default. If an input
colormap is specified, different areas get colored differently (the
colormap is evenly spaced in color depending on the total number of
areas).
Edges are colored depending on the logic relation specified: black for
SINGLE edges, magenta for AND edges, brown dashed for OR edges.
The font size, family and color for labels can be also specified,
together with the title for the window figure.
:param radius: optimal distance between nodes.
:type radius: float, optional, default to 1/sqrt(n) where n is the
number of nodes in the graph
:param initial_pos: initial positions for nodes as a dictionary with
node as jeys and values as a coordinate list or tuple. If None,
then use random initial positions.
:type initial_pos: dict, optional, default to None
:param fixed_nodes: nodes to keep fixed at initial position. ValueError
raised if `fixed_nodes` specified and `initial_pos` not.
:type fixed_nodes: list, optional, default to None
:param n_iter: maximum number of iterations taken in spring layout
simulation.
:type iter: int, optional, default to 500
:param thresh: threshold for relative error in node position changes.
The iteration stops if the error is below this threshold.
:type thresh: float, optional, default to 0.0001
:param size: size of nodes.
:type size: int, optional, default to 800
:param border: color of node borders.
:type border: color, optional, default to 'black'
:param fsize: font size for text labels.
:type fsize: int, optional, default to 12
:param fcolor: font color string for labels.
:type fcolor: string, optional, default to 'k' (black)
:param ffamily: font family for labels.
:type ffamily: string, optional, default to 'sans-serif'
:param title: title for figure window.
:type title: string, optional, defaut to 'Graph'
:param cmap: colormap for coloring the different areas with different
colors. If None, all nodes are colored as white.
:type cmap: Matplotlib colormap, optional, default to None
:param save_to_file: name of the file where to save the graph drawing.
The extension is guesses from the filename.
Interactive window is rendered in any case.
:type save_to_file: string, optional, default to None
:raises: ValueError
"""

if (fixed_nodes is not None) and (initial_pos is None):
raise ValueError('Fixed requested without given initial positions')
logging.getLogger().setLevel(logging.INFO)

pos = nx.spring_layout(self, k=radius, pos=initial_pos,
fixed=fixed_nodes, iterations=n_iter, threshold=thresh,
seed=3113794652)

shapes = {'SOURCE': 'v', 'USER': '^', 'HUB': 'o', 'SWITCH': 'X'}

all_areas = list(set(self.area.values()))
if input_cmap is None:
mymap = ListedColormap(["white"])
else:
mymap = cm.get_cmap(input_cmap, len(all_areas))
area_indices = {}
for idx in range(len(all_areas)):
area_indices[all_areas[idx]] = idx*(1./len(all_areas))

for node in self:
col = mymap(area_indices[self.area[node]])
col = np.array([col])
nx.draw_networkx_nodes(self, pos, nodelist=[node], node_color=col,
node_shape=shapes[self.type[node]], node_size=size,
edgecolors=border)

or_edges = [(u, v) for (u, v, d) in self.edges(data=True)
if d['father_condition'] == 'OR']
and_edges = [(u, v) for (u, v, d) in self.edges(data=True)
if d['father_condition'] == 'AND']
single_edges = [(u, v) for (u, v, d) in self.edges(data=True)
if d['father_condition'] == 'SINGLE']

nx.draw_networkx_edges(self, pos, edgelist=or_edges, width=3, alpha=0.9,
edge_color='brown', style='dashed', node_size=size)
nx.draw_networkx_edges(self, pos, edgelist=and_edges, width=3,
alpha=0.9, edge_color='violet', node_size=size)
nx.draw_networkx_edges(self, pos, edgelist=single_edges, width=3,
alpha=0.9, edge_color='black', node_size=size)

nx.draw_networkx_labels(self, pos, labels=self.mark, font_size=fsize,
font_color=fcolor)

plt.get_current_fig_manager().canvas.set_window_title(title)
plt.tight_layout()
plt.axis('off')
if save_to_file:
plt.savefig(save_to_file, orientation='landscape', transparent=True)
else:
plt.show()
38 changes: 38 additions & 0 deletions tests/test_input_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,41 @@ def test_initial_service(self):

self.assertDictEqual(initial_service_dict, g.initial_service,
msg=" Wrong INITIAL SERVICE in input ")

def test_initial_sources(self):
"""
Unittest check for sources of GeneralGraph: correct input reading.
"""
g = GeneralGraph()
g.load("tests/TOY_graph.csv")

self.assertEqual(['1', '15'], g.sources, msg=" Wrong SOURCES in input ")

def test_initial_hubs(self):
"""
Unittest check for hubs of GeneralGraph: correct input reading.
"""
g = GeneralGraph()
g.load("tests/TOY_graph.csv")

self.assertEqual(['4', '5', '6', '7', '8', '9', '16', '17', '10', '11',
'19', '12', '14', '13'], g.hubs, msg=" Wrong HUBS in input ")

def test_initial_users(self):
"""
Unittest check for users of GeneralGraph: correct input reading.
"""
g = GeneralGraph()
g.load("tests/TOY_graph.csv")

self.assertEqual(['18'], g.users, msg=" Wrong USERS in input ")

def test_initial_switches(self):
"""
Unittest check for switches of GeneralGraph: correct input reading.
"""
g = GeneralGraph()
g.load("tests/TOY_graph.csv")

self.assertEqual(['2', '3'], g.switches,
msg=" Wrong SWITCHES in input ")
30 changes: 30 additions & 0 deletions tests/test_integer_graph.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""TestOutputGraph to check output of GeneralGraph"""

from unittest import TestCase
from unittest import mock
import math
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from grape.general_graph import GeneralGraph
import grape.general_graph as my_module


def test_nodal_efficiency():
Expand Down Expand Up @@ -288,3 +292,29 @@ def test_service():
np.asarray(sorted(original_service.values())),
np.asarray(sorted(g.service.values())),
err_msg="ORIGINAL SERVICE failure")

@mock.patch("%s.my_module.plt" % __name__)
def test_print_graph(mock_plt):
"""
The following test checks that the number of figures has increased.
"""
g = my_module.GeneralGraph()
g.load("tests/TOY_graph.csv")
g.print_graph(radius=10./math.sqrt(len(g)), title='TOY graph',
input_cmap='viridis')

# Assert plt.show got called once
mock_plt.show.assert_called_once()

class Unittests(TestCase):

def test_clear_non_existing_attribute(self):
"""
The following test the error for trying to delete an attribute of
GeneralGraph that does not exist.
"""
g = GeneralGraph()
g.load("tests/TOY_graph.csv")

with self.assertRaises(ValueError):
g.clear_data('non_existing_attribute')

0 comments on commit 87d6a87

Please sign in to comment.