Commit 2a3bb492 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

KMeansRSImage: Added functions and properties to apply clustering, plot...

KMeansRSImage: Added functions and properties to apply clustering, plot cluster centers, plot cluster histogram, plot clustered image + Tests.
Former-commit-id: e08321f2
Former-commit-id: c98938a4
parent b7ac5722
......@@ -8,7 +8,7 @@ from scipy.interpolate import interp1d
import scipy as sp
import matplotlib.pyplot as plt
from logging import Logger
from sklearn.cluster import KMeans
from sklearn.cluster import KMeans, k_means_
from geoarray import GeoArray # noqa F401 # flake8 issue
......@@ -167,21 +167,105 @@ class SpectralResampler(object):
class KMeansRSImage(object):
_clusters = None
_im_clust = None
def __init__(self, im, n_clusters):
# type: (GeoArray, int) -> None
self.im = im
self.n_clusters = n_clusters
@property
def clusters(self):
# type: () -> k_means_.KMeans
if not self._clusters:
self._clusters = self.compute_clusters()
return self._clusters
@clusters.setter
def clusters(self, clusters):
self._clusters = clusters
@property
def im_clust(self):
if self._im_clust is None:
self._im_clust = self.clusters.labels_.reshape((self.im.rows, self.im.cols))
return self._im_clust
def compute_clusters(self):
# implement like this: https://www.pyimagesearch.com/2014/05/26/opencv-python-k-means-color-clustering/
pixels2d = self.im.reshape((self.im.rows * self.im.cols, self.im.bands))
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0)
out = kmeans.fit(pixels2d)
self.clusters = kmeans.fit(pixels2d)
return self.clusters
def apply_clusters(self, image):
image = GeoArray(image)
pixels2d = self.im.reshape((image.rows * image.cols, image.bands))
labels = self.clusters.predict(pixels2d)
return labels
def plot_cluster_centers(self, figsize=(15, 5)):
# type: (tuple) -> None
"""Show a plot of the cluster center signatures.
:param figsize: figure size (inches)
"""
plt.figure(figsize=figsize)
for i, center_signature in enumerate(self.clusters.cluster_centers_):
plt.plot(range(1, self.im.bands + 1), center_signature, label='Cluster #%s' % (i + 1))
plt.title('KMeans cluster centers for %s clusters' % self.n_clusters)
plt.xlabel('Spectral band')
plt.ylabel('Pixel values')
plt.legend(loc='upper right')
print(out)
plt.show()
def plot_cluster_histogram(self, figsize=(15, 5)):
# type: (tuple) -> None
"""Show a histogram indicating the proportion of each cluster label in percent.
:param figsize: figure size (inches)
"""
# grab the number of different clusters and create a histogram
# based on the number of pixels assigned to each cluster
numLabels = np.arange(0, len(np.unique(self.clusters.labels_)) + 1)
hist, bins = np.histogram(self.clusters.labels_, bins=numLabels)
# normalize the histogram, such that it sums to 100
hist = hist.astype("float")
hist /= hist.sum() / 100
# plot the histogram as bar plot
plt.figure(figsize=figsize)
plt.bar(bins[:-1], hist, width=1)
plt.title('Proportion of cluster labels (%s clusters)' % self.n_clusters)
plt.xlabel('# Cluster')
plt.ylabel('Proportion [%]')
plt.show()
def plot_clustered_image(self, figsize=(15, 15)):
# type: (tuple) -> None
"""Show a the clustered image.
:param figsize: figure size (inches)
"""
plt.figure(figsize=figsize)
rows, cols = self.im_clust.shape[:2]
plt.imshow(self.im_clust, plt.cm.prism, interpolation='none', extent=(0, cols, rows, 0))
plt.show()
def get_random_spectra_from_each_cluster(self, samplesize=50):
"""Returns a given number of spectra randomly selected within each cluster.
E.g., 50 spectra of belonging to cluster 1, 50 spectra of belonging to cluster 2 and so on."""
# TODO
......@@ -9,14 +9,19 @@ Tests for gms_preprocessing.algorithms.L2B_P.KMeansRSImage
"""
import unittest
import numpy as np
import os
import matplotlib
import numpy as np
from sklearn.cluster import k_means_
from geoarray import GeoArray
matplotlib.use('Template') # disables matplotlib figure popups
from gms_preprocessing import __file__
from gms_preprocessing.config import set_config
from gms_preprocessing.algorithms.L2B_P import KMeansRSImage
from geoarray import GeoArray # noqa E402 module level import not at top of file
from gms_preprocessing import __file__ # noqa E402 module level import not at top of file
from gms_preprocessing.config import set_config # noqa E402 module level import not at top of file
from gms_preprocessing.algorithms.L2B_P import KMeansRSImage # noqa E402 module level import not at top of file
testdata = os.path.join(os.path.dirname(__file__),
......@@ -36,3 +41,22 @@ class Test_KMeansRSImage(unittest.TestCase):
def test_compute_clusters(self):
self.kmeans.compute_clusters()
self.assertIsInstance(self.kmeans.clusters, k_means_.KMeans)
def test_apply_clusters(self):
labels = self.kmeans.apply_clusters(self.geoArr)
self.assertIsInstance(labels, np.ndarray)
self.assertTrue(labels.size == self.geoArr.rows * self.geoArr.cols)
def test_get_random_spectra_from_each_cluster(self):
self.kmeans.get_random_spectra_from_each_cluster()
# TODO
def test_plot_cluster_centers(self):
self.kmeans.plot_cluster_centers()
def test_plot_cluster_histogram(self):
self.kmeans.plot_cluster_histogram()
def test_plot_clustered_image(self):
self.kmeans.plot_clustered_image()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment