Commit 6f7ae30f authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Replaced implementation of SAM classifier by own implementation.

parent 5198ff20
Pipeline #3229 passed with stage
in 17 minutes and 39 seconds
......@@ -15,7 +15,7 @@ from geoarray import GeoArray
class _ImageClassifier(object):
"""Base class for GMS image classifiers."""
def __init__(self, train_spectra, train_labels, CPUs=1):
# type: (np.ndarray, Union[np.ndarray, List[int]], int) -> None
# type: (np.ndarray, Union[np.ndarray, List[int]], Union[int, None]) -> None
self.CPUs = CPUs
self.train_spectra = train_spectra
self.train_labels = train_labels
......@@ -60,7 +60,7 @@ class MinimumDistance_Classifier(_ImageClassifier):
NOTE: distance equation: D² = sqrt(sum((Xvi - Xvj)²)
"""
def __init__(self, train_spectra, train_labels, CPUs=1):
# type: (np.ndarray, Union[np.ndarray, List[int]], int) -> None
# type: (np.ndarray, Union[np.ndarray, List[int]], Union[int, None]) -> None
super(MinimumDistance_Classifier, self).__init__(train_spectra, train_labels, CPUs=CPUs)
self.clf = NearestCentroid()
......@@ -73,7 +73,7 @@ class MinimumDistance_Classifier(_ImageClassifier):
class kNN_Classifier(_ImageClassifier):
def __init__(self, train_spectra, train_labels, CPUs=1, n_neighbors=10):
# type: (np.ndarray, Union[np.ndarray, List[int]], int, int) -> None
# type: (np.ndarray, Union[np.ndarray, List[int]], Union[int, None], int) -> None
super(kNN_Classifier, self).__init__(train_spectra, train_labels, CPUs=CPUs)
self.clf = KNeighborsClassifier(n_neighbors=n_neighbors, n_jobs=CPUs)
......@@ -84,16 +84,16 @@ class kNN_Classifier(_ImageClassifier):
return tilepos, self.clf.predict(spectra).reshape(*tileimdata.shape[:2])
class SAM_Classifier(_ImageClassifier):
class SAM_Classifier_OLD(_ImageClassifier):
def __init__(self, train_spectra, threshold=0.1, CPUs=1):
# type: (np.ndarray, Union[np.ndarray, List[int]], int) -> None
super(SAM_Classifier, self).__init__(train_spectra, np.array(range(train_spectra.shape[0])), CPUs=CPUs)
# type: (np.ndarray, Union[np.ndarray, List[int]], Union[int, None]) -> None
super(SAM_Classifier_OLD, self).__init__(train_spectra, np.array(range(train_spectra.shape[0])), CPUs=CPUs)
self.clf = SAM()
self.threshold = threshold
def _predict(self, tilepos, tileimdata):
return self.clf.classify(tileimdata, self.train_spectra, self.threshold)
return tilepos, self.clf.classify(tileimdata, self.train_spectra, self.threshold)
def classify(self, image_cube, nodataVal=None, tiledims=(1000, 1000), mask=None):
image_cube_gA = GeoArray(image_cube, nodata=nodataVal)
......@@ -106,7 +106,7 @@ class SAM_Classifier(_ImageClassifier):
else:
image_cube_gA[image_cube_gA.mask_nodata.astype(np.int8) == 0] = np.max(image_cube_gA)
cmap = super(SAM_Classifier, self).classify(image_cube_gA, nodataVal=nodataVal, tiledims=tiledims)
cmap = super(SAM_Classifier_OLD, self).classify(image_cube_gA, nodataVal=nodataVal, tiledims=tiledims)
if mask:
cmap[mask] = -9999
......@@ -114,6 +114,56 @@ class SAM_Classifier(_ImageClassifier):
return cmap
class SAM_Classifier(_ImageClassifier):
def __init__(self, train_spectra, threshold=0.1, CPUs=1):
# type: (np.ndarray, Union[np.ndarray, List[int]], Union[int, None]) -> None
super(SAM_Classifier, self).__init__(train_spectra, np.array(range(train_spectra.shape[0])), CPUs=CPUs)
self.clf = SAM()
self.threshold = threshold
def _predict(self, tilepos, tileimdata):
angles = np.zeros((tileimdata.shape[0], tileimdata.shape[1], self.n_samples), np.float)
# if np.std(tileimdata) == 0:
tileimdata_norm = self._normalize(tileimdata, axis=2)
for n_sample in range(self.n_samples):
train_spectrum = self.train_spectra[n_sample, :].reshape(1, 1, self.n_features)
angles[:, :, n_sample] = self._calc_sam(tileimdata_norm, train_spectrum, axis=2, s1_normed=True)
cmap = np.argmin(angles, axis=2)
return tilepos, cmap
def _calc_sam(self, s1, s2, axis=0, s1_normed=False, s2_normed=False):
"""Compute the spectral angle mapper between two vectors or images (in radians)."""
norm_s1 = s1 if s1_normed else self._normalize(s1, axis=axis)
norm_s2 = s2 if s2_normed else self._normalize(s2, axis=axis)
upper = np.sum(norm_s1 * norm_s2, axis=axis)
lower = np.sqrt(np.sum(norm_s1 * norm_s1, axis=axis)) * np.sqrt(np.sum(norm_s2 * norm_s2, axis=axis))
if lower.ndim > 1:
lower[lower == 0] = 1e-10
else:
lower = lower or 1e-10
return np.arccos(upper / lower)
@staticmethod
def _normalize(x, axis=0):
if x.ndim > 2:
upper = (x - np.min(x, axis=axis)[:, :, np.newaxis]).astype(np.float)
lower = (np.max(x, axis=axis) - np.min(x, axis=axis)).astype(np.float)[:, :, np.newaxis]
lower[lower == 0] = 1e-10
else:
upper = (x - np.min(x, axis=axis)).astype(np.float)
lower = (np.max(x, axis=axis) - np.min(x, axis=axis)).astype(np.float)
lower = lower or 1e-10
return upper / lower
def classify_image(image, train_spectra, train_labels, classif_alg,
kNN_n_neighbors=10, nodataVal=None, tiledims=(1000, 1000), CPUs=None):
# type: (Union[np.ndarray, GeoArray], np.ndarray, Union[np.ndarray, List[int]], str, int, ...) -> GeoArray
......
......@@ -19,6 +19,7 @@ from geoarray import GeoArray
from gms_preprocessing import set_config
from gms_preprocessing.algorithms.classification import MinimumDistance_Classifier
from gms_preprocessing.algorithms.classification import kNN_Classifier
from gms_preprocessing.algorithms.classification import SAM_Classifier_OLD
from gms_preprocessing.algorithms.classification import SAM_Classifier
from . import db_host
......@@ -60,7 +61,16 @@ class Test_kNN_Classifier(unittest.TestCase):
class Test_SAM_Classifier(unittest.TestCase):
def test_classify(self):
SC = SAM_Classifier(cluster_centers)
cmap = SC.classify(test_gA, nodataVal=-9999)
cmap = SC.classify(test_gA, nodataVal=-9999, tiledims=(400, 200))
self.assertIsInstance(cmap, np.ndarray)
self.assertEqual(cmap.shape, (1010, 1010))
class Test_SAM_Classifier_OLD(unittest.TestCase):
def test_classify(self):
SC = SAM_Classifier_OLD(cluster_centers, CPUs=None)
cmap = SC.classify(test_gA, nodataVal=-9999, tiledims=(400, 200))
self.assertIsInstance(cmap, np.ndarray)
self.assertEqual(cmap.shape, (1010, 1010))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment