Commit 6ef40c33 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Refactored 'k' parameter o 'n_neighbors'.


Signed-off-by: Daniel Scheffler's avatarDaniel Scheffler <danschef@gfz-potsdam.de>
parent 5f4ec877
Pipeline #4100 failed with stage
in 2 minutes and 11 seconds
......@@ -323,12 +323,12 @@ class SAM_Classifier(_ImageClassifier):
class kNN_SAM_Classifier(SAM_Classifier):
def __init__(self, train_spectra, k=3, CPUs=1):
def __init__(self, train_spectra, n_neighbors=3, CPUs=1):
# type: (np.ndarray, int, Union[int, None]) -> None
super(kNN_SAM_Classifier, self).__init__(train_spectra, CPUs=CPUs)
self.clf_name = 'k-nearest neighbour spectral angle mapper (SAM)'
self.k = k
self.n_neighbors = n_neighbors
def _predict(self, tilepos):
assert global_shared_endmembers is not None and global_shared_im2classify is not None
......@@ -337,8 +337,8 @@ class kNN_SAM_Classifier(SAM_Classifier):
tileimdata = global_shared_im2classify[rS: rE + 1, cS: cE + 1, :]
angles = self._calc_sam(tileimdata, global_shared_endmembers)
cmap = np.argpartition(angles, self.k, axis=2)[:, :, :self.k].astype(np.int16)
angles_min_k = np.partition(angles, self.k, axis=2)[:, :, :self.k].astype(np.float32)
cmap = np.argpartition(angles, self.n_neighbors, axis=2)[:, :, :self.n_neighbors].astype(np.int16)
angles_min_k = np.partition(angles, self.n_neighbors, axis=2)[:, :, :self.n_neighbors].astype(np.float32)
if global_shared_im2classify.nodata is not None and self._cmap_nodataVal is not None:
cmap = self.overwrite_cmap_at_nodata_positions(cmap, tileimdata,
......
......@@ -121,24 +121,24 @@ class Test_SAM_Classifier(unittest.TestCase):
SC.label_unclassified_pixels(label_unclassified=-1, threshold='10%')
class Test_KNN_SAM_Classifier(unittest.TestCase):
class Test_kNN_SAM_Classifier(unittest.TestCase):
def setUp(self) -> None:
self.k = 3
self.n_neighbors = 3
def test_classify(self):
SC = kNN_SAM_Classifier(cluster_centers, k=self.k, CPUs=1)
SC = kNN_SAM_Classifier(cluster_centers, n_neighbors=self.n_neighbors, CPUs=1)
cmap_sp = SC.classify(test_gA, in_nodataVal=-9999, cmap_nodataVal=-9999, tiledims=(400, 200))
self.assertIsInstance(cmap_sp, GeoArray)
self.assertEqual(cmap_sp.shape, (1010, 1010, self.k))
self.assertEqual(cmap_sp.shape, (1010, 1010, self.n_neighbors))
SC = kNN_SAM_Classifier(cluster_centers, k=self.k, CPUs=None)
SC = kNN_SAM_Classifier(cluster_centers, n_neighbors=self.n_neighbors, CPUs=None)
cmap_mp = SC.classify(test_gA, in_nodataVal=-9999, cmap_nodataVal=-9999, tiledims=(400, 200))
self.assertIsInstance(cmap_mp, GeoArray)
self.assertEqual(cmap_mp.shape, (1010, 1010, self.k))
self.assertEqual(cmap_mp.shape, (1010, 1010, self.n_neighbors))
self.assertTrue(np.array_equal(cmap_sp, cmap_mp))
SC = kNN_SAM_Classifier(cluster_centers, k=self.k, CPUs=None)
SC = kNN_SAM_Classifier(cluster_centers, n_neighbors=self.n_neighbors, CPUs=None)
cmap_mp = SC.classify(test_gA_pure_endmembers, in_nodataVal=-9999, cmap_nodataVal=-9999)
for i, cl in enumerate(cluster_labels):
......@@ -146,12 +146,12 @@ class Test_KNN_SAM_Classifier(unittest.TestCase):
# self.assertTrue(np.array_equal(cmap_mp.flatten(), cluster_labels)) # TODO sort cmap by SC.angles_deg
def test_label_unclassified_pixels_absolute_th(self):
SC = kNN_SAM_Classifier(cluster_centers, k=self.k, CPUs=None)
SC = kNN_SAM_Classifier(cluster_centers, n_neighbors=self.n_neighbors, CPUs=None)
SC.classify(test_gA, in_nodataVal=-9999, cmap_nodataVal=-9999, tiledims=(400, 200))
SC.label_unclassified_pixels(label_unclassified=-1, threshold=10)
def test_label_unclassified_pixels_relative_th(self):
SC = kNN_SAM_Classifier(cluster_centers, self.k, CPUs=None)
SC = kNN_SAM_Classifier(cluster_centers, self.n_neighbors, CPUs=None)
SC.classify(test_gA, in_nodataVal=-9999, cmap_nodataVal=-9999, tiledims=(400, 200))
SC.label_unclassified_pixels(label_unclassified=-1, threshold='10%')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment