Commit 55bdd607 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Revised RSImage_ClusterPredictor.predict() to improve speed (reduced...


Revised RSImage_ClusterPredictor.predict() to improve speed (reduced processing time to 70-50% of the previous version).
Signed-off-by: Daniel Scheffler's avatarDaniel Scheffler <danschef@gfz-potsdam.de>
parent 5b27e29d
......@@ -27,12 +27,13 @@
"""Main module."""
import os
import numpy as np
import logging # noqa F401 # flake8 issue
from typing import Union, Tuple # noqa F401 # flake8 issue
from multiprocessing import cpu_count
import traceback
import time
from tqdm import tqdm
import numpy as np
from geoarray import GeoArray # noqa F401 # flake8 issue
from specclassify import classify_image
# from specclassify import kNN_MinimumDistance_Classifier
......@@ -41,6 +42,7 @@ from .classifier import Cluster_Learner
from .exceptions import ClassifierNotAvailableError
from .logging import SpecHomo_Logger
from .options import options
from .utils import spectra2im, im2spectra
__author__ = 'Daniel Scheffler'
......@@ -449,12 +451,12 @@ class RSImage_ClusterPredictor(object):
# apply prediction #
####################
# adjust classifier
# adjust classifier for multiprocessing
if self.CPUs is None or self.CPUs > 1:
# FIXME does not work -> parallelize with https://github.com/ajtulloch/sklearn-compiledtrees?
classifier.n_jobs = cpu_count() if self.CPUs is None else self.CPUs
# NOTE: prediction is applied in 1000 x 1000 tiles to save memory (because classifier.predict returns float32)
# get an empty GeoArray for the prediction result
t0 = time.time()
out_nodataVal = out_nodataVal if out_nodataVal is not None else image.nodata
image_predicted = GeoArray(np.empty((image.rows,
......@@ -469,6 +471,7 @@ class RSImage_ClusterPredictor(object):
else 'B0%s' % i
for i in classifier.tgt_LBA])
# compute the weights (only needed in case of multiple kNN classifiers)
if classifier.n_clusters > 1 and\
self.classif_map.ndim > 2:
......@@ -499,11 +502,32 @@ class RSImage_ClusterPredictor(object):
# print(self.distance_metrics[0, 0, :])
# print(weights[0, 0, :])
# set image_predicted to nodata at nodata positions of the input image
image_predicted[~image.mask_nodata[:]] = out_nodataVal
# NOTE:
# - prediction now only runs on the remaining pixels (that contain data)
# - computation is running in chunks of 50,000 spectra to save memory
# (classifier.predict returns float32) and speed up processing
# ----------------------------------------------------------------------
# get all spectra at pixels that really contain data
spectra_at_datapos = image[image.mask_nodata[:]]
n_spectra = spectra_at_datapos.shape[0]
# we need these spectra + weights + classification map as 3D image arrays
# (as expected by classifier.predict()
spectra_as_im = GeoArray(spectra2im(spectra_at_datapos, n_spectra, 1))
weights_datapos = spectra2im(weights[image.mask_nodata[:]], n_spectra, 1)
classif_map_datapos = spectra2im(self.classif_map[image.mask_nodata[:]], n_spectra, 1)
# spectra_predicted will be filled while looping over chunks
spectra_predicted = np.empty((n_spectra, image_predicted.bands), image_predicted.dtype)
n_saturated_px = 0
for ((rS, rE), (cS, cE)), im_tile in image.tiles(tilesize=(1000, 1000)):
self.logger.info('Predicting tile ((%s, %s), (%s, %s))...' % (rS, rE, cS, cE))
classif_map_tile = self.classif_map[rS: rE + 1, cS: cE + 1] # integer array
for ((rS, rE), (cS, cE)), im_tile in tqdm(spectra_as_im.tiles(tilesize=(50000, 1))):
classif_map_tile = classif_map_datapos[rS: rE + 1, cS: cE + 1] # integer array
# predict!
if self.classif_map.ndim == 2:
......@@ -514,7 +538,7 @@ class RSImage_ClusterPredictor(object):
cmap_unclassifiedVal=unclassified_pixVal)
else:
weights_tile = weights[rS: rE + 1, cS: cE + 1] # float array
weights_tile = weights_datapos[rS: rE + 1, cS: cE + 1] # float array
im_tile_pred = \
classifier.predict_weighted_averages(im_tile, classif_map_tile, weights_tile,
......@@ -523,6 +547,7 @@ class RSImage_ClusterPredictor(object):
cmap_unclassifiedVal=unclassified_pixVal)
# set saturated pixels (exceeding the output data range with respect to the data type) to no-data
# NOTE: this is computed on the chunks to save memory
if isinstance(image_predicted.dtype, np.integer):
out_dTMin, out_dTMax = np.iinfo(image_predicted.dtype).min,\
np.iinfo(image_predicted.dtype).max
......@@ -536,7 +561,10 @@ class RSImage_ClusterPredictor(object):
n_saturated_px += np.sum(mask_saturated)
im_tile_pred[mask_saturated] = out_nodataVal
image_predicted[rS:rE + 1, cS:cE + 1] = im_tile_pred.astype(image_predicted.dtype)
spectra_predicted[rS:rE + 1, :] = im2spectra(im_tile_pred) # [n_spectra x n_tgt_bands]
# fill in the predicted spectra
image_predicted[image.mask_nodata[:]] = spectra_predicted
if n_saturated_px:
self.logger.warning("%.2f %% of the predicted pixels are saturated and set to no-data."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment