Commit 4f27c904 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files
parent a0127bf7
Pipeline #4102 failed with stage
in 2 minutes and 8 seconds
......@@ -78,7 +78,7 @@ class _ImageClassifier(object):
# use a local variable to avoid pickling in multiprocessing
cmap_dist_shape = (image_cube_gA.rows, image_cube_gA.cols) if tiles_results[0][1].ndim == 2 else \
(image_cube_gA.rows, image_cube_gA.cols, tiles_results[0][1].ndim)
(image_cube_gA.rows, image_cube_gA.cols, tiles_results[0][1].shape[2])
cmap = GeoArray(np.empty(cmap_dist_shape, dtype=dtype_cmap), nodata=cmap_nodataVal)
cmap.unclassified_val = None
dist = np.empty(cmap_dist_shape, dtype=np.float32)
......@@ -327,7 +327,7 @@ class kNN_SAM_Classifier(SAM_Classifier):
# type: (np.ndarray, int, Union[int, None]) -> None
super(kNN_SAM_Classifier, self).__init__(train_spectra, CPUs=CPUs)
self.clf_name = 'k-nearest neighbour spectral angle mapper (SAM) - %d neighbors' % n_neighbors
self.clf_name = 'k-nearest neighbour spectral angle mapper (kNN_SAM; k=%d)' % n_neighbors
self.n_neighbors = n_neighbors
def _predict(self, tilepos):
......@@ -338,8 +338,13 @@ class kNN_SAM_Classifier(SAM_Classifier):
angles = self._calc_sam(tileimdata, global_shared_endmembers)
k = self.n_neighbors if self.n_neighbors <= angles.shape[2] else angles.shape[2]
cmap = np.argpartition(angles, k, axis=2)[:, :, :k].astype(np.int16)
angles_min_k = np.partition(angles, k, axis=2)[:, :, :k].astype(np.float32)
if self.n_neighbors < angles.shape[2]:
cmap = np.argpartition(angles, k, axis=2)[:, :, :k].astype(np.int16)
angles_min_k = np.partition(angles, k, axis=2)[:, :, :k].astype(np.float32)
else:
cmap = np.tile(np.arange(angles.shape[2]).reshape(1, 1, -1), (*angles.shape[:2], 1))
angles_min_k = angles
if global_shared_im2classify.nodata is not None and self._cmap_nodataVal is not None:
cmap = self.overwrite_cmap_at_nodata_positions(cmap, tileimdata,
......@@ -545,7 +550,7 @@ def classify_image(image, train_spectra, train_labels, classif_alg, in_nodataVal
CPUs=CPUs)
elif classif_alg == 'kNN_SAM':
kw = dict(k=kwargs['k']) if 'k' in kwargs else dict()
kw = dict(n_neighbors=kwargs['n_neighbors']) if 'n_neighbors' in kwargs else dict()
clf = kNN_SAM_Classifier(
train_spectra,
CPUs=CPUs,
......
......@@ -123,7 +123,7 @@ class Test_SAM_Classifier(unittest.TestCase):
class Test_kNN_SAM_Classifier(unittest.TestCase):
def setUp(self) -> None:
self.n_neighbors = 3
self.n_neighbors = 5
def test_classify(self):
SC = kNN_SAM_Classifier(cluster_centers, n_neighbors=self.n_neighbors, CPUs=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment