Commit 7360e31b authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Revised L2B_P.SpecHomo_Classifier. Multiprocessing now parallelizes on block...

Revised L2B_P.SpecHomo_Classifier. Multiprocessing now parallelizes on block level, not on image level anymore. Removed deprecated functions. Improved tests.
Former-commit-id: 452e2d2a
Former-commit-id: 38ca93be
parent c414fd96
......@@ -5,7 +5,7 @@ Level 2B Processor: Spectral homogenization
import os
import numpy as np
import logging
import logging # noqa F401 # flake8 issue
from scipy.interpolate import interp1d
import scipy as sp
import matplotlib.pyplot as plt
......@@ -14,12 +14,9 @@ from pandas import DataFrame
from typing import Union, List # noqa F401 # flake8 issue
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import tempfile
import time
from sklearn.cluster import k_means_ # noqa F401 # flake8 issue
from geoarray import GeoArray # noqa F401 # flake8 issue
from py_tools_ds.numeric.array import get_array_tilebounds
from py_tools_ds.processing.progress_mon import ProgressBar
from ..config import GMS_config as CFG
......@@ -78,30 +75,24 @@ class L2B_object(L2A_object):
class SpectralResampler(object):
"""Class for spectral resampling of a single spectral signature (1D-array) or an image (3D-array)."""
def __init__(self, wvl_src, srf_tgt, wvl_unit='nanometers', logger=None):
def __init__(self, wvl_src, srf_tgt, logger=None):
# type: (np.ndarray, SRF, str) -> None
"""Get an instance of the SpectralResampler class.
:param wvl_src: center wavelength positions of the source spectrum
:param srf_tgt: spectral response of the target instrument as an instance of io.Input_reader.SRF.
:param wvl_unit: the wavelengths unit of the source wavelenth positions ('nanometers' or 'micrometers)
"""
# validate inputs
if wvl_unit not in ['micrometers', 'nanometers']:
raise ValueError('Unknown wavelength unit %s.' % wvl_unit)
# privates
self._wvl_1nm = None
self._srf_1nm = {}
wvl = np.array(wvl_src, dtype=np.float).flatten()
wvl_src = np.array(wvl_src, dtype=np.float).flatten()
if srf_tgt.wvl_unit != 'nanometers':
srf_tgt.convert_wvl_unit()
self.wvl_src_nm = wvl if wvl_unit == 'nanometers' else wvl * 1000
self.wvl_src_nm = wvl_src if max(wvl_src) > 100 else wvl_src * 1000
self.srf_tgt = srf_tgt
self.wvl_unit = wvl_unit
self.logger = logger or logging.getLogger(__name__)
self.logger = logger or GMS_logger(__name__) # must be pickable
@property
def wvl_1nm(self):
......@@ -161,50 +152,49 @@ class SpectralResampler(object):
return np.array(spectrum_rsp)
def resample_image(self, image_cube, tiledims=(20, 20)):
def resample_image(self, image_cube, tiledims=(20, 20), CPUs=None):
# type: (Union[GeoArray, np.ndarray], tuple) -> np.ndarray
"""Resample the given spectral image cube according to the spectral response functions of the target instrument.
:param image_cube: image (3D array) containing the spectral information in the third dimension
:param tiledims: dimension of tiles to be used during computation
:param CPUs: CPUs to use for processing
:return: resampled spectral image cube
"""
# input validation
if not isinstance(image_cube, (GeoArray, np.ndarray)):
raise TypeError(image_cube)
if not image_cube.ndim == 3:
raise ValueError("The given image cube must be 3-dimensional. Received a %s-dimensional array."
% image_cube.ndim)
assert image_cube.shape[2] == self.wvl_src_nm.size
tilebounds = get_array_tilebounds(array_shape=image_cube.shape, tile_shape=tiledims)
image_cube = GeoArray(image_cube)
(R, C), B = image_cube.shape[:2], len(self.srf_tgt.bands)
image_rsp = np.zeros((R, C, B), dtype=image_cube.dtype)
for bounds in tilebounds:
(rs, re), (cs, ce) = bounds
tile = image_cube[rs: re+1, cs: ce+1, :] if image_cube.ndim == 3 else image_cube[rs: re+1, cs: ce+1]
tile_rsp = self._specresample(tile)
if CPUs is None or CPUs > 1:
with Pool(CPUs) as pool:
tiles_rsp = pool.starmap(self._specresample, image_cube.tiles(tiledims))
else:
tiles_rsp = [self._specresample(bounds, tiledata) for bounds, tiledata in image_cube.tiles(tiledims)]
if image_cube.ndim == 3:
image_rsp[rs: re + 1, cs: ce + 1, :] = tile_rsp
else:
image_rsp[rs: re + 1, cs: ce + 1] = tile_rsp
for ((rS, rE), (cS, cE)), tile_rsp in tiles_rsp:
image_rsp[rS: rE + 1, cS: cE + 1, :] = tile_rsp
return image_rsp
def _specresample(self, tile):
def _specresample(self, tilebounds, tiledata):
# spectral resampling of input image to 1 nm resolution
tile_1nm = interp1d(self.wvl_src_nm, tile,
tile_1nm = interp1d(self.wvl_src_nm, tiledata,
axis=2, bounds_error=False, fill_value=0, kind='linear')(self.wvl_1nm)
tile_rsp = np.zeros((*tile_1nm.shape[:2], len(self.srf_tgt.bands)), dtype=tile.dtype)
tile_rsp = np.zeros((*tile_1nm.shape[:2], len(self.srf_tgt.bands)), dtype=tiledata.dtype)
for band_idx, band in enumerate(self.srf_tgt.bands):
# compute the resampled image cube (np.average computes the weighted mean value)
tile_rsp[:, :, band_idx] = np.average(tile_1nm, weights=self.srf_1nm[band], axis=2)
return tile_rsp
return tilebounds, tile_rsp
class KMeansRSImage(object):
......@@ -320,6 +310,7 @@ class KMeansRSImage(object):
df.insert(0, 'cluster_label', self.clusters.labels_)
# get random sample from each cluster and generate a dict like {cluster_label: random_sample}
print('Generating random samples from clusters.')
random_samples = dict()
for label in range(self.n_clusters):
cluster_subset = df[df.cluster_label == label].loc[:, 'B1':]
......@@ -343,69 +334,6 @@ class SpecHomo_Classifier(object):
self.CPUs = CPUs or cpu_count()
self.tmpdir_multiproc = ''
def generate_reference_cubeOLD(self, tgt_satellite, tgt_sensor, n_clusters=10, tgt_n_samples=1000, path_out='',
fmt_out='ENVI'):
# type: (str, str, int, int, str, str) -> np.ndarray
"""Generate reference spectra from all hyperspectral input images.
The hyperspectral images are spectrally resampled to the target sensor specifications. The resulting target
sensor image is then clustered and the same number of spectra is randomly selected from each cluster. All
spectra are combined into a single 'reference cube' containing the same number of spectra for each cluster
whereas the spectra orginate from all the input images.
:param tgt_satellite: target satellite, e.g., 'Landsat-8'
:param tgt_sensor: target sensor, e.g.. 'OLI_TIRS'
:param n_clusters: number of clusters to be used for clustering the input images (KMeans)
:param tgt_n_samples: number o spectra to be collected from each input image
:param path_out: output path for the generated reference cube
:param fmt_out: output format (GDAL driver code)
:return: np.array: [tgt_n_samples x images x spectral bands of the target sensor]
"""
self.logger.info('Generating reference spectra from all input images...')
# get SRFs
self.logger.info('Reading spectral response functions of target sensor...')
tgt_srf = SRF(dict(Satellite=tgt_satellite, Sensor=tgt_sensor, Subsystem=None, image_type='RSD',
proc_level='L1A', logger=self.logger))
if self.v:
tgt_srf.plot_srfs()
# Build the reference cube from random samples of each image
# => rows: tgt_n_samples, columns: images, bands: spectral information
# generate random spectra samples equally for each KMeans cluster
args = [(im, tgt_srf, n_clusters, tgt_n_samples) for im in self.ims_ref]
if self.CPUs > 1:
processes = len(self.ims_ref) if self.CPUs > len(self.ims_ref) else self.CPUs
with tempfile.TemporaryDirectory() as tmpdir:
self.tmpdir_multiproc = tmpdir
with Pool(processes) as pool:
pool.starmap(self._get_uniform_random_samples, args)
# combine temporarily saved random samples to ref_cube
self.logger.info('Combining random samples to reference cube...')
self.ref_cube = np.zeros((tgt_n_samples, len(self.ims_ref), len(tgt_srf.bands)))
for im_n, im in enumerate(self.ims_ref):
path_randsampl = os.path.join(tmpdir, 'random_samples', os.path.basename(im))
self.logger.info('Adding content of %s to reference cube...' % im)
self.ref_cube[:, im_n, :] = GeoArray(path_randsampl)[:]
else:
self.ref_cube = np.zeros((tgt_n_samples, len(self.ims_ref), len(tgt_srf.bands)))
for im_n, argset in enumerate(args):
# combine returned random samples to ref_cube
random_samples = self._get_uniform_random_samples(*argset)
self.logger.info('Adding random samples for %s to reference cube...' % argset[0])
self.ref_cube[:, im_n, :] = random_samples
# save
if path_out:
GeoArray(self.ref_cube).save(out_path=path_out, fmt=fmt_out)
return self.ref_cube
def generate_reference_cube(self, tgt_satellite, tgt_sensor, n_clusters=10, tgt_n_samples=1000, path_out='',
fmt_out='ENVI', progress=True):
# type: (str, str, int, int, str, str, bool) -> np.ndarray
......@@ -422,7 +350,7 @@ class SpecHomo_Classifier(object):
:param tgt_n_samples: number o spectra to be collected from each input image
:param path_out: output path for the generated reference cube
:param fmt_out: output format (GDAL driver code)
:param progress: show progress bar
:param progress: show progress bar (default: True)
:return: np.array: [tgt_n_samples x images x spectral bands of the target sensor]
"""
self.logger.info('Generating reference spectra from all input images...')
......@@ -438,98 +366,80 @@ class SpecHomo_Classifier(object):
# Build the reference cube from random samples of each image
# => rows: tgt_n_samples, columns: images, bands: spectral information
# generate random spectra samples equally for each KMeans cluster
args = ((im_num, im, tgt_srf, n_clusters, tgt_n_samples) for im_num, im in enumerate(self.ims_ref))
if self.CPUs > 1:
processes = len(self.ims_ref) if self.CPUs > len(self.ims_ref) else self.CPUs
with Pool(processes) as pool:
results = pool.starmap_async(self._get_uniform_random_samples, args, chunksize=1)
bar = ProgressBar(prefix='\tprogress:')
while True:
time.sleep(.1)
# this does not really represent the remaining tasks but the remaining chunks -> thus chunksize=1
# noinspection PyProtectedMember
numberDone = len(self.ims_ref) - results._number_left
if progress:
bar.print_progress(percent=numberDone / len(self.ims_ref) * 100)
if results.ready():
# <= this is the line where multiprocessing may freeze if an exception occurrs
results = results.get()
break
else:
results = []
bar = ProgressBar(prefix='\tprogress:')
for i, argset in enumerate(args):
if progress:
bar.print_progress((i + 1) / len(self.ims_ref) * 100)
results.append(self._get_uniform_random_samples(*argset))
# combine returned random samples to ref_cube
self.ref_cube = np.zeros((tgt_n_samples, len(self.ims_ref), len(tgt_srf.bands)))
for im_num, random_samples in results:
bar = ProgressBar(prefix='\t overall progress:')
if progress:
bar.print_progress(0 / len(self.ims_ref) * 100)
for im_num, im in enumerate(self.ims_ref):
self.logger.info('Generating random samples for %s (shape: %s)'
% (os.path.basename(im), GeoArray(im).shape))
im_rsp = self.perform_spectral_resampling(im, tgt_srf, progress=progress)
random_samples = self.cluster_image_and_get_uniform_samples(im_rsp, n_clusters, tgt_n_samples)
self.logger.info('Adding random samples of %s to reference cube...'
% os.path.basename(self.ims_ref[im_num]))
self.ref_cube[:, im_num, :] = random_samples
if progress:
bar.print_progress((im_num + 1) / len(self.ims_ref) * 100)
# save
if path_out:
GeoArray(self.ref_cube).save(out_path=path_out, fmt=fmt_out)
return self.ref_cube
@staticmethod
def perform_spectral_resampling(src_im, tgt_srf, logger=None, progress=False):
# type: (str, SRF, str) -> Union[GeoArray, None]
im_name = os.path.basename(src_im)
def perform_spectral_resampling(self, src_im, tgt_srf, progress=False):
# type: (str, SRF, bool) -> Union[GeoArray, None]
"""Perform spectral resampling of the given image to match the given spectral response functions.
logger = logger or GMS_logger('spectral_resamp_%s' % im_name)
:param src_im: source image to be resampled
:param tgt_srf: target spectral response functions to be used for spectral resampling
:param progress: show progress bar (default: false)
:return:
"""
im_name = os.path.basename(src_im)
# read input image
logger.info('Reading the input image %s...' % im_name)
self.logger.info('Reading the input image %s...' % im_name)
im_gA = GeoArray(src_im)
im_gA.cwl = np.array(im_gA.meta.loc['wavelength'], dtype=np.float).flatten()
wvl_unit = 'nanometers' if max(im_gA.cwl) > 15 else 'micrometers'
# perform spectral resampling of input image to match spectral properties of target sensor
logger.info('Performing spectral resampling to match spectral properties of target sensor...')
SR = SpectralResampler(im_gA.cwl, tgt_srf, wvl_unit=wvl_unit)
self.logger.info('Performing spectral resampling to match spectral properties of target sensor...')
SR = SpectralResampler(im_gA.cwl, tgt_srf)
tgt_im = GeoArray(np.zeros((*im_gA.shape[:2], len(tgt_srf.bands)), dtype=np.int16), im_gA.gt, im_gA.prj)
tiles = im_gA.tiles((1000, 1000)) # use tiles to save memory
for ((rS, rE), (cS, cE)), tiledata in (tqdm(tiles) if progress else tiles):
tgt_im[rS: rE + 1, cS: cE + 1, :] = SR.resample_image(tiledata.astype(np.int16))
tgt_im[rS: rE + 1, cS: cE + 1, :] = SR.resample_image(tiledata.astype(np.int16), CPUs=self.CPUs)
return tgt_im
@staticmethod
def cluster_image_and_get_uniform_samples(im, n_clusters, tgt_n_samples, logger=None, v=False):
logger = logger or GMS_logger('cluster_logger')
def cluster_image_and_get_uniform_samples(self, im, n_clusters, tgt_n_samples):
# type: (Union[GeoArray, np.ndarray], int, int) -> np.ndarray
"""Compute KMeans clusters for the given image and return the an array of uniform random samples.
:param im: image to be clustered
:param n_clusters: number of clusters to use
:param tgt_n_samples: number of returned random samples
:return:
"""
# compute KMeans clusters for the spectrally resampled image
logger.info('Computing %s KMeans clusters...' % n_clusters)
kmeans = KMeansRSImage(im, n_clusters=n_clusters)
self.logger.info('Computing %s KMeans clusters...' % n_clusters)
kmeans = KMeansRSImage(im, n_clusters=n_clusters, CPUs=self.CPUs)
if v:
if self.v:
kmeans.plot_cluster_centers()
kmeans.plot_cluster_histogram()
# randomly grab the given number of spectra from each cluster
logger.info('Getting %s random spectra from each cluster...' % (tgt_n_samples // n_clusters))
self.logger.info('Getting %s random spectra from each cluster...' % (tgt_n_samples // n_clusters))
random_samples = kmeans.get_random_spectra_from_each_cluster(samplesize=tgt_n_samples // n_clusters)
# combine the spectra (2D arrays) of all clusters to a single 2D array
logger.info('Combining random samples from all clusters.')
self.logger.info('Combining random samples from all clusters.')
random_samples = np.vstack([random_samples[clusterlabel] for clusterlabel in random_samples])
return random_samples
def _get_uniform_random_samples(self, im_num, src_im, tgt_srf, n_clusters, tgt_n_samples, path_out=''):
im_resamp = self.perform_spectral_resampling(src_im, tgt_srf)
random_samples = self.cluster_image_and_get_uniform_samples(im_resamp, n_clusters, tgt_n_samples)
if path_out:
GeoArray(random_samples).save(path_out)
else:
return im_num, random_samples
......@@ -43,7 +43,8 @@ class Test_SpectralResampler(unittest.TestCase):
spectrum = self.geoArr[0, 0, :].flatten()
sr = SR(spectrum_wvl, self.srf_l8)
sr.resample_signature(spectrum)
sig_rsp = sr.resample_signature(spectrum)
self.assertTrue(np.any(sig_rsp), msg='Output signature is empty.')
def test_resample_image(self):
# Get a hyperspectral spectrum.
......@@ -51,4 +52,5 @@ class Test_SpectralResampler(unittest.TestCase):
image = self.geoArr[:]
sr = SR(image_wvl, self.srf_l8)
sr.resample_image(image)
im_rsp = sr.resample_image(image)
self.assertTrue(np.any(im_rsp), msg='Output image is empty.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment