Commit d3d617a9 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Added Quadratic Regression as possible algorithm for spectral homogenization.

Added more Ridge Regression classifiers for different alpha values.
parent 144a5e5e
......@@ -27,6 +27,8 @@ import traceback
from sklearn.cluster import k_means_ # noqa F401 # flake8 issue
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline, Pipeline
from geoarray import GeoArray # noqa F401 # flake8 issue
from ..options.config import GMS_config as CFG
......@@ -215,6 +217,7 @@ class SpectralHomogenizer(object):
:param method: machine learning approach to be used for spectral bands prediction
'LR': Linear Regression
'RR': Ridge Regression
'QR': Quadratic Regression
:param src_satellite: source satellite, e.g., 'Landsat-8'
:param src_sensor: source sensor, e.g., 'OLI_TIRS'
:param src_LBA: source LayerBandsAssignment
......@@ -1360,15 +1363,15 @@ class Classifier_Generator(object):
get_LayerBandsAssignment(L1C_GMSid, no_pan=True, sort_by_cwl=False), # L1C_noPan_alphabetical
]
def create_classifiers(self, outDir, method='LR', *args, **kwargs):
def create_classifiers(self, outDir, method='LR', **kwargs):
"""Create classifiers for all combinations of the reference cubes given in __init__().
:param outDir: output directory for the created classifier collections
:param method: type of machine learning classifiers to be included in classifier collections
'LR': Linear Regression
'RR': Ridge Regression
:param args: arguments to be passed to the fit() function of the machine learners
:param kwargs: keyword arguments to be passed to the fit() function of machine learners
'QR': Quadratic Regression
:param kwargs: keyword arguments to be passed to the __init__() function of machine learners
:return:
"""
for src_cube in self.refcubes:
......@@ -1400,8 +1403,8 @@ class Classifier_Generator(object):
test_size=0.4, shuffle=True, random_state=0)
# train the model
ML = specHomoApproaches[method]()
ML.fit(train_X, train_Y, *args, **kwargs)
ML = get_machine_learner(method, **kwargs)
ML.fit(train_X, train_Y)
def mean_absolute_percentage_error(y_true, y_pred):
y_true, y_pred = np.array(y_true), np.array(y_pred)
......@@ -1444,10 +1447,23 @@ class Classifier_Generator(object):
dill.dump(cls_collection.to_dict(), outF)
specHomoApproaches = dict(
LR=LinearRegression,
RR=Ridge
)
def get_machine_learner(method='LR', **init_params):
# type: (str, dict) -> Union[LinearRegression, Ridge, Pipeline]
"""Get an instance of a machine learner.
:param method: 'LR': Linear regression
'RR': Ridge regression
'QR': Quadratic regression
:param init_params: parameters to be passed to __init__() function of the returned machine learner model.
"""
if method == 'LR':
return LinearRegression(**init_params)
elif method == 'RR':
return Ridge(**init_params)
elif method == 'QR':
return make_pipeline(PolynomialFeatures(degree=2), LinearRegression(**init_params))
else:
raise ValueError("Unknown machine learner method code '%s'." % method)
# def get_classifier_filename(method, src_satellite, src_sensor, src_LBA_name, tgt_satellite, tgt_sensor, tgt_LBA_name):
......@@ -1456,6 +1472,9 @@ specHomoApproaches = dict(
def get_filename_classifier_collection(method, src_satellite, src_sensor):
if method == 'RR':
method = method + '_alpha1.0' # TODO add to config
return '__'.join([method, src_satellite, src_sensor]) + '.dill'
......@@ -1519,10 +1538,10 @@ class RSImage_Predictor(object):
tgt_satellite, tgt_sensor, tgt_LBA)
# validation
exp = specHomoApproaches[self.method]
if not isinstance(ML_instance, exp):
expected_type = type(get_machine_learner(self.method))
if not isinstance(ML_instance, expected_type):
raise ValueError('The given dillFile %s does not contain an instance of %s but %s.'
% (os.path.basename(fName_cls), exp.__name__, type(ML_instance)))
% (os.path.basename(fName_cls), expected_type.__name__, type(ML_instance)))
return ML_instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment