Source code for lavaflow.utils.io

"""Utility functions.
"""
import json
import re

import numpy as np

import logging
logger = logging.getLogger('lavaflow')


# -----------------------------------------------------------------------------

# Numpy encoders

[docs]class NumpyJsonEncoder(json.JSONEncoder): """Extension of json.JSONEncoder that supports number arrays as lists only. """
[docs] def default(self, obj): """Override default method. Args: obj (object): object to be serialized Returns: obj (object): new object that is serializable (int, float, or list) """ if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() else: return super(NumpyJsonEncoder, self).default(obj)
[docs]class NumpyLineEncoder(object): """Class for serializing numpy arrays into a single line format for efficiency. Note that this is intended for numeric numpy arrays with up to three axes. """ def __init__(self, separators=(',', ';', ':'), **kwargs): """ Args: separators (list|tuple): separators to use for each axis in order kwargs (dict): keyword arguments to pass to np.array2string """ self.separators = separators self.kwargs = { **{ 'precision': 5, }, **kwargs }
[docs] def encode(self, x): """Encode numpy array. Args: x (np.ndarray): numpy array Returns: s (str): encoded numpy array Examples: >>> encoder = NumpyLineEncoder() >>> x = np.array([[[1,2,3],[4,5,6],[7,8,9]],[[1,2,3],[4,5,6],[7,8,9]]]) >>> print(encoder.encode(x)) 1,2,3;4,5,6;7,8,9:1,2,3;4,5,6;7,8,9 """ if len(x.shape) > 3: raise ValueError('NumpyLineEncoder does not support len(x.shape) > 3') options = np.get_printoptions() np.set_printoptions(nanstr='', suppress=True) s = np.array2string(np.squeeze(x), separator=',', threshold=x.size) s = re.sub('[\n ]+', '', s) i = 0 for v, j in zip(self.separators, x.shape): s = re.sub('(?<!])' + '\]' * i + ',' + '\[' * i + '(?!\[)', v, s) i = i + 1 s = re.sub('[\[\]]+', '', s) if self.kwargs['precision'] == 0: s = re.sub('\.', '', s) np.set_printoptions(**options) return s