#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""OpenEO Python UDF interface"""
import os
from typing import Optional, Dict
from openeo_udf.server.config import UdfConfiguration
__license__ = "Apache License, Version 2.0"
__author__ = "Soeren Gebbert"
__copyright__ = "Copyright 2018, Soeren Gebbert"
__maintainer__ = "Soeren Gebbert"
__email__ = "soerengebbert@googlemail.com"
[docs]class MachineLearnModelConfig:
"""This class represents a machine learn model. The model will be loaded
at construction, based on the machine learn framework.
The following frameworks are supported:
- sklearn models that are created with sklearn.externals.joblib
- pytorch models that are created with torch.save
>>> from sklearn.ensemble import RandomForestRegressor
>>> from sklearn.externals import joblib
>>> model = RandomForestRegressor(n_estimators=10, max_depth=2, verbose=0)
>>> path = '/tmp/test.pkl.xz'
>>> dummy = joblib.dump(value=model, filename=path, compress=("xz", 3))
>>> m = MachineLearnModelConfig(framework="sklearn", name="test",
... description="Machine learn model", path=path)
>>> m.get_model()# doctest: +ELLIPSIS
... # doctest: +NORMALIZE_WHITESPACE
RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=2,
max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,
oob_score=False, random_state=None, verbose=0, warm_start=False)
>>> m.to_dict() # doctest: +ELLIPSIS
... # doctest: +NORMALIZE_WHITESPACE
{'description': 'Machine learn model', 'name': 'test', 'framework': 'sklearn', 'path': '/tmp/test.pkl.xz', 'md5_hash': None}
>>> d = {'description': 'Machine learn model', 'name': 'test', 'framework': 'sklearn',
... 'path': '/tmp/test.pkl.xz', "md5_hash": None}
>>> m = MachineLearnModelConfig.from_dict(d)
>>> m.get_model() # doctest: +ELLIPSIS
... # doctest: +NORMALIZE_WHITESPACE
RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=2,
max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,
oob_score=False, random_state=None, verbose=0, warm_start=False)
>>> import torch
>>> import torch.nn as nn
>>> model = nn.Module
>>> path = '/tmp/test.pt'
>>> torch.save(model, path)
>>> m = MachineLearnModelConfig(framework="pytorch", name="test",
... description="Machine learn model", path=path)
>>> m.get_model()# doctest: +ELLIPSIS
... # doctest: +NORMALIZE_WHITESPACE
<class 'torch.nn.modules.module.Module'>
>>> m.to_dict() # doctest: +ELLIPSIS
... # doctest: +NORMALIZE_WHITESPACE
{'description': 'Machine learn model', 'name': 'test', 'framework': 'pytorch', 'path': '/tmp/test.pt', 'md5_hash': None}
>>> d = {'description': 'Machine learn model', 'name': 'test', 'framework': 'pytorch',
... 'path': '/tmp/test.pt', "md5_hash": None}
>>> m = MachineLearnModelConfig.from_dict(d)
>>> m.get_model() # doctest: +ELLIPSIS
... # doctest: +NORMALIZE_WHITESPACE
<class 'torch.nn.modules.module.Module'>
"""
def __init__(self, framework: str, name: str, description: str,
path: Optional[str] = None, md5_hash: Optional[str] = None):
"""The constructor to create a machine learn model object
Args:
framework: The name of the framework, pytroch and sklearn are supported
name: The name of the model
description: The description of the model
path: The path to the pre-trained machine learn model that should be applied
md5_hash: The md5 hash of the machine learn model that is located in the local storage
"""
self.framework = framework
self.name = name
self.description = description
self.path = path
self.md5_hash = md5_hash
self.model = None
self.load_model()
[docs] def load_model(self):
"""Load the machine learn model from the path or md5 hash.
Supported model:
- sklearn models that are created with sklearn.externals.joblib
- pytorch models that are created with torch.save
"""
if self.md5_hash is not None:
filepath = os.path.join(UdfConfiguration.machine_learn_storage_path, self.md5_hash)
else:
filepath = self.path
if os.path.exists(filepath) and os.path.isfile(filepath):
if self.framework.lower() in "sklearn":
from sklearn.externals import joblib
self.model = joblib.load(filepath)
if self.framework.lower() in "pytorch":
import torch
self.model = torch.load(filepath)
else:
raise Exception(f"Unable to find the specified machine learn model at path {filepath}")
[docs] def get_model(self):
"""Get the loaded machine learn model. This function will return None if the model was not loaded
:return: the loaded model
"""
return self.model
[docs] def to_dict(self) -> Dict:
return dict(description=self.description, name=self.name,
framework=self.framework, path=self.path, md5_hash=self.md5_hash)
[docs] @staticmethod
def from_dict(machine_learn_model: Dict):
description = machine_learn_model["description"]
name = machine_learn_model["name"]
framework = machine_learn_model["framework"]
path = None
md5_hash = None
if "path" in machine_learn_model:
path = machine_learn_model["path"]
if "md5_hash" in machine_learn_model:
md5_hash = machine_learn_model["md5_hash"]
return MachineLearnModelConfig(description=description, name=name,
framework=framework, path=path, md5_hash=md5_hash)
if __name__ == "__main__":
import doctest
doctest.testmod()