Source code for openeo_udf.functions.datacube_sklearn_ml

# -*- coding: utf-8 -*-
import numpy
import pandas
import xarray

from openeo_udf.api.datacube import DataCube
from openeo_udf.api.udf_data import UdfData

__license__ = "Apache License, Version 2.0"
__author__ = "Soeren Gebbert"
__copyright__ = "Copyright 2018, Soeren Gebbert"
__maintainer__ = "Soeren Gebbert"
__email__ = "soerengebbert@googlemail.com"


[docs]def rct_sklearn_ml(udf_data: UdfData): """Apply a pre-trained sklearn machine learn model on RED and NIR tiles The model must be a sklearn model that has a prediction method: m.predict(X) The prediction method must accept a pandas.DataFrame as input. Tiles with ids "red" and "nir" are required. The machine learn model will be applied to all spatio-temporal pixel of the two input raster collections. Args: udf_data (UdfData): The UDF data object that contains raster and vector tiles Returns: This function will not return anything, the UdfData object "udf_data" must be used to store the resulting data. """ red = None nir = None # Iterate over each cube for cube in udf_data.get_datacube_list(): if "red" in cube.id.lower(): red = cube if "nir" in cube.id.lower(): nir = cube if red is None: raise Exception("Red data cube is missing in input") if nir is None: raise Exception("Nir data cube is missing in input") # We need to reshape the data for prediction into one dimensional arrays three_dim_shape = red.array.shape one_dim_shape = numpy.prod(three_dim_shape) red_reshape = red.array.values.reshape((one_dim_shape)) nir_reshape = nir.array.values.reshape((one_dim_shape)) # This is the input data of the model. It must be trained with a DataFrame using the same names. X = pandas.DataFrame() X["red"] = red_reshape X["nir"] = nir_reshape # Get the first model mlm = udf_data.get_ml_model_list()[0] m = mlm.get_model() # Predict the data pred = m.predict(X) # Reshape the one dimensional predicted values to three dimensions based on the input shape pred_reshape = pred.reshape(three_dim_shape) result = xarray.DataArray(data=pred_reshape, dims=red.array.dims, coords=red.array.coords, name=red.id + "_pytorch") # Create the new raster collection cube h = DataCube(array=result) # Insert the new hypercubes in the input object. The new tiles will # replace the original input tiles. udf_data.set_datacube_list([h, ])