Commit 216a8711 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Improved structure of RSImage_ClusterPredictor.predict().


Signed-off-by: Daniel Scheffler's avatarDaniel Scheffler <danschef@gfz-potsdam.de>
parent 67b5eb7d
......@@ -401,6 +401,8 @@ class RSImage_ClusterPredictor(object):
# assign each input pixel to a cluster (compute classification with cluster centers as endmembers)
if self.classif_map is None:
if self.n_clusters > 1:
self.logger.info(f'Assigning material-specific regressors to each image pixel.')
t0 = time.time()
kw_clf = dict(classif_alg=self.classif_alg,
in_nodataVal=image.nodata,
......@@ -452,37 +454,16 @@ class RSImage_ClusterPredictor(object):
self.distance_metrics = np.zeros_like(self.classif_map,
np.float32)
####################
# apply prediction #
####################
self.logger.info(f'Starting prediction with {self.method} regressor, {self.n_clusters} clusters, '
f'{self.classif_alg}.')
# adjust classifier for multiprocessing
if self.CPUs is None or self.CPUs > 1:
# FIXME does not work -> parallelize with https://github.com/ajtulloch/sklearn-compiledtrees?
classifier.n_jobs = cpu_count() if self.CPUs is None else self.CPUs
# get an empty GeoArray for the prediction result
t0 = time.time()
out_nodataVal = out_nodataVal if out_nodataVal is not None else image.nodata
image_predicted = GeoArray(np.empty((image.rows,
image.cols,
classifier.tgt_n_bands),
dtype=image.dtype),
geotransform=image.gt,
projection=image.prj,
nodata=out_nodataVal,
bandnames=['B%s' % i
if len(i) == 2
else 'B0%s' % i
for i in classifier.tgt_LBA])
##############################
# compute prediction weights #
##############################
# compute the weights (only needed in case of multiple kNN classifiers)
if classifier.n_clusters > 1 and\
self.classif_map.ndim > 2:
self.logger.info(f'Computing prediction weights per pixel for each regressor.')
if self.classif_alg == 'kNN_SAM':
# scale SAM values between 0 and 15 degrees spectral angle
dist_min, dist_max = 0, 15
......@@ -510,6 +491,33 @@ class RSImage_ClusterPredictor(object):
# print(self.distance_metrics[0, 0, :])
# print(weights[0, 0, :])
####################
# apply prediction #
####################
self.logger.info(f'Starting prediction with {self.method} regressor, {self.n_clusters} clusters, '
f'{self.classif_alg}.')
# adjust classifier for multiprocessing
if self.CPUs is None or self.CPUs > 1:
# FIXME does not work -> parallelize with https://github.com/ajtulloch/sklearn-compiledtrees?
classifier.n_jobs = cpu_count() if self.CPUs is None else self.CPUs
# get an empty GeoArray for the prediction result
t0 = time.time()
out_nodataVal = out_nodataVal if out_nodataVal is not None else image.nodata
image_predicted = GeoArray(np.empty((image.rows,
image.cols,
classifier.tgt_n_bands),
dtype=image.dtype),
geotransform=image.gt,
projection=image.prj,
nodata=out_nodataVal,
bandnames=['B%s' % i
if len(i) == 2
else 'B0%s' % i
for i in classifier.tgt_LBA])
# set image_predicted to nodata at nodata positions of the input image
if out_nodataVal is not None:
image_predicted[~image.mask_nodata[:]] = out_nodataVal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment