Source code for HierMat.utils

"""utils.py: Utilities for the :mod:`HMatrix` module
"""
import math

from HierMat.block_cluster_tree import BlockClusterTree
from HierMat.cluster_tree import ClusterTree
import matplotlib.pyplot as plt


[docs]def export(obj, form='xml', out_file='out'): """Export obj in specified format. :param obj: object to export :type obj: BlockClusterTree or ClusterTree or HMat or RMat :param form: format specifier :type form: str :param out_file: path to output file :type out_file: str :raises NotImplementedError: if form is not supported .. note:: implemented so far: - xml - dot - bin """ if form == 'xml': head = '<?xml version="1.0" encoding="utf-8"?>\n' output = obj.to_xml() output = head + output with open(out_file, "w") as out: out.write(output) elif form == 'dot': head = 'graph {\nnodesep=0.1;\nranksep=1.5;\n' output = obj.to_dot() tail = '}' output = head + output + tail with open(out_file, "w") as out: out.write(output) elif form == 'bin': import pickle file_handle = open(out_file, "wb") pickle.dump(obj, file_handle, protocol=-1) file_handle.close() else: raise NotImplementedError()
[docs]def plot(obj, filename=None, **kwargs): """plot an object :param obj: object to plot :type obj: BlockClusterTree :param filename: filename to save the plot to (if omitted, the plot will be displayed) :type filename: str :param kwargs: optional arguments to specific plot commands see the respective documentations """ if isinstance(obj, BlockClusterTree): return block_cluster_tree_plot(obj, filename, **kwargs) else: raise NotImplementedError('object can not be plotted')
[docs]def block_cluster_tree_plot(obj, filename=None, ticks=False, face_color='#133f52', admissible_color='#76f7a8', inadmissible_color='#ff234b'): """Plot the block cluster tree :param obj: block cluster tree to plot :type obj: BlockClusterTree :param filename: filename to save the plot. if omitted, the plot will be displayed :type filename: str :param ticks: show ticks in the plot :type ticks: bool :param face_color: background color (see matplotlib for color specs) :param admissible_color: color for admissible patch :type admissible_color: str :param inadmissible_color: color for inadmissible patch :type inadmissible_color: str .. note:: depends on :mod:`matplotlib.pyplot` """ plt.rc('axes', linewidth=0.5, labelsize=4) plt.rc('xtick', labelsize=4) plt.rc('ytick', labelsize=4) fig = plt.figure(figsize=(3, 3), dpi=400) fig.patch.set_facecolor(face_color) # get max of the ticks x_min, x_max = obj.left_clustertree.get_patch_coordinates() y_min, y_max = obj.right_clustertree.get_patch_coordinates() axes = plt.axes() axes.set_xlim(x_min, x_max + 1) axes.set_ylim(y_min, y_max + 1) if ticks: x_divisors = list(divisor_generator(x_max + 1)) y_divisors = list(divisor_generator(y_max + 1)) if len(x_divisors) > 4: x_ticks = x_divisors[-4] else: x_ticks = x_divisors[-1] if len(y_divisors) > 4: y_ticks = y_divisors[-4] else: y_ticks = y_divisors[-1] axes.set_xticks(range(x_min, x_max + 2, x_ticks)) axes.set_yticks(range(y_min, y_max + 2, y_ticks)) else: axes.set_xticks([]) axes.set_yticks([]) axes.tick_params(length=2, width=0.5) axes.xaxis.tick_top() axes.invert_yaxis() obj.plot_recursion(axes, admissible_color=admissible_color, inadmissible_color=inadmissible_color) fig.add_axes(axes) if not filename: return fig else: if not ticks: # remove whitespace around the plot plt.subplots_adjust(left=0, right=1, top=1, bottom=0) else: plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1) plt.savefig(filename, format='png', facecolor=fig.get_facecolor(), edgecolor=None)
[docs]def load(filename): """Load a :class:`ClusterTree` or :class:`BlockClusterTree` from file :param filename: file to import :type filename: String :return: object :rtype: BlockClusterTree or ClusterTree .. note:: Depends on :mod:`pickle` """ import pickle with open(filename, 'rb') as infile: obj = pickle.load(infile) return obj
[docs]def divisor_generator(n): """Return divisors of n :param n: integer to find divisors of :type n: int :return: divisors :rtype: list[int] .. warning:: This is a generator! To get a list with all divisors call:: list(divisor_generator(n)) .. note:: found at `StackOverflow <http://stackoverflow.com/questions/171765/what-is-the-best-way-to-get-all-the-divisors-of-a-number>`_ on 2017.03.08 """ large_divisors = [] for i in xrange(1, int(math.sqrt(n) + 1)): if n % i == 0: yield i if i * i != n: large_divisors.append(n / i) for divisor in reversed(large_divisors): yield divisor