"""Utility functions.
"""
import json
import re
import numpy as np
import logging
logger = logging.getLogger('lavaflow')
# -----------------------------------------------------------------------------
# Numpy encoders
[docs]class NumpyJsonEncoder(json.JSONEncoder):
"""Extension of json.JSONEncoder that supports number arrays as lists only.
"""
[docs] def default(self, obj):
"""Override default method.
Args:
obj (object): object to be serialized
Returns:
obj (object): new object that is serializable (int, float, or list)
"""
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(NumpyJsonEncoder, self).default(obj)
[docs]class NumpyLineEncoder(object):
"""Class for serializing numpy arrays into a single line format for efficiency.
Note that this is intended for numeric numpy arrays with up to three axes.
"""
def __init__(self, separators=(',', ';', ':'), **kwargs):
"""
Args:
separators (list|tuple): separators to use for each axis in order
kwargs (dict): keyword arguments to pass to np.array2string
"""
self.separators = separators
self.kwargs = {
**{
'precision': 5,
},
**kwargs
}
[docs] def encode(self, x):
"""Encode numpy array.
Args:
x (np.ndarray): numpy array
Returns:
s (str): encoded numpy array
Examples:
>>> encoder = NumpyLineEncoder()
>>> x = np.array([[[1,2,3],[4,5,6],[7,8,9]],[[1,2,3],[4,5,6],[7,8,9]]])
>>> print(encoder.encode(x))
1,2,3;4,5,6;7,8,9:1,2,3;4,5,6;7,8,9
"""
if len(x.shape) > 3:
raise ValueError('NumpyLineEncoder does not support len(x.shape) > 3')
options = np.get_printoptions()
np.set_printoptions(nanstr='', suppress=True)
s = np.array2string(np.squeeze(x), separator=',', threshold=x.size)
s = re.sub('[\n ]+', '', s)
i = 0
for v, j in zip(self.separators, x.shape):
s = re.sub('(?<!])' + '\]' * i + ',' + '\[' * i + '(?!\[)', v, s)
i = i + 1
s = re.sub('[\[\]]+', '', s)
if self.kwargs['precision'] == 0:
s = re.sub('\.', '', s)
np.set_printoptions(**options)
return s