Commit e08321f2 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

KMeansRSImage: Added functions and properties to apply clustering, plot...

KMeansRSImage: Added functions and properties to apply clustering, plot cluster centers, plot cluster histogram, plot clustered image + Tests.
parent 799e34de
Pipeline #1422 failed with stage
in 7 minutes and 59 seconds
......@@ -8,7 +8,7 @@ from scipy.interpolate import interp1d
import scipy as sp
import matplotlib.pyplot as plt
from logging import Logger
from sklearn.cluster import KMeans
from sklearn.cluster import KMeans, k_means_
from geoarray import GeoArray # noqa F401 # flake8 issue
......@@ -167,21 +167,105 @@ class SpectralResampler(object):
class KMeansRSImage(object):
_clusters = None
_im_clust = None
def __init__(self, im, n_clusters):
# type: (GeoArray, int) -> None = im
self.n_clusters = n_clusters
def clusters(self):
# type: () -> k_means_.KMeans
if not self._clusters:
self._clusters = self.compute_clusters()
return self._clusters
def clusters(self, clusters):
self._clusters = clusters
def im_clust(self):
if self._im_clust is None:
self._im_clust = self.clusters.labels_.reshape((,
return self._im_clust
def compute_clusters(self):
# implement like this:
pixels2d = *,
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0)
out =
self.clusters =
return self.clusters
def apply_clusters(self, image):
image = GeoArray(image)
pixels2d = * image.cols, image.bands))
labels = self.clusters.predict(pixels2d)
return labels
def plot_cluster_centers(self, figsize=(15, 5)):
# type: (tuple) -> None
"""Show a plot of the cluster center signatures.
:param figsize: figure size (inches)
for i, center_signature in enumerate(self.clusters.cluster_centers_):
plt.plot(range(1, + 1), center_signature, label='Cluster #%s' % (i + 1))
plt.title('KMeans cluster centers for %s clusters' % self.n_clusters)
plt.xlabel('Spectral band')
plt.ylabel('Pixel values')
plt.legend(loc='upper right')
def plot_cluster_histogram(self, figsize=(15, 5)):
# type: (tuple) -> None
"""Show a histogram indicating the proportion of each cluster label in percent.
:param figsize: figure size (inches)
# grab the number of different clusters and create a histogram
# based on the number of pixels assigned to each cluster
numLabels = np.arange(0, len(np.unique(self.clusters.labels_)) + 1)
hist, bins = np.histogram(self.clusters.labels_, bins=numLabels)
# normalize the histogram, such that it sums to 100
hist = hist.astype("float")
hist /= hist.sum() / 100
# plot the histogram as bar plot
plt.figure(figsize=figsize)[:-1], hist, width=1)
plt.title('Proportion of cluster labels (%s clusters)' % self.n_clusters)
plt.xlabel('# Cluster')
plt.ylabel('Proportion [%]')
def plot_clustered_image(self, figsize=(15, 15)):
# type: (tuple) -> None
"""Show a the clustered image.
:param figsize: figure size (inches)
rows, cols = self.im_clust.shape[:2]
plt.imshow(self.im_clust,, interpolation='none', extent=(0, cols, rows, 0))
def get_random_spectra_from_each_cluster(self, samplesize=50):
"""Returns a given number of spectra randomly selected within each cluster.
E.g., 50 spectra of belonging to cluster 1, 50 spectra of belonging to cluster 2 and so on."""
......@@ -9,14 +9,19 @@ Tests for gms_preprocessing.algorithms.L2B_P.KMeansRSImage
import unittest
import numpy as np
import os
import matplotlib
import numpy as np
from sklearn.cluster import k_means_
from geoarray import GeoArray
matplotlib.use('Template') # disables matplotlib figure popups
from gms_preprocessing import __file__
from gms_preprocessing.config import set_config
from gms_preprocessing.algorithms.L2B_P import KMeansRSImage
from geoarray import GeoArray # noqa E402 module level import not at top of file
from gms_preprocessing import __file__ # noqa E402 module level import not at top of file
from gms_preprocessing.config import set_config # noqa E402 module level import not at top of file
from gms_preprocessing.algorithms.L2B_P import KMeansRSImage # noqa E402 module level import not at top of file
testdata = os.path.join(os.path.dirname(__file__),
......@@ -36,3 +41,22 @@ class Test_KMeansRSImage(unittest.TestCase):
def test_compute_clusters(self):
self.assertIsInstance(self.kmeans.clusters, k_means_.KMeans)
def test_apply_clusters(self):
labels = self.kmeans.apply_clusters(self.geoArr)
self.assertIsInstance(labels, np.ndarray)
self.assertTrue(labels.size == self.geoArr.rows * self.geoArr.cols)
def test_get_random_spectra_from_each_cluster(self):
def test_plot_cluster_centers(self):
def test_plot_cluster_histogram(self):
def test_plot_clustered_image(self):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment