Commit a893af35 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Added _show_distances_histogram() and subclass methods. Bugfix.


Signed-off-by: Daniel Scheffler's avatarDaniel Scheffler <danschef@gfz-potsdam.de>
parent 1da8e603
Pipeline #3886 failed with stage
in 1 minute and 59 seconds
......@@ -11,6 +11,7 @@ from sklearn.neighbors import KNeighborsClassifier, NearestCentroid
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import MaxAbsScaler
from numba import jit
from matplotlib import pyplot as plt
from geoarray import GeoArray
from py_tools_ds.numeric.array import get_array_tilebounds
......@@ -67,6 +68,7 @@ class _ImageClassifier(object):
# use a local variable to avoid pickling in multiprocessing
cmap = GeoArray(np.empty((image_cube_gA.rows, image_cube_gA.cols), dtype=dtype_cmap), nodata=cmap_nodataVal)
cmap.unclassified_val = None
dist = np.empty((image_cube_gA.rows, image_cube_gA.cols), dtype=np.float32)
print('Performing %s image classification...' % self.clf_name)
......@@ -105,12 +107,13 @@ class _ImageClassifier(object):
pass
elif isinstance(threshold, str) and threshold.endswith('%'):
percent = float(threshold.split('%')[0].strip())
dists = distances[cmap != cmap_nodataVal] if cmap_nodataVal is not None else distances
dists = distances[cmap[:] != cmap_nodataVal] if cmap_nodataVal is not None else distances
threshold = np.nanpercentile(dists, 100 - percent)
else:
raise ValueError(threshold)
cmap[distances > threshold] = label_unclassified
cmap.unclassified_val = label_unclassified
return cmap
......@@ -118,6 +121,20 @@ class _ImageClassifier(object):
if self.cmap:
self.cmap.show()
@staticmethod
def _show_distances_histogram(distances, cmap, figsize=(10, 5), bins=100, normed=False):
# noinspection PyProtectedMember
mask_gooddata = cmap[:] != cmap._nodata
if cmap.unclassified_val is not None:
mask_gooddata = mask_gooddata & (cmap[:] != cmap.unclassified_val)
distances = distances[mask_gooddata]
plt.figure(figsize=figsize)
plt.hist(list(distances), density=normed, bins=bins, color='gray')
plt.xlabel('Pixel value')
plt.ylabel('Probabilty' if normed else 'Count')
plt.show()
class MinimumDistance_Classifier(_ImageClassifier):
"""Classifier computing the n-dimensional euclidian distance of each pixel vector to each cluster mean vector.
......@@ -164,9 +181,10 @@ class MinimumDistance_Classifier(_ImageClassifier):
spectra = tileimdata.reshape((tileimdata.shape[0] * tileimdata.shape[1], tileimdata.shape[2]))
cmap = self.clf.predict(spectra).reshape(*tileimdata.shape[:2])
dist = self.compute_euclidian_distance_jit(tileimdata.astype(np.float32), cmap)
# dist = self.compute_euclidian_distance_jit(tileimdata.astype(np.float32), cmap)
return tilepos, cmap, dist
return tilepos, cmap
# return tilepos, cmap, dist
def label_unclassified_pixels(self, label_unclassified, threshold):
# type: (int, Union[str, int, float]) -> GeoArray
......@@ -174,6 +192,9 @@ class MinimumDistance_Classifier(_ImageClassifier):
self.cmap, label_unclassified, threshold, self.euclidian_distance
)
def show_distances_histogram(self, figsize=(10, 5), bins=100, normed=False):
self._show_distances_histogram(self.euclidian_distance, self.cmap, figsize=figsize, bins=bins, normed=normed)
class kNN_Classifier(_ImageClassifier):
def __init__(self, train_spectra, train_labels, CPUs=1, **kwargs):
......@@ -260,6 +281,9 @@ class SAM_Classifier(_ImageClassifier):
self.cmap, label_unclassified, threshold, self.angles_deg
)
def show_angles_histogram(self, figsize=(10, 5), bins=100, normed=False):
self._show_distances_histogram(self.angles_deg, self.cmap, figsize=figsize, bins=bins, normed=normed)
class SID_Classifier(_ImageClassifier):
def __init__(self, train_spectra, CPUs=1):
......@@ -325,6 +349,9 @@ class SID_Classifier(_ImageClassifier):
self.cmap, label_unclassified, threshold, self.sid
)
def show_sid_histogram(self, figsize=(10, 5), bins=100, normed=False):
self._show_distances_histogram(self.sid, self.cmap, figsize=figsize, bins=bins, normed=normed)
class RF_Classifier(_ImageClassifier):
"""Random forest classifier."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment