Commit f029234f authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Converted _calc _sam to staticmethod.


Signed-off-by: Daniel Scheffler's avatarDaniel Scheffler <danschef@gfz-potsdam.de>
parent 118c49b4
Pipeline #4119 failed with stage
in 2 minutes and 15 seconds
......@@ -332,20 +332,23 @@ class SAM_Classifier(_ImageClassifier):
def angles_deg(self):
return np.rad2deg(self._distance_metrics) if self._distance_metrics is not None else None
def _calc_sam(self, image, endmembers):
if not image.shape[2] == self.train_spectra.shape[1]:
@staticmethod
def calc_sam(image, endmembers):
n_samples, n_features = endmembers.shape
if not image.shape[2] == endmembers.shape[1]:
raise RuntimeError('Matrix dimensions are not aligned. Input image has %d bands but input spectra '
'have %d.' % (image.shape[2], self.train_spectra.shape[1]))
'have %d.' % (image.shape[2], endmembers.shape[1]))
# normalize input data because SAM asserts only data between -1 and 1
train_spectra_norm, tileimdata_norm = normalize_endmembers_image(endmembers, image)
angles = np.zeros((image.shape[0], image.shape[1], self.n_samples), np.float)
angles = np.zeros((image.shape[0], image.shape[1], n_samples), np.float)
# if np.std(tileimdata) == 0: # skip tiles that only contain the same value
# loop over all training spectra and compute spectral angle for each pixel
for n_sample in range(self.n_samples):
train_spectrum = train_spectra_norm[n_sample, :].reshape(1, 1, self.n_features)
for n_sample in range(n_samples):
train_spectrum = train_spectra_norm[n_sample, :].reshape(1, 1, n_features)
angles[:, :, n_sample] = calc_sam(tileimdata_norm, train_spectrum, axis=2)
return angles
......@@ -356,7 +359,7 @@ class SAM_Classifier(_ImageClassifier):
(rS, rE), (cS, cE) = tilepos
tileimdata = global_shared_im2classify[rS: rE + 1, cS: cE + 1, :]
angles = self._calc_sam(tileimdata, global_shared_endmembers)
angles = self.calc_sam(tileimdata, global_shared_endmembers)
angles_min = np.min(angles, axis=2).astype(np.float32)
cmap = np.argmin(angles, axis=2).astype(np.int16)
......@@ -393,7 +396,7 @@ class kNN_SAM_Classifier(SAM_Classifier):
(rS, rE), (cS, cE) = tilepos
tileimdata = global_shared_im2classify[rS: rE + 1, cS: cE + 1, :]
angles = self._calc_sam(tileimdata, global_shared_endmembers)
angles = self.calc_sam(tileimdata, global_shared_endmembers)
k = self.n_neighbors if self.n_neighbors <= angles.shape[2] else angles.shape[2]
if self.n_neighbors < angles.shape[2]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment