Commit c414fd96 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Revised L2B_P.SpecHomo_Classifier.

Former-commit-id: 2c5f60d8
Former-commit-id: 764dda22
parent ce31b59b
......@@ -15,10 +15,12 @@ from typing import Union, List # noqa F401 # flake8 issue
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import tempfile
import time
from sklearn.cluster import k_means_ # noqa F401 # flake8 issue
from geoarray import GeoArray # noqa F401 # flake8 issue
from py_tools_ds.numeric.array import get_array_tilebounds
from py_tools_ds.processing.progress_mon import ProgressBar
from ..config import GMS_config as CFG
from ..io.input_reader import SRF # noqa F401 # flake8 issue
......@@ -206,7 +208,7 @@ class SpectralResampler(object):
class KMeansRSImage(object):
def __init__(self, im, n_clusters):
def __init__(self, im, n_clusters, CPUs=1):
# type: (GeoArray, int) -> None
# privates
......@@ -216,6 +218,7 @@ class KMeansRSImage(object):
self.im = im
self.n_clusters = n_clusters
self.CPUs = CPUs
@property
def clusters(self):
......@@ -235,7 +238,7 @@ class KMeansRSImage(object):
return self._im_clust
def compute_clusters(self):
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0)
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0, n_jobs=self.CPUs)
self.clusters = kmeans.fit(self._im2spectra(self.im))
return self.clusters
......@@ -340,8 +343,8 @@ class SpecHomo_Classifier(object):
self.CPUs = CPUs or cpu_count()
self.tmpdir_multiproc = ''
def generate_reference_cube(self, tgt_satellite, tgt_sensor, n_clusters=10, tgt_n_samples=1000, path_out='',
fmt_out='ENVI'):
def generate_reference_cubeOLD(self, tgt_satellite, tgt_sensor, n_clusters=10, tgt_n_samples=1000, path_out='',
fmt_out='ENVI'):
# type: (str, str, int, int, str, str) -> np.ndarray
"""Generate reference spectra from all hyperspectral input images.
......@@ -403,42 +406,130 @@ class SpecHomo_Classifier(object):
return self.ref_cube
def _get_uniform_random_samples(self, im_ref, tgt_srf, n_clusters, tgt_n_samples):
im_name = os.path.basename(im_ref)
def generate_reference_cube(self, tgt_satellite, tgt_sensor, n_clusters=10, tgt_n_samples=1000, path_out='',
fmt_out='ENVI', progress=True):
# type: (str, str, int, int, str, str, bool) -> np.ndarray
"""Generate reference spectra from all hyperspectral input images.
The hyperspectral images are spectrally resampled to the target sensor specifications. The resulting target
sensor image is then clustered and the same number of spectra is randomly selected from each cluster. All
spectra are combined into a single 'reference cube' containing the same number of spectra for each cluster
whereas the spectra orginate from all the input images.
:param tgt_satellite: target satellite, e.g., 'Landsat-8'
:param tgt_sensor: target sensor, e.g.. 'OLI_TIRS'
:param n_clusters: number of clusters to be used for clustering the input images (KMeans)
:param tgt_n_samples: number o spectra to be collected from each input image
:param path_out: output path for the generated reference cube
:param fmt_out: output format (GDAL driver code)
:param progress: show progress bar
:return: np.array: [tgt_n_samples x images x spectral bands of the target sensor]
"""
self.logger.info('Generating reference spectra from all input images...')
# get SRFs
self.logger.info('Reading spectral response functions of target sensor...')
tgt_srf = SRF(dict(Satellite=tgt_satellite, Sensor=tgt_sensor, Subsystem=None, image_type='RSD',
proc_level='L1A', logger=self.logger))
if self.v:
tgt_srf.plot_srfs()
# Build the reference cube from random samples of each image
# => rows: tgt_n_samples, columns: images, bands: spectral information
# generate random spectra samples equally for each KMeans cluster
args = ((im_num, im, tgt_srf, n_clusters, tgt_n_samples) for im_num, im in enumerate(self.ims_ref))
if self.CPUs > 1:
processes = len(self.ims_ref) if self.CPUs > len(self.ims_ref) else self.CPUs
with Pool(processes) as pool:
results = pool.starmap_async(self._get_uniform_random_samples, args, chunksize=1)
bar = ProgressBar(prefix='\tprogress:')
while True:
time.sleep(.1)
# this does not really represent the remaining tasks but the remaining chunks -> thus chunksize=1
# noinspection PyProtectedMember
numberDone = len(self.ims_ref) - results._number_left
if progress:
bar.print_progress(percent=numberDone / len(self.ims_ref) * 100)
if results.ready():
# <= this is the line where multiprocessing may freeze if an exception occurrs
results = results.get()
break
else:
results = []
bar = ProgressBar(prefix='\tprogress:')
for i, argset in enumerate(args):
if progress:
bar.print_progress((i + 1) / len(self.ims_ref) * 100)
results.append(self._get_uniform_random_samples(*argset))
# combine returned random samples to ref_cube
self.ref_cube = np.zeros((tgt_n_samples, len(self.ims_ref), len(tgt_srf.bands)))
for im_num, random_samples in results:
self.logger.info('Adding random samples of %s to reference cube...'
% os.path.basename(self.ims_ref[im_num]))
self.ref_cube[:, im_num, :] = random_samples
# save
if path_out:
GeoArray(self.ref_cube).save(out_path=path_out, fmt=fmt_out)
return self.ref_cube
@staticmethod
def perform_spectral_resampling(src_im, tgt_srf, logger=None, progress=False):
# type: (str, SRF, str) -> Union[GeoArray, None]
im_name = os.path.basename(src_im)
logger = logger or GMS_logger('spectral_resamp_%s' % im_name)
# read input image
self.logger.info('Reading the input image %s...' % im_name)
im_gA = GeoArray(im_ref)
logger.info('Reading the input image %s...' % im_name)
im_gA = GeoArray(src_im)
im_gA.cwl = np.array(im_gA.meta.loc['wavelength'], dtype=np.float).flatten()
wvl_unit = 'nanometers' if max(im_gA.cwl) > 15 else 'micrometers'
# perform spectral resampling of input image to match spectral properties of target sensor
self.logger.info('Performing spectral resampling to match spectral properties of target sensor...')
logger.info('Performing spectral resampling to match spectral properties of target sensor...')
SR = SpectralResampler(im_gA.cwl, tgt_srf, wvl_unit=wvl_unit)
im_tgt = np.empty((*im_gA.shape[:2], len(tgt_srf.bands)))
tiles = im_gA.tiles((1000, 1000))
for ((rS, rE), (cS, cE)), tiledata in (tqdm(tiles) if self.CPUs == 1 else tiles):
im_tgt[rS: rE + 1, cS: cE + 1, :] = SR.resample_image(tiledata)
im_tgt = GeoArray(im_tgt, im_gA.gt, im_gA.prj)
tgt_im = GeoArray(np.zeros((*im_gA.shape[:2], len(tgt_srf.bands)), dtype=np.int16), im_gA.gt, im_gA.prj)
tiles = im_gA.tiles((1000, 1000)) # use tiles to save memory
for ((rS, rE), (cS, cE)), tiledata in (tqdm(tiles) if progress else tiles):
tgt_im[rS: rE + 1, cS: cE + 1, :] = SR.resample_image(tiledata.astype(np.int16))
return tgt_im
@staticmethod
def cluster_image_and_get_uniform_samples(im, n_clusters, tgt_n_samples, logger=None, v=False):
logger = logger or GMS_logger('cluster_logger')
# compute KMeans clusters for the spectrally resampled image
self.logger.info('Computing %s KMeans clusters...' % n_clusters)
kmeans = KMeansRSImage(im_tgt, n_clusters=n_clusters)
logger.info('Computing %s KMeans clusters...' % n_clusters)
kmeans = KMeansRSImage(im, n_clusters=n_clusters)
if self.v:
if v:
kmeans.plot_cluster_centers()
kmeans.plot_cluster_histogram()
# randomly grab the given number of spectra from each cluster
self.logger.info('Getting %s random spectra from each cluster...' % (tgt_n_samples // n_clusters))
logger.info('Getting %s random spectra from each cluster...' % (tgt_n_samples // n_clusters))
random_samples = kmeans.get_random_spectra_from_each_cluster(samplesize=tgt_n_samples // n_clusters)
# combine the spectra (2D arrays) of all clusters to a single 2D array
logger.info('Combining random samples from all clusters.')
random_samples = np.vstack([random_samples[clusterlabel] for clusterlabel in random_samples])
# return random samples or cache them on disk in multiprocessing
if self.CPUs > 1:
GeoArray(random_samples, nodata=-9999).save(os.path.join(self.tmpdir_multiproc, 'random_samples', im_name))
return random_samples
def _get_uniform_random_samples(self, im_num, src_im, tgt_srf, n_clusters, tgt_n_samples, path_out=''):
im_resamp = self.perform_spectral_resampling(src_im, tgt_srf)
random_samples = self.cluster_image_and_get_uniform_samples(im_resamp, n_clusters, tgt_n_samples)
if path_out:
GeoArray(random_samples).save(path_out)
else:
return random_samples
return im_num, random_samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment