from __future__ import annotations
import functools
import logging
import warnings
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
import pystac
import pystac.extensions.datacube
import pystac.extensions.eo
import pystac.extensions.item_assets
from openeo.internal.jupyter import render_component
from openeo.util import Rfc3339, deep_get
from openeo.utils.normalize import normalize_resample_resolution
_log = logging.getLogger(__name__)
class MetadataException(Exception):
pass
class DimensionAlreadyExistsException(MetadataException):
pass
# TODO: make these dimension classes immutable data classes
# TODO: align better with STAC datacube extension
# TODO: align/adapt/integrate with pystac's datacube extension implementation?
class Dimension:
"""Base class for dimensions."""
def __init__(self, type: str, name: str):
self.type = type
self.name = name
def __repr__(self):
return "{c}({f})".format(
c=self.__class__.__name__,
f=", ".join("{k!s}={v!r}".format(k=k, v=v) for (k, v) in self.__dict__.items())
)
def __eq__(self, other):
return self.__class__ == other.__class__ and self.__dict__ == other.__dict__
def rename(self, name) -> Dimension:
"""Create new dimension with new name."""
return Dimension(type=self.type, name=name)
def rename_labels(self, target, source) -> Dimension:
"""
Rename labels, if the type of dimension allows it.
:param target: List of target labels
:param source: Source labels, or empty list
:return: A new dimension with modified labels, or the same if no change is applied.
"""
# In general, we don't have/manage label info here, so do nothing.
return Dimension(type=self.type, name=self.name)
[docs]
class SpatialDimension(Dimension):
# TODO: align better with STAC datacube extension: e.g. support "axis" (x or y)
DEFAULT_CRS = 4326
def __init__(
self,
name: str,
extent: Union[Tuple[float, float], List[float]],
crs: Union[str, int, dict] = DEFAULT_CRS,
step=None,
):
"""
@param name:
@param extent:
@param crs:
@param step: The space between the values. Use null for irregularly spaced steps.
"""
super().__init__(type="spatial", name=name)
self.extent = extent
self.crs = crs
self.step = step
[docs]
def rename(self, name) -> Dimension:
return SpatialDimension(name=name, extent=self.extent, crs=self.crs, step=self.step)
[docs]
class TemporalDimension(Dimension):
def __init__(self, name: str, extent: Union[Tuple[str, str], List[str]]):
super().__init__(type="temporal", name=name)
self.extent = extent
[docs]
def rename(self, name) -> Dimension:
return TemporalDimension(name=name, extent=self.extent)
[docs]
def rename_labels(self, target, source) -> Dimension:
# TODO should we check if the extent has changed with the new labels?
return TemporalDimension(name=self.name, extent=self.extent)
class Band(NamedTuple):
"""
Simple container class for band metadata.
Based on https://github.com/stac-extensions/eo#band-object
"""
name: str
common_name: Optional[str] = None
# wavelength in micrometer
wavelength_um: Optional[float] = None
aliases: Optional[List[str]] = None
# "openeo:gsd" field (https://github.com/Open-EO/openeo-stac-extensions#GSD-Object)
gsd: Optional[dict] = None
[docs]
class BandDimension(Dimension):
# TODO #575 support unordered bands and avoid assumption that band order is known.
def __init__(self, name: str, bands: List[Band]):
super().__init__(type="bands", name=name)
self.bands = bands
@property
def band_names(self) -> List[str]:
return [b.name for b in self.bands]
@property
def band_aliases(self) -> List[List[str]]:
return [b.aliases for b in self.bands]
@property
def common_names(self) -> List[str]:
return [b.common_name for b in self.bands]
[docs]
def band_index(self, band: Union[int, str]) -> int:
"""
Resolve a given band (common) name/index to band index
:param band: band name, common name or index
:return int: band index
"""
band_names = self.band_names
if isinstance(band, int) and 0 <= band < len(band_names):
return band
elif isinstance(band, str):
common_names = self.common_names
# First try common names if possible
if band in common_names:
return common_names.index(band)
if band in band_names:
return band_names.index(band)
# Check band aliases to still support old band names
aliases = [True if aliases and band in aliases else False for aliases in self.band_aliases]
if any(aliases):
return aliases.index(True)
raise ValueError("Invalid band name/index {b!r}. Valid names: {n!r}".format(b=band, n=band_names))
[docs]
def band_name(self, band: Union[str, int], allow_common=True) -> str:
"""Resolve (common) name or index to a valid (common) name"""
if isinstance(band, str):
if band in self.band_names:
return band
elif band in self.common_names:
if allow_common:
return band
else:
return self.band_names[self.common_names.index(band)]
elif any([True if aliases and band in aliases else False for aliases in self.band_aliases]):
return self.band_names[self.band_index(band)]
elif isinstance(band, int) and 0 <= band < len(self.bands):
return self.band_names[band]
raise ValueError("Invalid band name/index {b!r}. Valid names: {n!r}".format(b=band, n=self.band_names))
[docs]
def filter_bands(self, bands: List[Union[int, str]]) -> BandDimension:
"""
Construct new BandDimension with subset of bands,
based on given band indices or (common) names
"""
return BandDimension(
name=self.name,
bands=[self.bands[self.band_index(b)] for b in bands]
)
[docs]
def append_band(self, band: Band) -> BandDimension:
"""Create new BandDimension with appended band."""
if band.name in self.band_names:
raise ValueError("Duplicate band {b!r}".format(b=band))
return BandDimension(
name=self.name,
bands=self.bands + [band]
)
[docs]
def rename_labels(self, target, source) -> Dimension:
if source:
if len(target) != len(source):
raise ValueError(
"In rename_labels, `target` and `source` should have same number of labels, "
"but got: `target` {t} and `source` {s}".format(t=target, s=source)
)
new_bands = self.bands.copy()
for old_name, new_name in zip(source, target):
band_index = self.band_index(old_name)
the_band = new_bands[band_index]
new_bands[band_index] = Band(
name=new_name,
common_name=the_band.common_name,
wavelength_um=the_band.wavelength_um,
aliases=the_band.aliases,
gsd=the_band.gsd,
)
else:
new_bands = [Band(name=n) for n in target]
return BandDimension(name=self.name, bands=new_bands)
[docs]
def rename(self, name) -> Dimension:
return BandDimension(name=name, bands=self.bands)
[docs]
def contains_band(self, band: Union[int, str]) -> bool:
"""
Check if the given band name or index is present in the dimension.
"""
try:
self.band_index(band)
return True
except ValueError:
return False
class GeometryDimension(Dimension):
# TODO: how to model/store labels of geometry dimension?
def __init__(self, name: str):
super().__init__(name=name, type="geometry")
def rename(self, name) -> Dimension:
return GeometryDimension(name=name)
def rename_labels(self, target, source) -> Dimension:
return GeometryDimension(name=self.name)
class CubeMetadata:
"""
Interface for metadata of a data cube.
Allows interaction with the cube dimensions and their labels (if available).
"""
def __init__(self, dimensions: Optional[List[Dimension]] = None):
# Original collection metadata (actual cube metadata might be altered through processes)
self._dimensions = dimensions
self._band_dimension = None
self._temporal_dimension = None
if dimensions is not None:
for dim in self._dimensions:
# TODO: here we blindly pick last bands or temporal dimension if multiple. Let user choose?
# TODO: add spatial dimension handling?
if dim.type == "bands":
if isinstance(dim, BandDimension):
self._band_dimension = dim
else:
raise MetadataException("Invalid band dimension {d!r}".format(d=dim))
if dim.type == "temporal":
if isinstance(dim, TemporalDimension):
self._temporal_dimension = dim
else:
raise MetadataException("Invalid temporal dimension {d!r}".format(d=dim))
def __eq__(self, o: Any) -> bool:
return isinstance(o, type(self)) and self._dimensions == o._dimensions
def __str__(self) -> str:
bands = self.band_names if self.has_band_dimension() else "no bands dimension"
return f"CubeMetadata({bands} - {self.dimension_names()})"
def _clone_and_update(self, dimensions: Optional[List[Dimension]] = None, **kwargs) -> CubeMetadata:
"""Create a new instance (of same class) with copied/updated fields."""
cls = type(self)
if dimensions is None:
dimensions = self._dimensions
return cls(dimensions=dimensions, **kwargs)
def dimension_names(self) -> List[str]:
return list(d.name for d in self._dimensions)
def assert_valid_dimension(self, dimension: str) -> str:
"""Make sure given dimension name is valid."""
names = self.dimension_names()
if dimension not in names:
raise ValueError(f"Invalid dimension {dimension!r}. Should be one of {names}")
return dimension
def has_band_dimension(self) -> bool:
return isinstance(self._band_dimension, BandDimension)
@property
def band_dimension(self) -> BandDimension:
"""Dimension corresponding to spectral/logic/thematic "bands"."""
if not self.has_band_dimension():
raise MetadataException("No band dimension")
return self._band_dimension
def has_temporal_dimension(self) -> bool:
return isinstance(self._temporal_dimension, TemporalDimension)
@property
def temporal_dimension(self) -> TemporalDimension:
if not self.has_temporal_dimension():
raise MetadataException("No temporal dimension")
return self._temporal_dimension
@property
def spatial_dimensions(self) -> List[SpatialDimension]:
return [d for d in self._dimensions if isinstance(d, SpatialDimension)]
def has_geometry_dimension(self):
return any(isinstance(d, GeometryDimension) for d in self._dimensions)
@property
def geometry_dimension(self) -> GeometryDimension:
for d in self._dimensions:
if isinstance(d, GeometryDimension):
return d
raise MetadataException("No geometry dimension")
@property
def bands(self) -> List[Band]:
"""Get band metadata as list of Band metadata tuples"""
return self.band_dimension.bands
@property
def band_names(self) -> List[str]:
"""Get band names of band dimension"""
return self.band_dimension.band_names
@property
def band_common_names(self) -> List[str]:
return self.band_dimension.common_names
def get_band_index(self, band: Union[int, str]) -> int:
# TODO: eliminate this shortcut for smaller API surface
return self.band_dimension.band_index(band)
def filter_bands(self, band_names: List[Union[int, str]]) -> CubeMetadata:
"""
Create new `CubeMetadata` with filtered band dimension
:param band_names: list of band names/indices to keep
:return:
"""
assert self.band_dimension
return self._clone_and_update(
dimensions=[d.filter_bands(band_names) if isinstance(d, BandDimension) else d for d in self._dimensions]
)
def append_band(self, band: Band) -> CubeMetadata:
"""
Create new `CubeMetadata` with given band added to band dimension.
"""
assert self.band_dimension
return self._clone_and_update(
dimensions=[d.append_band(band) if isinstance(d, BandDimension) else d for d in self._dimensions]
)
def rename_labels(self, dimension: str, target: list, source: list = None) -> CubeMetadata:
"""
Renames the labels of the specified dimension from source to target.
:param dimension: Dimension name
:param target: The new names for the labels.
:param source: The names of the labels as they are currently in the data cube.
:return: Updated metadata
"""
self.assert_valid_dimension(dimension)
return self._clone_and_update(
dimensions=[
d.rename_labels(target=target, source=source) if d.name == dimension else d for d in self._dimensions
]
)
def rename_dimension(self, source: str, target: str) -> CubeMetadata:
"""
Rename source dimension into target, preserving other properties
"""
self.assert_valid_dimension(source)
return self._clone_and_update(
dimensions=[d.rename(name=target) if d.name == source else d for d in self._dimensions]
)
def reduce_dimension(self, dimension_name: str) -> CubeMetadata:
"""Create new CubeMetadata object by collapsing/reducing a dimension."""
# TODO: option to keep reduced dimension (with a single value)?
# TODO: rename argument to `name` for more internal consistency
# TODO: merge with drop_dimension (which does the same).
self.assert_valid_dimension(dimension_name)
loc = self.dimension_names().index(dimension_name)
dimensions = self._dimensions[:loc] + self._dimensions[loc + 1 :]
return self._clone_and_update(dimensions=dimensions)
def reduce_spatial(self) -> CubeMetadata:
"""Create new CubeMetadata object by reducing the spatial dimensions."""
dimensions = [d for d in self._dimensions if not isinstance(d, SpatialDimension)]
return self._clone_and_update(dimensions=dimensions)
def add_dimension(self, name: str, label: Union[str, float], type: Optional[str] = None) -> CubeMetadata:
"""Create new CubeMetadata object with added dimension"""
if any(d.name == name for d in self._dimensions):
raise DimensionAlreadyExistsException(f"Dimension with name {name!r} already exists")
if type == "bands":
dim = BandDimension(name=name, bands=[Band(name=label)])
elif type == "spatial":
dim = SpatialDimension(name=name, extent=[label, label])
elif type == "temporal":
dim = TemporalDimension(name=name, extent=[label, label])
elif type == "geometry":
dim = GeometryDimension(name=name)
else:
dim = Dimension(type=type or "other", name=name)
return self._clone_and_update(dimensions=self._dimensions + [dim])
def _ensure_band_dimension(
self, *, name: Optional[str] = None, bands: List[Union[Band, str]], warning: str
) -> CubeMetadata:
"""
Create new CubeMetadata object, ensuring a band dimension with given bands.
This will override any existing band dimension, and is intended for
special cases where pragmatism necessitates to ignore the original metadata.
For example, to overrule badly/incomplete detected band names from STAC metadata.
.. note::
It is required to specify a warning message as this method is only intended
to be used as temporary stop-gap solution for use cases that are possibly not future-proof.
Enforcing a warning should make that clear and avoid that users unknowingly depend on
metadata handling behavior that is not guaranteed to be stable.
"""
_log.warning(warning or "ensure_band_dimension: overriding band dimension metadata with user-defined bands.")
if name is None:
# Preserve original band dimension name if possible
name = self.band_dimension.name if self.has_band_dimension() else "bands"
bands = [b if isinstance(b, Band) else Band(name=b) for b in bands]
band_dimension = BandDimension(name=name, bands=bands)
return self._clone_and_update(
dimensions=[d for d in self._dimensions if not isinstance(d, BandDimension)] + [band_dimension]
)
def drop_dimension(self, name: str = None) -> CubeMetadata:
"""Create new CubeMetadata object without dropped dimension with given name"""
dimension_names = self.dimension_names()
if name not in dimension_names:
raise ValueError("No dimension named {n!r} (valid names: {ns!r})".format(n=name, ns=dimension_names))
return self._clone_and_update(dimensions=[d for d in self._dimensions if not d.name == name])
def resample_spatial(
self,
resolution: Union[float, Tuple[float, float], List[float]] = 0.0,
projection: Union[int, str, None] = None,
) -> CubeMetadata:
resolution = normalize_resample_resolution(resolution)
if self._dimensions is None:
# Best-effort fallback to work with
dimensions = [
SpatialDimension(name="x", extent=[None, None]),
SpatialDimension(name="y", extent=[None, None]),
]
else:
# Make sure to work with a copy (to edit in-place)
dimensions = list(self._dimensions)
# Find and replace spatial dimensions
spatial_indices = [i for i, d in enumerate(dimensions) if isinstance(d, SpatialDimension)]
if len(spatial_indices) != 2:
raise MetadataException(f"Expected two spatial dimensions but found {spatial_indices=}")
assert len(resolution) == 2
for i, r in zip(spatial_indices, resolution):
dim: SpatialDimension = dimensions[i]
dimensions[i] = SpatialDimension(
name=dim.name,
extent=dim.extent,
crs=projection or dim.crs,
step=r if r != 0.0 else dim.step,
)
return self._clone_and_update(dimensions=dimensions)
def resample_cube_spatial(self, target: CubeMetadata) -> CubeMetadata:
# Replace spatial dimensions with ones from target, but keep other dimensions
dimensions = [d for d in (self._dimensions or []) if not isinstance(d, SpatialDimension)]
dimensions.extend(target.spatial_dimensions)
return self._clone_and_update(dimensions=dimensions)
def metadata_from_stac(url: str) -> CubeMetadata:
"""
Reads the band metadata a static STAC catalog or a STAC API Collection and returns it as a :py:class:`CubeMetadata`
:param url: The URL to a static STAC catalog (STAC Item, STAC Collection, or STAC Catalog) or a specific STAC API Collection
:return: A :py:class:`CubeMetadata` containing the DataCube band metadata from the url.
"""
# TODO move these nested functions and other logic to _StacMetadataParser
def get_band_metadata(eo_bands_location: dict) -> List[Band]:
# TODO: return None iso empty list when no metadata?
return [
Band(name=band["name"], common_name=band.get("common_name"), wavelength_um=band.get("center_wavelength"))
for band in eo_bands_location.get("eo:bands", [])
]
def get_band_names(bands: List[Band]) -> List[str]:
return [band.name for band in bands]
def is_band_asset(asset: pystac.Asset) -> bool:
return "eo:bands" in asset.extra_fields
stac_object = pystac.read_file(href=url)
if isinstance(stac_object, pystac.Item):
item = stac_object
if "eo:bands" in item.properties:
eo_bands_location = item.properties
elif item.get_collection() is not None:
# TODO: Also do asset based band detection (like below)?
eo_bands_location = item.get_collection().summaries.lists
else:
eo_bands_location = {}
bands = get_band_metadata(eo_bands_location)
elif isinstance(stac_object, pystac.Collection):
collection = stac_object
bands = get_band_metadata(collection.summaries.lists)
# Summaries is not a required field in a STAC collection, so also check the assets
for itm in collection.get_items():
band_assets = {asset_id: asset for asset_id, asset in itm.get_assets().items() if is_band_asset(asset)}
for asset in band_assets.values():
asset_bands = get_band_metadata(asset.extra_fields)
for asset_band in asset_bands:
if asset_band.name not in get_band_names(bands):
bands.append(asset_band)
if _PYSTAC_1_9_EXTENSION_INTERFACE and collection.ext.has("item_assets"):
# TODO #575 support unordered band names and avoid conversion to a list.
bands = list(_StacMetadataParser().get_bands_from_item_assets(collection.ext.item_assets))
elif isinstance(stac_object, pystac.Catalog):
catalog = stac_object
bands = get_band_metadata(catalog.extra_fields.get("summaries", {}))
else:
raise ValueError(stac_object)
# At least assume there are spatial dimensions
# TODO #743: are there conditions in which we even should not assume the presence of spatial dimensions?
dimensions = [
SpatialDimension(name="x", extent=[None, None]),
SpatialDimension(name="y", extent=[None, None]),
]
# TODO #743: conditionally include band dimension when there was actual indication of band metadata?
band_dimension = BandDimension(name="bands", bands=bands)
dimensions.append(band_dimension)
# TODO: is it possible to derive the actual name of temporal dimension that the backend will use?
temporal_dimension = _StacMetadataParser().get_temporal_dimension(stac_object)
if temporal_dimension:
dimensions.append(temporal_dimension)
metadata = CubeMetadata(dimensions=dimensions)
return metadata
# Sniff for PySTAC extension API since version 1.9.0 (which is not available below Python 3.9)
# TODO: remove this once support for Python 3.7 and 3.8 is dropped
_PYSTAC_1_9_EXTENSION_INTERFACE = hasattr(pystac.Item, "ext")
class _StacMetadataParser:
"""
Helper to extract openEO metadata from STAC metadata resource
"""
def __init__(self):
# TODO: toggles for how to handle strictness, warnings, logging, etc
pass
def _get_band_from_eo_bands_item(self, eo_band: Union[dict, pystac.extensions.eo.Band]) -> Band:
if isinstance(eo_band, pystac.extensions.eo.Band):
return Band(
name=eo_band.name,
common_name=eo_band.common_name,
wavelength_um=eo_band.center_wavelength,
)
elif isinstance(eo_band, dict) and "name" in eo_band:
return Band(
name=eo_band["name"],
common_name=eo_band.get("common_name"),
wavelength_um=eo_band.get("center_wavelength"),
)
else:
raise ValueError(eo_band)
def get_bands_from_eo_bands(self, eo_bands: List[Union[dict, pystac.extensions.eo.Band]]) -> List[Band]:
"""
Extract bands from STAC `eo:bands` array
:param eo_bands: List of band objects, as dict or `pystac.extensions.eo.Band` instances
"""
# TODO: option to skip bands that failed to parse in some way?
return [self._get_band_from_eo_bands_item(band) for band in eo_bands]
def _get_bands_from_item_asset(
self, item_asset: pystac.extensions.item_assets.AssetDefinition, *, _warn: Callable[[str], None] = _log.warning
) -> Union[List[Band], None]:
"""Get bands from a STAC 'item_assets' asset definition."""
if _PYSTAC_1_9_EXTENSION_INTERFACE and item_asset.ext.has("eo"):
if item_asset.ext.eo.bands is not None:
return self.get_bands_from_eo_bands(item_asset.ext.eo.bands)
elif "eo:bands" in item_asset.properties:
# TODO: skip this in strict mode?
if _PYSTAC_1_9_EXTENSION_INTERFACE:
_warn("Extracting band info from 'eo:bands' metadata, but 'eo' STAC extension was not declared.")
return self.get_bands_from_eo_bands(item_asset.properties["eo:bands"])
def get_bands_from_item_assets(
self, item_assets: Dict[str, pystac.extensions.item_assets.AssetDefinition]
) -> Set[Band]:
"""
Get bands extracted from "item_assets" objects (defined by "item-assets" extension,
in combination with "eo" extension) at STAC Collection top-level,
Note that "item_assets" in STAC is a mapping, so the band order is undefined,
which is why we return a set of bands here.
:param item_assets: a STAC `item_assets` mapping
"""
bands = set()
# Trick to just warn once per collection
_warn = functools.lru_cache()(_log.warning)
for item_asset in item_assets.values():
asset_bands = self._get_bands_from_item_asset(item_asset, _warn=_warn)
if asset_bands:
bands.update(asset_bands)
return bands
def get_temporal_dimension(self, stac_obj: pystac.STACObject) -> Union[TemporalDimension, None]:
"""
Extract the temporal dimension from a STAC Collection/Item (if any)
"""
# TODO: also extract temporal dimension from assets?
if _PYSTAC_1_9_EXTENSION_INTERFACE:
if stac_obj.ext.has("cube") and hasattr(stac_obj.ext, "cube"):
temporal_dims = [
(n, d.extent or [None, None])
for (n, d) in stac_obj.ext.cube.dimensions.items()
if d.dim_type == pystac.extensions.datacube.DimensionType.TEMPORAL
]
if len(temporal_dims) == 1:
name, extent = temporal_dims[0]
return TemporalDimension(name=name, extent=extent)
elif isinstance(stac_obj, pystac.Collection) and stac_obj.extent.temporal:
# No explicit "cube:dimensions": build fallback from "extent.temporal",
# with dimension name "t" (openEO API recommendation).
extent = [Rfc3339(propagate_none=True).normalize(d) for d in stac_obj.extent.temporal.intervals[0]]
return TemporalDimension(name="t", extent=extent)
else:
if isinstance(stac_obj, pystac.Item):
cube_dimensions = stac_obj.properties.get("cube:dimensions", {})
elif isinstance(stac_obj, pystac.Collection):
cube_dimensions = stac_obj.extra_fields.get("cube:dimensions", {})
else:
cube_dimensions = {}
temporal_dims = [
(n, d.get("extent", [None, None])) for (n, d) in cube_dimensions.items() if d.get("type") == "temporal"
]
if len(temporal_dims) == 1:
name, extent = temporal_dims[0]
return TemporalDimension(name=name, extent=extent)