Source code for openeo_udf.functions.datacube_pytorch_ml
# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
import xarray
import torch
from openeo_udf.api.datacube import DataCube
from openeo_udf.api.udf_data import UdfData
__license__ = "Apache License, Version 2.0"
__author__ = "Soeren Gebbert"
__copyright__ = "Copyright 2018, Soeren Gebbert"
__maintainer__ = "Soeren Gebbert"
__email__ = "soerengebbert@googlemail.com"
[docs]def hyper_pytorch_ml(udf_data: UdfData):
"""Apply a pre-trained pytorch machine learn model on a hypercube
The model must be a pytorch model that has expects the input data in the constructor
The prediction method must accept a torch.autograd.Variable as input.
Args:
udf_data (UdfData): The UDF data object that hypercubes and vector tiles
Returns:
This function will not return anything, the UdfData object "udf_data" must be used to store the resulting
data.
"""
cube = udf_data.get_datacube_list()[0]
# This is the input data of the model.
input = torch.autograd.Variable(torch.Tensor(cube.array.values))
# Get the first model
mlm = udf_data.get_ml_model_list()[0]
m = mlm.get_model()
# Predict the data
pred = m(input)
result = xarray.DataArray(data=pred.detach().numpy(), dims=cube.array.dims,
coords=cube.array.coords, name=cube.id + "_pytorch")
# Create the new raster collection tile
result_cube = DataCube(array=result)
# Insert the new hypercube in the input object.
udf_data.set_datacube_list([result_cube])