Source code for openeo.udf.xarraydatacube



# Note: this module was initially developed under the ``openeo-udf`` project (

import collections
import json
import typing
from pathlib import Path
from typing import Union

import numpy
import xarray

from openeo.udf import OpenEoUdfException
from openeo.util import dict_no_none, deep_get

if typing.TYPE_CHECKING:
    # Imports for type checking only (circular import issue at runtime).
    import matplotlib.colors

[docs]class XarrayDataCube: """ This is a thin wrapper around :py:class:`xarray.DataArray` providing a basic "DataCube" interface for openEO UDF usage around multi-dimensional data. """ def __init__(self, array: xarray.DataArray): if not isinstance(array, xarray.DataArray): raise OpenEoUdfException("Argument data must be of type xarray.DataArray") self._array = array def __repr__(self): return f"<{type(self).__name__} shape:{self._array.shape}>"
[docs] def get_array(self) -> xarray.DataArray: """ Get the :py:class:`xarray.DataArray` that contains the data and dimension definition """ return self._array
array = property(fget=get_array) @property def id(self): return
[docs] def to_dict(self) -> dict: """ Convert this hypercube into a dictionary that can be converted into a valid JSON representation >>> example = { ... "id": "test_data", ... "data": [ ... [[0.0, 0.1], [0.2, 0.3]], ... [[0.0, 0.1], [0.2, 0.3]], ... ], ... "dimension": [ ... {"name": "time", "coordinates": ["2001-01-01", "2001-01-02"]}, ... {"name": "X", "coordinates": [50.0, 60.0]}, ... {"name": "Y"}, ... ], ... } """ xd = self._array.to_dict() return dict_no_none({ "id": xd.get("name"), "data": xd.get("data"), "description": deep_get(xd, "attrs", "description", default=None), "dimensions": [ dict_no_none( name=dim, coordinates=deep_get(xd, "coords", dim, "data", default=None) ) for dim in xd.get("dims", []) ] })
[docs] @classmethod def from_dict(cls, xdc_dict: dict) -> "XarrayDataCube": """ Create a :py:class:`XarrayDataCube` from a Python dictionary that was created from the JSON definition of the data cube :param data: The dictionary that contains the data cube definition """ if "data" not in xdc_dict: raise OpenEoUdfException("Missing data in dictionary") data = numpy.asarray(xdc_dict["data"]) if "dimensions" in xdc_dict: dims = [dim["name"] for dim in xdc_dict["dimensions"]] coords = {dim["name"]: dim["coordinates"] for dim in xdc_dict["dimensions"] if "coordinates" in dim} else: dims = None coords = None x = xarray.DataArray(data, dims=dims, coords=coords, name=xdc_dict.get("id")) if "description" in xdc_dict: x.attrs["description"] = xdc_dict["description"] return cls(array=x)
@staticmethod def _guess_format(path: Union[str, Path]) -> str: """Guess file format from file name.""" suffix = Path(path).suffix.lower() if suffix in [".nc", ".netcdf"]: return "netcdf" elif suffix in [".json"]: return "json" else: raise ValueError("Can not guess format of {p}".format(p=path))
[docs] @classmethod def from_file(cls, path: Union[str, Path], fmt=None) -> "XarrayDataCube": """ Load data file as :py:class:`XarrayDataCube` in memory :param path: the file on disk :param fmt: format to load from, e.g. "netcdf" or "json" (will be auto-detected when not specified) :return: loaded data cube """ fmt = fmt or cls._guess_format(path) if fmt.lower() == 'netcdf': return cls(array=XarrayIO.from_netcdf_file(path=path)) elif fmt.lower() == 'json': return cls(array=XarrayIO.from_json_file(path=path)) else: raise ValueError("invalid format {f}".format(f=fmt))
[docs] def save_to_file(self, path: Union[str, Path], fmt=None): """ Store :py:class:`XarrayDataCube` to file :param path: destination file on disk :param fmt: format to save as, e.g. "netcdf" or "json" (will be auto-detected when not specified) """ fmt = fmt or self._guess_format(path) if fmt.lower() == 'netcdf': XarrayIO.to_netcdf_file(array=self.get_array(), path=path) elif fmt.lower() == 'json': XarrayIO.to_json_file(array=self.get_array(), path=path) else: raise ValueError(fmt)
[docs] def plot( self, title: str = None, limits=None, show_bandnames: bool = True, show_dates: bool = True, show_axeslabels: bool = False, fontsize: float = 10., oversample: float = 1, cmap: Union[str, 'matplotlib.colors.ColorMap'] = 'RdYlBu_r', cbartext: str = None, to_file: str = None, to_show: bool = True ): """ Visualize a :py:class:`XarrayDataCube` with matplotlib :param datacube: data to plot :param title: title text drawn in the top left corner (default: nothing) :param limits: range of the contour plot as a tuple(min,max) (default: None, in which case the min/max is computed from the data) :param show_bandnames: whether to plot the column names (default: True) :param show_dates: whether to show the dates for each row (default: True) :param show_axeslabels: whether to show the labels on the axes (default: False) :param fontsize: font size in pixels (default: 10) :param oversample: one value is plotted into oversample x oversample number of pixels (default: 1 which means each value is plotted as a single pixel) :param cmap: built-in matplotlib color map name or ColorMap object (default: RdYlBu_r which is a blue-yellow-red rainbow) :param cbartext: text on top of the legend (default: nothing) :param to_file: filename to save the image to (default: None, which means no file is generated) :param to_show: whether to show the image in a matplotlib window (default: True) :return: None """ from matplotlib import pyplot data = self.get_array() if limits is None: vmin = data.min() vmax = data.max() else: vmin = limits[0] vmax = limits[1] # fill bands and t if missing if 'bands' not in data.dims: data = data.expand_dims(dim={'bands': ['band0']}) if 't' not in data.dims: data = data.expand_dims(dim={'t': [numpy.datetime64('today')]}) if 'bands' not in data.coords: data['bands'] = ['band0'] if 't' not in data.coords: data['t'] = [numpy.datetime64('today')] # align with plot data = data.transpose('t', 'bands', 'y', 'x') dpi = 100 xres = len(data.x) / dpi yres = len(data.y) / dpi fs = fontsize / oversample frame = 0.33 nrow = data.shape[0] ncol = data.shape[1] fig = pyplot.figure(figsize=((ncol + frame) * xres * 1.1, (nrow + frame) * yres), dpi=int(dpi * oversample)) gs = pyplot.GridSpec(nrow, ncol, wspace=0., hspace=0., top=nrow / (nrow + frame), bottom=0., left=frame / (ncol + frame), right=1.) xmin = data.x.min() xmax = data.x.max() ymin = data.y.min() ymax = data.y.max() # flip around if incorrect, this is in harmony with origin='lower' if (data.x[0] > data.x[-1]): data = data.reindex(x=list(reversed(data.x))) if (data.y[0] > data.y[-1]): data = data.reindex(y=list(reversed(data.y))) extent = (data.x[0], data.x[-1], data.y[0], data.y[-1]) for i in range(nrow): for j in range(ncol): im = data[i, j] ax = pyplot.subplot(gs[i, j]) ax.set_xlim(xmin, xmax) ax.set_ylim(ymin, ymax) img = ax.imshow(im, vmin=vmin, vmax=vmax, cmap=cmap, origin='lower', extent=extent) ax.xaxis.set_tick_params(labelsize=fs) ax.yaxis.set_tick_params(labelsize=fs) if not show_axeslabels: ax.set_axis_off() ax.set_xticklabels([]) ax.set_yticklabels([]) if show_bandnames: if i == 0: ax.text(0.5, 1.08, data.bands.values[j] + " (" + str(data.dtype) + ")", size=fs, va="center", ha="center", transform=ax.transAxes) if show_dates: if j == 0: ax.text(-0.08, 0.5, data.t.dt.strftime("%Y-%m-%d").values[i], size=fs, va="center", ha="center", rotation=90, transform=ax.transAxes) if title is not None: fig.text(0., 1., title.split('/')[-1], size=fs, va="top", ha="left", weight='bold') cbar_ax = fig.add_axes([0.01, 0.1, 0.04, 0.5]) if cbartext is not None: fig.text(0.06, 0.62, cbartext, size=fs, va="bottom", ha="center") cbar = fig.colorbar(img, cax=cbar_ax) cbar.outline.set_visible(False) if to_file is not None: pyplot.savefig(str(to_file)) if to_show: pyplot.close()
class XarrayIO: """ Helpers to load/store :py:cass:`xarray.DataArray` objects, with some conventions about expected dimensions/bands """ @classmethod def from_json_file(cls, path: Union[str, Path]) -> xarray.DataArray: with Path(path).open() as f: return cls.from_json(json.load(f)) @classmethod def from_json(cls, d: dict) -> xarray.DataArray: d['data'] = numpy.array(d['data'], dtype=numpy.dtype(d['attrs']['dtype'])) for k, v in d['coords'].items(): # prepare coordinate d['coords'][k]['data'] = numpy.array(v['data'], dtype=v['attrs']['dtype']) # remove dtype and shape, because that is included for helping the user if d['coords'][k].get('attrs', None) is not None: d['coords'][k]['attrs'].pop('dtype', None) d['coords'][k]['attrs'].pop('shape', None) # remove dtype and shape, because that is included for helping the user if d.get('attrs', None) is not None: d['attrs'].pop('dtype', None) d['attrs'].pop('shape', None) # convert to xarray r = xarray.DataArray.from_dict(d) # build dimension list in proper order dims = list(filter(lambda i: i != 't' and i != 'bands' and i != 'x' and i != 'y', r.dims)) if 't' in r.dims: dims += ['t'] if 'bands' in r.dims: dims += ['bands'] if 'x' in r.dims: dims += ['x'] if 'y' in r.dims: dims += ['y'] # return the resulting data array return r.transpose(*dims) @classmethod def from_netcdf_file(cls, path: Union[str, Path]) -> xarray.DataArray: # load the dataset and convert to data array ds = xarray.open_dataset(path, engine='h5netcdf') r = ds.to_array(dim='bands') # build dimension list in proper order dims = list(filter(lambda i: i != 't' and i != 'bands' and i != 'x' and i != 'y', r.dims)) if 't' in r.dims: dims += ['t'] if 'bands' in r.dims: dims += ['bands'] if 'x' in r.dims: dims += ['x'] if 'y' in r.dims: dims += ['y'] # return the resulting data array return r.transpose(*dims) @classmethod def to_json_file(cls, array: xarray.DataArray, path: Union[str, Path]): # to deserialized json jsonarray = array.to_dict() # add attributes that needed for re-creating xarray from json jsonarray['attrs']['dtype'] = str(array.values.dtype) jsonarray['attrs']['shape'] = list(array.values.shape) for i in array.coords.values(): jsonarray['coords'][]['attrs']['dtype'] = str(i.dtype) jsonarray['coords'][]['attrs']['shape'] = list(i.shape) # custom print so resulting json file is humanly easy to read # TODO: make this human friendly JSON format optional and allow compact JSON too. with Path(path).open("w") as f: def custom_print(data_structure, indent=1): f.write("{\n") needs_comma = False for key, value in data_structure.items(): if needs_comma: f.write(',\n') needs_comma = True f.write(' ' * indent + json.dumps(key) + ':') if isinstance(value, dict): custom_print(value, indent + 1) else: json.dump(value, f, default=str, separators=(',', ':')) f.write('\n' + ' ' * (indent - 1) + "}") custom_print(jsonarray) @classmethod def to_netcdf_file(cls, array: xarray.DataArray, path: Union[str, Path]): # temp reference to avoid modifying the original array result = array # rearrange in a basic way because older xarray versions have a bug and ellipsis don't work in xarray.transpose() if result.dims[-2] == 'x' and result.dims[-1] == 'y': l = list(result.dims[:-2]) result = result.transpose(*(l + ['y', 'x'])) # turn it into a dataset where each band becomes a variable if not 'bands' in result.dims: result = result.expand_dims(dim=collections.OrderedDict({'bands': ['band_0']})) else: if not 'bands' in result.coords: labels = ['band_' + str(i) for i in range(result.shape[result.dims.index('bands')])] result = result.assign_coords(bands=labels) result = result.to_dataset('bands') result.to_netcdf(path, engine='h5netcdf')