Commit f2787388 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Classification maps are now returned as int16.


Signed-off-by: Daniel Scheffler's avatarDaniel Scheffler <danschef@gfz-potsdam.de>
parent 6d8ed7ea
Pipeline #3893 failed with stage
in 2 minutes and 1 second
......@@ -15,8 +15,8 @@ from geoarray import GeoArray
from py_tools_ds.numeric.array import get_array_tilebounds
global_shared_endmembers = None
global_shared_im2classify = None
global_shared_endmembers = None # type: Union[None, np.ndarray]
global_shared_im2classify = None # type: Union[None, GeoArray]
def initializer(endmembers, im2classify):
......@@ -59,12 +59,11 @@ class _ImageClassifier(object):
"""
self._cmap_nodataVal = cmap_nodataVal
dtype_cmap = np.array(self.train_labels).dtype
dtype_cmap = np.int16
if cmap_nodataVal is not None and not np.can_cast(cmap_nodataVal, dtype_cmap):
dtype_cmap = np.find_common_type(np.array(self.train_labels), np.array([cmap_nodataVal]))
image_cube_gA = GeoArray(image_cube, nodata=in_nodataVal)
# image_cube_gA.to_mem()
image_cube_gA = GeoArray(image_cube, nodata=in_nodataVal) # lazily read in tiles to save memory
bounds_alltiles = get_array_tilebounds(image_cube_gA.shape, tiledims)
......@@ -222,7 +221,7 @@ class MinimumDistance_Classifier(_ImageClassifier):
dist = self.compute_euclidian_distance(tileimdata.astype(np.float32), cmap, self._cmap_nodataVal)
return tilepos, cmap, dist
return tilepos, cmap.astype(np.int16), dist
def label_unclassified_pixels(self, label_unclassified, threshold):
# type: (int, Union[str, int, float]) -> GeoArray
......@@ -253,7 +252,9 @@ class kNN_Classifier(_ImageClassifier):
tileimdata = global_shared_im2classify[rS: rE + 1, cS: cE + 1, :]
spectra = tileimdata.reshape((tileimdata.shape[0] * tileimdata.shape[1], tileimdata.shape[2]))
return tilepos, self.clf.predict(spectra).reshape(*tileimdata.shape[:2]), None
cmap = self.clf.predict(spectra).reshape(*tileimdata.shape[:2])
return tilepos, cmap.astype(np.int16), None
class SAM_Classifier(_ImageClassifier):
......@@ -276,7 +277,7 @@ class SAM_Classifier(_ImageClassifier):
(rS, rE), (cS, cE) = tilepos
tileimdata = global_shared_im2classify[rS: rE + 1, cS: cE + 1, :]
endmembers = global_shared_endmembers # type: np.ndarray
endmembers = global_shared_endmembers
if not tileimdata.shape[2] == self.train_spectra.shape[1]:
raise RuntimeError('Matrix dimensions are not aligned. Input image has %d bands but input spectra '
......@@ -302,7 +303,7 @@ class SAM_Classifier(_ImageClassifier):
cmap = self.overwrite_cmap_at_nodata_positions(cmap, tileimdata,
self._cmap_nodataVal, global_shared_im2classify.nodata)
return tilepos, cmap, angles_min
return tilepos, cmap.astype(np.int16), angles_min
@staticmethod
def _calc_sam(s1_norm, s2_norm, axis=0):
......@@ -349,7 +350,7 @@ class SID_Classifier(_ImageClassifier):
(rS, rE), (cS, cE) = tilepos
tileimdata = global_shared_im2classify[rS: rE + 1, cS: cE + 1, :]
endmembers = global_shared_endmembers # type: np.ndarray
endmembers = global_shared_endmembers
if not tileimdata.shape[2] == self.train_spectra.shape[1]:
raise RuntimeError('Matrix dimensions are not aligned. Input image has %d bands but input spectra '
......@@ -375,7 +376,7 @@ class SID_Classifier(_ImageClassifier):
cmap = self.overwrite_cmap_at_nodata_positions(cmap, tileimdata,
self._cmap_nodataVal, global_shared_im2classify.nodata)
return tilepos, cmap, sid_min
return tilepos, cmap.astype(np.int16), sid_min
@staticmethod
def _calc_sid(s1_norm, s2_norm, axis=0):
......@@ -431,7 +432,7 @@ class RF_Classifier(_ImageClassifier):
cmap = self.overwrite_cmap_at_nodata_positions(cmap, tileimdata,
self._cmap_nodataVal, global_shared_im2classify.nodata)
return tilepos, cmap, None
return tilepos, cmap.astype(np.int16), None
def classify_image(image, train_spectra, train_labels, classif_alg, in_nodataVal=None, cmap_nodataVal=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment