L2B_P.py 84.8 KB
Newer Older
Daniel Scheffler's avatar
Daniel Scheffler committed
1
# -*- coding: utf-8 -*-
Daniel Scheffler's avatar
Daniel Scheffler committed
2
3
4
5
"""
Level 2B Processor: Spectral homogenization
"""

6
import os
7
import numpy as np
8
import logging  # noqa F401  # flake8 issue
9
from scipy.interpolate import interp1d
10
11
import scipy as sp
import matplotlib.pyplot as plt
Daniel Scheffler's avatar
Daniel Scheffler committed
12
from sklearn.cluster import KMeans
13
from pandas import DataFrame
14
from pandas.plotting import scatter_matrix
15
from typing import Union, List, Dict, Tuple  # noqa F401  # flake8 issue
16
from tqdm import tqdm
17
from multiprocessing import Pool, cpu_count
18
from glob import glob
19
import re
20
21
import json
from collections import OrderedDict
22
import dill
23
24
25
from pprint import pformat
from nested_dict import nested_dict
import traceback
26
27
28
import zipfile
import tempfile
import time
29

Daniel Scheffler's avatar
Daniel Scheffler committed
30
from sklearn.cluster import k_means_  # noqa F401  # flake8 issue
31
from sklearn.model_selection import train_test_split
32
from sklearn.linear_model import LinearRegression, Ridge
33
from sklearn.preprocessing import PolynomialFeatures
Daniel Scheffler's avatar
Daniel Scheffler committed
34
from sklearn.pipeline import make_pipeline, Pipeline  # noqa F401  # flake8 issue
35
from geoarray import GeoArray  # noqa F401  # flake8 issue
36

37
from ..options.config import GMS_config as CFG
Daniel Scheffler's avatar
Daniel Scheffler committed
38
from ..io.input_reader import SRF  # noqa F401  # flake8 issue
39
from ..misc.logging import GMS_logger
40
41
from ..misc.definition_dicts import datasetid_to_sat_sen, sat_sen_to_datasetid
from ..misc.exceptions import ClassifierNotAvailableError
42
from ..misc.logging import close_logger
43
from ..model.gms_object import GMS_object
44
from ..model.metadata import get_LayerBandsAssignment, get_center_wavelengths_by_LBA
45
from .L2A_P import L2A_object
46
from ..model.gms_object import GMS_identifier
47
from .classification import classify_image
Daniel Scheffler's avatar
Daniel Scheffler committed
48

49
__author__ = 'Daniel Scheffler'
50
51
52


class L2B_object(L2A_object):
53
54
    def __init__(self, L2A_obj=None):

Daniel Scheffler's avatar
Daniel Scheffler committed
55
        super(L2B_object, self).__init__()
56
57
58

        if L2A_obj:
            # populate attributes
59
            [setattr(self, key, value) for key, value in L2A_obj.__dict__.items()]
60

61
        self.proc_level = 'L2B'
62
        self.proc_status = 'initialized'
63

64
    def spectral_homogenization(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
65
        """Apply spectral homogenization, i.e., prediction of the spectral bands of the target sensor."""
66
67
68
        #################################################################
        # collect some information specifying the needed homogenization #
        #################################################################
69

70
71
        method = CFG.spechomo_method
        src_dsID = sat_sen_to_datasetid(self.satellite, self.sensor)
Daniel Scheffler's avatar
Bugfix.    
Daniel Scheffler committed
72
        src_cwls = [float(self.MetaObj.CWL[bN]) for bN in self.MetaObj.LayerBandsAssignment]
73
        # FIXME exclude or include thermal bands; respect sorted CWLs in context of LayerBandsAssignment
74
        tgt_sat, tgt_sen = datasetid_to_sat_sen(CFG.datasetid_spectral_ref)
75
        # NOTE: get target LBA at L2A, because spectral characteristics of target sensor do not change after AC
76
        tgt_LBA = get_LayerBandsAssignment(
77
            GMS_identifier(satellite=tgt_sat, sensor=tgt_sen, subsystem='',
78
                           image_type='RSD', proc_level='L2A', dataset_ID=src_dsID, logger=None))
79

80
81
82
83
84
85
86
87
88
89
90
91
        if CFG.datasetid_spectral_ref is None:
            tgt_cwl = CFG.target_CWL
            tgt_fwhm = CFG.target_FWHM
        else:
            # exclude those bands from CFG.target_CWL and CFG.target_FWHM that have been removed after AC
            full_LBA = get_LayerBandsAssignment(
                GMS_identifier(satellite=tgt_sat, sensor=tgt_sen, subsystem='',
                               image_type='RSD', proc_level='L1A', dataset_ID=src_dsID, logger=None),
                no_thermal=True, no_pan=False, return_fullLBA=True, sort_by_cwl=True, proc_level='L1A')
            tgt_cwl = [dict(zip(full_LBA, CFG.target_CWL))[bN] for bN in tgt_LBA]
            tgt_fwhm = [dict(zip(full_LBA, CFG.target_FWHM))[bN] for bN in tgt_LBA]

92
93
94
        ####################################################
        # special cases where homogenization is not needed #
        ####################################################
95

96
97
98
99
        if self.dataset_ID == CFG.datasetid_spectral_ref:
            self.logger.info("Spectral homogenization has been skipped because the dataset id equals the dataset id of "
                             "the spectral refernce sensor.")
            return
100

101
        if src_cwls == CFG.target_CWL or (self.satellite == tgt_sat and self.sensor == tgt_sen):
Daniel Scheffler's avatar
Bugfix    
Daniel Scheffler committed
102
            # FIXME catch the case if LayerBandsAssignments are unequal with np.take
103
104
            self.logger.info("Spectral homogenization has been skipped because the current spectral characteristics "
                             "are already equal to the target sensor's.")
105
106
            return

107
108
109
110
        #################################################
        # perform spectral homogenization of image data #
        #################################################

Daniel Scheffler's avatar
Daniel Scheffler committed
111
        SpH = SpectralHomogenizer(classifier_rootDir=CFG.path_spechomo_classif, logger=self.logger)
112
113

        if method == 'LI' or CFG.datasetid_spectral_ref is None:
114
115
            # linear interpolation (if intended by user or in case of custom spectral characteristics of target sensor)
            # -> no classifier for that case available -> linear interpolation
116
            im = SpH.interpolate_cube(self.arr, src_cwls, tgt_cwl, kind='linear')
117
118
119
120
121
122

            if CFG.spechomo_estimate_accuracy:
                self.logger.warning("Unable to compute any error information in case spectral homogenization algorithm "
                                    "is set to 'LI' (Linear Interpolation)")

            errs = None
123
124
125

        else:
            # a known sensor has been specified as spectral reference => apply a machine learner
126
127
128
129
130
131
132
133
134
            im, errs = SpH.predict_by_machine_learner(self.arr,
                                                      method=method,
                                                      src_satellite=self.satellite,
                                                      src_sensor=self.sensor,
                                                      src_LBA=self.LayerBandsAssignment,
                                                      tgt_satellite=tgt_sat,
                                                      tgt_sensor=tgt_sen,
                                                      tgt_LBA=tgt_LBA,
                                                      nodataVal=self.arr.nodata,
135
                                                      compute_errors=CFG.spechomo_estimate_accuracy,
136
                                                      bandwise_errors=CFG.spechomo_bandwise_accuracy,
137
                                                      fallback_argskwargs=dict(
138
                                                          args=dict(source_CWLs=src_cwls, target_CWLs=tgt_cwl,),
139
140
141
                                                          kwargs=dict(kind='linear')
                                                      ))

142
143
144
        ###################
        # update metadata #
        ###################
145

146
        self.LayerBandsAssignment = tgt_LBA
147
148
149
        self.MetaObj.CWL = dict(zip(tgt_LBA, tgt_cwl))
        self.MetaObj.FWHM = dict(zip(tgt_LBA, tgt_fwhm))
        self.MetaObj.bands = len(tgt_LBA)
150

151
        self.arr = im  # type: GeoArray
152
        self.spec_homo_errors = errs  # type: Union[np.ndarray, None]  # int16, None if ~CFG.spechomo_estimate_accuracy
Daniel Scheffler's avatar
Daniel Scheffler committed
153

154
155
156
157
158
159
160
        #########################################################################################
        # perform spectral homogenization of bandwise error information from earlier processors #
        #########################################################################################

        if self.ac_errors and self.ac_errors.ndim == 3:
            self.logger.info("Performing linear interpolation for 'AC errors' array to match target sensor bands "
                             "number..")
161
            outarr = interp1d(np.array(src_cwls), self.ac_errors,
162
                              axis=2, kind='linear', fill_value='extrapolate')(tgt_cwl)
163
            self.ac_errors = outarr.astype(self.ac_errors.dtype)
164

165
166

class SpectralHomogenizer(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
167
    """Class for applying spectral homogenization by applying an interpolation or machine learning approach."""
168
    def __init__(self, classifier_rootDir='', logger=None):
Daniel Scheffler's avatar
Daniel Scheffler committed
169
170
171
172
173
        """Get instance of SpectralHomogenizer.

        :param classifier_rootDir:  root directory where machine learning classifiers are stored.
        :param logger:              instance of logging.Logger
        """
174
        self.classifier_rootDir = classifier_rootDir or CFG.path_spechomo_classif
Daniel Scheffler's avatar
Daniel Scheffler committed
175
        self.logger = logger or logging.getLogger(self.__class__.__name__)  # FIXME own logger logs nothing
176

177
178
179
180
181
182
183
184
185
186
187
    def __getstate__(self):
        """Defines how the attributes of SpectralHomogenizer instances are pickled."""
        close_logger(self.logger)
        self.logger = None

        return self.__dict__

    def __del__(self):
        close_logger(self.logger)
        self.logger = None

188
    def interpolate_cube(self, arrcube, source_CWLs, target_CWLs, kind='linear'):
Daniel Scheffler's avatar
Daniel Scheffler committed
189
        # type: (Union[np.ndarray, GeoArray], list, list) -> np.ndarray
Daniel Scheffler's avatar
Daniel Scheffler committed
190
        """Spectrally interpolate the spectral bands of a remote sensing image to new band positions.
191
192
193
194
195
196
197
198
199

        :param arrcube:     array to be spectrally interpolated
        :param source_CWLs: list of source central wavelength positions
        :param target_CWLs: list of target central wavelength positions
        :param kind:        interpolation kind to be passed to scipy.interpolate.interp1d (default: 'linear')
        :return:
        """
        assert kind in ['linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'], \
            "%s is not a supported kind of spectral interpolation." % kind
200
        assert arrcube is not None,\
201
            'L2B_obj.interpolate_cube_linear expects a numpy array as input. Got %s.' % type(arrcube)
202

203
        orig_CWLs, target_CWLs = np.array(source_CWLs), np.array(target_CWLs)
204

Daniel Scheffler's avatar
Daniel Scheffler committed
205
206
207
        self.logger.info(
            'Performing spectral homogenization (%s interpolation) with target wavelength positions at %s nm.'
            % (kind, ', '.join(np.array(target_CWLs[:-1]).astype(str)) + ' and %s' % target_CWLs[-1]))
208
        outarr = interp1d(np.array(orig_CWLs), arrcube, axis=2, kind=kind, fill_value='extrapolate')(target_CWLs)
209
        outarr = outarr.astype(np.int16)
210

Daniel Scheffler's avatar
Daniel Scheffler committed
211
        assert outarr.shape == tuple([*arrcube.shape[:2], len(target_CWLs)])
212

213
        return outarr
Daniel Scheffler's avatar
Daniel Scheffler committed
214

215
216
217
218
    def predict_by_machine_learner(self, arrcube, method, src_satellite, src_sensor, src_LBA, tgt_satellite, tgt_sensor,
                                   tgt_LBA, n_clusters=50, classif_alg='MinDist', kNN_n_neighbors=10,
                                   nodataVal=None, compute_errors=False, bandwise_errors=True, **fallback_argskwargs):
        # type: (Union[np.ndarray, GeoArray], str, str, str, list, str, str, list, int, str, int, int, ...) -> tuple
Daniel Scheffler's avatar
Daniel Scheffler committed
219
220
221
222
223
224
        """Predict spectral bands of target sensor by applying a machine learning approach.

        :param arrcube:             input image array for target sensor spectral band prediction (rows x cols x bands)
        :param method:              machine learning approach to be used for spectral bands prediction
                                    'LR': Linear Regression
                                    'RR': Ridge Regression
225
                                    'QR': Quadratic Regression
Daniel Scheffler's avatar
Daniel Scheffler committed
226
227
228
229
230
231
        :param src_satellite:       source satellite, e.g., 'Landsat-8'
        :param src_sensor:          source sensor, e.g., 'OLI_TIRS'
        :param src_LBA:             source LayerBandsAssignment
        :param tgt_satellite:       target satellite, e.g., 'Landsat-8'
        :param tgt_sensor:          target sensor, e.g., 'OLI_TIRS'
        :param tgt_LBA:             target LayerBandsAssignment
232
233
234
235
236
237
238
239
240
241
242
243
244
        :param n_clusters:          Number of spectral clusters to be used during LR/ RR/ QR homogenization.
                                    E.g., 50 means that the image to be converted to the spectral target sensor
                                    is clustered into 50 spectral clusters and one separate machine learner per
                                    cluster is applied to the input data to predict the homogenized image. If
                                    'spechomo_n_clusters' is set to 1, the source image is not clustered and
                                    only one machine learning classifier is used for prediction.
        :param classif_alg:         Multispectral classification algorithm to be used to determine the spectral cluster
                                    each pixel belongs to.
                                    'MinDist': Minimum Distance (Nearest Centroid) Classification
                                    'kNN': k-Nearest-Neighbor Classification
                                    'SAM': Spectral Angle Mapping
        :param kNN_n_neighbors:     The number of neighbors to be considered in case 'classif_alg' is set to 'kNN'.
                                    Otherwise, this parameter is ignored.
Daniel Scheffler's avatar
Daniel Scheffler committed
245
        :param nodataVal:           no data value
246
247
        :param compute_errors:      whether to compute pixel- / bandwise model errors for estimated pixel values
                                    (default: false)
248
249
        :param bandwise_errors      whether to compute error information for each band separately (True - default)
                                    or to average errors over bands using median (False) (ignored in case of fallback)
250
        :param fallback_argskwargs: arguments and keyword arguments for fallback algorithm ({'args':{}, 'kwargs': {}}
Daniel Scheffler's avatar
Daniel Scheffler committed
251
        :return:                    predicted array (rows x columns x bands)
252
        :rtype:                     Tuple[np.ndarray, Union[np.ndarray, None]]
Daniel Scheffler's avatar
Daniel Scheffler committed
253
        """
254
        # TODO: add LBA validation to .predict()
255
256
257
258
259
260
261
262
263
        # if n_clusters > 1:
        PR = RSImage_ClusterPredictor(method=method,
                                      classifier_rootDir=self.classifier_rootDir,
                                      n_clusters=n_clusters,
                                      classif_alg=classif_alg,
                                      kNN_n_neighbors=kNN_n_neighbors)
        # else:
        #     PR = RSImage_Predictor(method=method,
        #                            classifier_rootDir=self.classifier_rootDir)
264

265
        ######################
266
267
        # get the classifier #
        ######################
268

269
270
271
272
        cls = None
        exc = Exception()
        try:
            cls = PR.get_classifier(src_satellite, src_sensor, src_LBA, tgt_satellite, tgt_sensor, tgt_LBA)
273

274
275
276
277
278
        except FileNotFoundError as e:
            self.logger.warning('No machine learning classifier available that fulfills the specifications of the '
                                'spectral reference sensor. Falling back to linear interpolation for performing '
                                'spectral homogenization.')
            exc = e
279

280
281
282
283
284
285
        except ClassifierNotAvailableError as e:
            self.logger.error('\nAn error occurred during spectral homogenization using machine learning. '
                              'Falling back to linear interpolation. Error message was: ')
            self.logger.error(traceback.format_exc())
            exc = e

286
        ##################
287
288
        # run prediction #
        ##################
289

290
        errors = None
291
292
293
        if cls:
            self.logger.info('Performing spectral homogenization using %s. Target is %s %s %s.'
                             % (method, tgt_satellite, tgt_sensor, tgt_LBA))
294
            im_homo = PR.predict(arrcube, classifier=cls, nodataVal=nodataVal)
295
296
            if compute_errors:
                errors = PR.compute_prediction_errors(im_homo, cls, nodataVal=nodataVal)
297

298
299
300
                if not bandwise_errors:
                    errors = np.median(errors, axis=2).astype(errors.dtype)

Daniel Scheffler's avatar
Bugfix    
Daniel Scheffler committed
301
        elif fallback_argskwargs:
302
303
            # fallback: use linear interpolation and set errors to an array of zeros
            im_homo = self.interpolate_cube(arrcube, *fallback_argskwargs['args'], **fallback_argskwargs['kwargs'])
304
            if compute_errors:
305
306
                self.logger.warning("Spectral homogenization algorithm had to be performed by linear interpolation "
                                    "(fallback). Unable to compute any accuracy information from that.")
307
308
309
310
                if not bandwise_errors:
                    errors = np.zeros_like(im_homo, dtype=np.int16)
                else:
                    errors = np.zeros(im_homo.shape[:2], dtype=np.int16)
311

312
313
314
        else:
            raise exc

315
        return im_homo, errors
316

Daniel Scheffler's avatar
Daniel Scheffler committed
317

318
319
class SpectralResampler(object):
    """Class for spectral resampling of a single spectral signature (1D-array) or an image (3D-array)."""
320

321
    def __init__(self, wvl_src, srf_tgt, logger=None):
Daniel Scheffler's avatar
Daniel Scheffler committed
322
        # type: (np.ndarray, SRF, str) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
323
        """Get an instance of the SpectralResampler class.
324

325
326
        :param wvl_src:     center wavelength positions of the source spectrum
        :param srf_tgt:     spectral response of the target instrument as an instance of io.Input_reader.SRF.
327
        """
328
329
330
331
        # privates
        self._wvl_1nm = None
        self._srf_1nm = {}

332
        wvl_src = np.array(wvl_src, dtype=np.float).flatten()
333
334
335
        if srf_tgt.wvl_unit != 'nanometers':
            srf_tgt.convert_wvl_unit()

336
        self.wvl_src_nm = wvl_src if max(wvl_src) > 100 else wvl_src * 1000
337
        self.srf_tgt = srf_tgt
338
        self.logger = logger or GMS_logger(__name__)  # must be pickable
339

340
341
342
343
344
345
346
347
348
349
350
    def __getstate__(self):
        """Defines how the attributes of SpectralResampler instances are pickled."""
        close_logger(self.logger)
        self.logger = None

        return self.__dict__

    def __del__(self):
        close_logger(self.logger)
        self.logger = None

351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    @property
    def wvl_1nm(self):
        # spectral resampling of input image to 1 nm resolution
        if self._wvl_1nm is None:
            self._wvl_1nm = np.arange(np.ceil(self.wvl_src_nm.min()), np.floor(self.wvl_src_nm.max()), 1)
        return self._wvl_1nm

    @property
    def srf_1nm(self):
        if not self._srf_1nm:
            for band in self.srf_tgt.bands:
                # resample srf to 1 nm
                self._srf_1nm[band] = \
                    sp.interpolate.interp1d(self.srf_tgt.srfs_wvl, self.srf_tgt.srfs[band],
                                            bounds_error=False, fill_value=0, kind='linear')(self.wvl_1nm)

                # validate
                assert len(self._srf_1nm[band]) == len(self.wvl_1nm)

        return self._srf_1nm

372
373
374
    def resample_signature(self, spectrum, scale_factor=10000, v=False):
        # type: (np.ndarray, int, bool) -> np.ndarray
        """Resample the given spectrum according to the spectral response functions of the target instument.
375

376
377
378
379
        :param spectrum:        spectral signature data
        :param scale_factor:    the scale factor to apply to the given spectrum when it is plotted (default: 10000)
        :param v:               enable verbose mode (shows a plot of the resampled spectrum) (default: False)
        :return:    resampled spectral signature
380
        """
381
382
383
        if not spectrum.ndim == 1:
            raise ValueError("The array of the given spectral signature must be 1-dimensional. "
                             "Received a %s-dimensional array." % spectrum.ndim)
384
385
386
387
        spectrum = np.array(spectrum, dtype=np.float).flatten()
        assert spectrum.size == self.wvl_src_nm.size

        # resample input spectrum and wavelength to 1nm
388
        spectrum_1nm = interp1d(self.wvl_src_nm, spectrum,
389
                                bounds_error=False, fill_value=0, kind='linear')(self.wvl_1nm)
390
391
392

        if v:
            plt.figure()
393
            plt.plot(self.wvl_1nm, spectrum_1nm/scale_factor, '.')
394
395

        spectrum_rsp = []
396
397
398

        for band, wvl_center in zip(self.srf_tgt.bands, self.srf_tgt.wvl):
            # compute the resampled spectral value (np.average computes the weighted mean value)
399
            specval_rsp = np.average(spectrum_1nm, weights=self.srf_1nm[band])
400

401
            if v:
402
                plt.plot(self.wvl_1nm, self.srf_1nm[band]/max(self.srf_1nm[band]))
403
                plt.plot(wvl_center, specval_rsp/scale_factor, 'x', color='r')
404

405
            spectrum_rsp.append(specval_rsp)
406

407
408
        return np.array(spectrum_rsp)

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    def resample_spectra(self, spectra, chunksize=200, CPUs=None):
        # type: (Union[GeoArray, np.ndarray], int) -> np.ndarray
        """Resample the given spectral signatures according to the spectral response functions of the target instrument.

        :param spectra:     spectral signatures, provided as 2D array
                            (rows: spectral samples, columns: spectral information / bands)
        :param chunksize:   defines how many spectral signatures are resampled per CPU
        :param CPUs:        CPUs to use for processing
        """
        # input validation
        if not spectra.ndim == 2:
            ValueError("The given spectra array must be 2-dimensional. Received a %s-dimensional array."
                       % spectra.ndim)
        assert spectra.shape[1] == self.wvl_src_nm.size

        # convert spectra to one multispectral image column
        im_col = spectra.reshape(spectra.shape[0], 1, spectra.shape[1])

        im_col_rsp = self.resample_image(im_col, tiledims=(1, chunksize), CPUs=CPUs)
        spectra_rsp = im_col_rsp.reshape(im_col_rsp.shape[0], im_col_rsp.shape[2])

        return spectra_rsp

432
    def resample_image(self, image_cube, tiledims=(20, 20), CPUs=None):
433
        # type: (Union[GeoArray, np.ndarray], tuple) -> np.ndarray
434
        """Resample the given spectral image cube according to the spectral response functions of the target instrument.
435
436

        :param image_cube:      image (3D array) containing the spectral information in the third dimension
437
        :param tiledims:        dimension of tiles to be used during computation (rows, columns)
438
        :param CPUs:            CPUs to use for processing
439
440
441
442
443
444
445
446
        :return:    resampled spectral image cube
        """
        # input validation
        if not image_cube.ndim == 3:
            raise ValueError("The given image cube must be 3-dimensional. Received a %s-dimensional array."
                             % image_cube.ndim)
        assert image_cube.shape[2] == self.wvl_src_nm.size

447
        image_cube = GeoArray(image_cube)
448
449
450
451

        (R, C), B = image_cube.shape[:2], len(self.srf_tgt.bands)
        image_rsp = np.zeros((R, C, B), dtype=image_cube.dtype)

452
453
454
455
456
457
        if CPUs is None or CPUs > 1:
            with Pool(CPUs) as pool:
                tiles_rsp = pool.starmap(self._specresample, image_cube.tiles(tiledims))

        else:
            tiles_rsp = [self._specresample(bounds, tiledata) for bounds, tiledata in image_cube.tiles(tiledims)]
458

459
460
        for ((rS, rE), (cS, cE)), tile_rsp in tiles_rsp:
            image_rsp[rS: rE + 1, cS: cE + 1, :] = tile_rsp
461

462
463
        return image_rsp

464
    def _specresample(self, tilebounds, tiledata):
465
        # spectral resampling of input image to 1 nm resolution
466
        tile_1nm = interp1d(self.wvl_src_nm, tiledata,
467
468
                            axis=2, bounds_error=False, fill_value=0, kind='linear')(self.wvl_1nm)

469
        tile_rsp = np.zeros((*tile_1nm.shape[:2], len(self.srf_tgt.bands)), dtype=tiledata.dtype)
470
        for band_idx, band in enumerate(self.srf_tgt.bands):
471
            # compute the resampled image cube (np.average computes the weighted mean value)
472
            tile_rsp[:, :, band_idx] = np.average(tile_1nm, weights=self.srf_1nm[band], axis=2)
Daniel Scheffler's avatar
Daniel Scheffler committed
473

474
        return tilebounds, tile_rsp
475
476
477


class KMeansRSImage(object):
478
    """Class for clustering a given input image by using K-Means algorithm."""
479
    def __init__(self, im, n_clusters, CPUs=1, v=False):
480
        # type: (GeoArray, int) -> None
481
482
483
484
485
486

        # privates
        self._clusters = None
        self._im_clust = None
        self._spectra = None

487
488
        self.im = im
        self.n_clusters = n_clusters
489
        self.CPUs = CPUs
490
        self.v = v
491

492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    @property
    def clusters(self):
        # type: () -> k_means_.KMeans
        if not self._clusters:
            self._clusters = self.compute_clusters()
        return self._clusters

    @clusters.setter
    def clusters(self, clusters):
        self._clusters = clusters

    @property
    def im_clust(self):
        if self._im_clust is None:
            self._im_clust = self.clusters.labels_.reshape((self.im.rows, self.im.cols))
        return self._im_clust

509
    def compute_clusters(self):
510
        kmeans = KMeans(n_clusters=self.n_clusters, random_state=0, n_jobs=self.CPUs, verbose=self.v)
511
        self.clusters = kmeans.fit(self._im2spectra(self.im))
512
513
514
515

        return self.clusters

    def apply_clusters(self, image):
516
        labels = self.clusters.predict(self._im2spectra(GeoArray(image)))
517
518
        return labels

519
520
521
522
    @staticmethod
    def _im2spectra(geoArr):
        return geoArr.reshape((geoArr.rows * geoArr.cols, geoArr.bands))

523
524
525
526
527
528
529
530
531
532
533
534
535
536
    def plot_cluster_centers(self, figsize=(15, 5)):
        # type: (tuple) -> None
        """Show a plot of the cluster center signatures.

        :param figsize:     figure size (inches)
        """
        plt.figure(figsize=figsize)
        for i, center_signature in enumerate(self.clusters.cluster_centers_):
            plt.plot(range(1, self.im.bands + 1), center_signature, label='Cluster #%s' % (i + 1))

        plt.title('KMeans cluster centers for %s clusters' % self.n_clusters)
        plt.xlabel('Spectral band')
        plt.ylabel('Pixel values')
        plt.legend(loc='upper right')
537

538
        plt.show()
539

540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
    def plot_cluster_histogram(self, figsize=(15, 5)):
        # type: (tuple) -> None
        """Show a histogram indicating the proportion of each cluster label in percent.

        :param figsize:     figure size (inches)
        """
        # grab the number of different clusters and create a histogram
        # based on the number of pixels assigned to each cluster
        numLabels = np.arange(0, len(np.unique(self.clusters.labels_)) + 1)
        hist, bins = np.histogram(self.clusters.labels_, bins=numLabels)

        # normalize the histogram, such that it sums to 100
        hist = hist.astype("float")
        hist /= hist.sum() / 100

        # plot the histogram as bar plot
        plt.figure(figsize=figsize)

        plt.bar(bins[:-1], hist, width=1)

        plt.title('Proportion of cluster labels (%s clusters)' % self.n_clusters)
        plt.xlabel('# Cluster')
        plt.ylabel('Proportion [%]')

        plt.show()

    def plot_clustered_image(self, figsize=(15, 15)):
        # type: (tuple) -> None
        """Show a the clustered image.

        :param figsize:     figure size (inches)
        """
        plt.figure(figsize=figsize)
        rows, cols = self.im_clust.shape[:2]
Daniel Scheffler's avatar
Daniel Scheffler committed
574
        plt.imshow(self.im_clust, plt.get_cmap('prism'), interpolation='none', extent=(0, cols, rows, 0))
575
        plt.show()
576

577
    def get_random_spectra_from_each_cluster(self, samplesize=50, src_im=None):
578
        # type: (int, GeoArray) -> dict
579
        """Returns a given number of spectra randomly selected within each cluster.
580

581
        E.g., 50 spectra belonging to cluster 1, 50 spectra belonging to cluster 2 and so on.
582
583

        :param samplesize:  number of spectra to be randomly selected from each cluster
584
        :param src_im:      image to get random samples from (default: self.im)
585
586
        :return:
        """
587
        src_im = src_im if src_im is not None else self.im
588
589

        # get DataFrame with columns [cluster_label, B1, B2, B3, ...]
590
        df = DataFrame(self._im2spectra(src_im), columns=['B%s' % band for band in range(1, src_im.bands + 1)], )
591
592
593
594
595
596
597
598
599
600
601
602
        df.insert(0, 'cluster_label', self.clusters.labels_)

        # get random sample from each cluster and generate a dict like {cluster_label: random_sample}
        random_samples = dict()
        for label in range(self.n_clusters):
            cluster_subset = df[df.cluster_label == label].loc[:, 'B1':]
            # get random sample while filling it with duplicates of the same sample when cluster has not enough spectra
            random_samples[label] = np.array(cluster_subset.sample(samplesize, replace=True))

        return random_samples


603
class TrainingData(object):
604
    """Class for analyzing statistical relations between a pair of machine learning training data cubes."""
605
    def __init__(self, im_X, im_Y, test_size):
Daniel Scheffler's avatar
Daniel Scheffler committed
606
607
608
609
610
611
612
        # type: (Union[GeoArray, np.ndarray], Union[GeoArray, np.ndarray], Union[float, int]) -> None
        """Get instance of TrainingData.

        :param im_X:        input image X
        :param im_Y:        input image Y
        :param test_size:   test size (proportion as float between 0 and 1) or absolute value as integer
        """
613
614
615
616
        self.im_X = GeoArray(im_X)
        self.im_Y = GeoArray(im_Y)

        # Set spectra (3D to 2D conversion)
617
618
        self.spectra_X = im2spectra(self.im_X)
        self.spectra_Y = im2spectra(self.im_Y)
619
620
621
622
623
624
625

        # Set train and test variables
        # NOTE: If random_state is set to an Integer, train_test_split will always select the same 'pseudo-random' set
        #       of the input data.
        self.train_X, self.test_X, self.train_Y, self.test_Y = \
            train_test_split(self.spectra_X, self.spectra_Y, test_size=test_size, shuffle=True, random_state=0)

626
    def plot_scatter_matrix(self, figsize=(15, 15), mode='intersensor'):
Daniel Scheffler's avatar
Daniel Scheffler committed
627
        # TODO complete this function
628
629
630
        train_X = self.train_X[np.random.choice(self.train_X.shape[0], 1000, replace=False), :]
        train_Y = self.train_Y[np.random.choice(self.train_Y.shape[0], 1000, replace=False), :]

631
632
633
634
635
        if mode == 'intersensor':
            import seaborn

            fig, axes = plt.subplots(train_X.shape[1], train_Y.shape[1],
                                     figsize=(25, 9), sharex='all', sharey='all')
636
637
            # fig.suptitle('Correlation of %s and %s bands' % (self.src_cube.satellite, self.tgt_cube.satellite),
            #              size=25)
638
639
640
641
642
643

            color = seaborn.hls_palette(13)

            for i, ax in zip(range(train_X.shape[1]), axes.flatten()):
                for j, ax in zip(range(train_Y.shape[1]), axes.flatten()):
                    axes[i, j].scatter(train_X[:, j], train_Y[:, i], c=color[j], label=str(j))
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
                    # axes[i, j].set_xlim(-0.1, 1.1)
                    # axes[i, j].set_ylim(-0.1, 1.1)
                    #  if j == 8:
                    #      axes[5, j].set_xlabel('S2 B8A\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', size=10)
                    #  elif j in range(9, 13):
                    #      axes[5, j].set_xlabel('S2 B' + str(j) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm',
                    #                            size=10)
                    #  else:
                    #      axes[5, j].set_xlabel('S2 B' + str(j + 1) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm',
                    #                            size=10)
                    #  axes[i, 0].set_ylabel(
                    #      'S3 SLSTR B' + str(6 - i) + '\n' + str(metadata_s3['Bands_S3'][5 - i]) + ' nm',
                    #      size=10)
                    # axes[4, j].set_xticks(np.arange(0, 1.2, 0.2))
                    # axes[i, j].plot([0, 1], [0, 1], c='red')
659

660
661
662
663
664
665
666
667
        else:
            df = DataFrame(train_X, columns=['Band %s' % b for b in range(1, self.im_X.bands + 1)])
            scatter_matrix(df, figsize=figsize, marker='.', hist_kwds={'bins': 50}, s=30, alpha=0.8)
            plt.suptitle('Image X band to band correlation')

            df = DataFrame(train_Y, columns=['Band %s' % b for b in range(1, self.im_Y.bands + 1)])
            scatter_matrix(df, figsize=figsize, marker='.', hist_kwds={'bins': 50}, s=30, alpha=0.8)
            plt.suptitle('Image Y band to band correlation')
668

669
    def plot_scattermatrix(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
670
        # TODO complete this function
671
        import seaborn
672

673
        fig, axes = plt.subplots(self.im_X.data.bands, self.im_Y.data.bands,
674
                                 figsize=(25, 9), sharex='all', sharey='all')
675
        fig.suptitle('Correlation of %s and %s bands' % (self.im_X.satellite, self.im_Y.satellite), size=25)
676

677
        color = seaborn.hls_palette(13)
678

679
680
        for i, ax in zip(range(6), axes.flatten()):
            for j, ax in zip(range(13), axes.flatten()):
681
                axes[i, j].scatter(self.train_X[:, j], self.train_Y[:, 5 - i], c=color[j], label=str(j))
682
683
                axes[i, j].set_xlim(-0.1, 1.1)
                axes[i, j].set_ylim(-0.1, 1.1)
684
685
686
687
688
689
690
691
692
                # if j == 8:
                #     axes[5, j].set_xlabel('S2 B8A\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', size=10)
                # elif j in range(9, 13):
                #     axes[5, j].set_xlabel('S2 B' + str(j) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm', size=10)
                # else:
                #     axes[5, j].set_xlabel('S2 B' + str(j + 1) + '\n' + str(metadata_s2['Bands_S2'][j]) + ' nm',
                #                           size=10)
                # axes[i, 0].set_ylabel('S3 SLSTR B' + str(6 - i) + '\n' + str(metadata_s3['Bands_S3'][5 - i]) + ' nm',
                #                       size=10)
693
694
                axes[4, j].set_xticks(np.arange(0, 1.2, 0.2))
                axes[i, j].plot([0, 1], [0, 1], c='red')
695

696
    def show_band_scatterplot(self, band_src_im, band_tgt_im):
Daniel Scheffler's avatar
Daniel Scheffler committed
697
        # TODO complete this function
698
        from scipy.stats import gaussian_kde
699

700
701
        x = self.im_X.data[band_src_im].flatten()[:10000]
        y = self.im_Y.data[band_tgt_im].flatten()[:10000]
702

703
704
705
        # Calculate the point density
        xy = np.vstack([x, y])
        z = gaussian_kde(xy)(xy)
706

707
        plt.figure(figsize=(15, 15))
708
709
        plt.scatter(x, y, c=z, s=30, edgecolor='')
        plt.show()
710
711


712
def im2spectra(geoArr):
Daniel Scheffler's avatar
Daniel Scheffler committed
713
    # type: (Union[GeoArray, np.ndarray]) -> np.ndarray
714
715
    """Convert 3D images to array of spectra samples (rows: samples;  cols: spectral information)."""
    return geoArr.reshape((geoArr.shape[0] * geoArr.shape[1], geoArr.shape[2]))
716
717


Daniel Scheffler's avatar
Daniel Scheffler committed
718
719
720
721
722
723
724
725
726
727
def spectra2im(spectra, tgt_rows, tgt_cols):
    # type: (Union[GeoArray, np.ndarray], int, int) -> np.ndarray
    """Convert array of spectra samples (rows: samples;  cols: spectral information) to a 3D image.

    :param spectra:     2D array with rows: spectral samples / columns: spectral information (bands)
    :param tgt_rows:    number of target image rows
    :param tgt_cols:    number of target image rows
    :return:            3D array (rows x columns x spectral bands)
    """
    return spectra.reshape(tgt_rows, tgt_cols, spectra.shape[1])
728
729


730
731
class ReferenceCube_Generator(object):
    """Class for creating reference cube that are later used as training data for SpecHomo_Classifier."""
732

733
734
735
736
737
    def __init__(self, filelist_refs, tgt_sat_sen_list=None, dir_refcubes='', n_clusters=10, tgt_n_samples=1000,
                 v=False, logger=None, CPUs=None):
        # type: (List[str], List[Tuple[str, str]], str, int, int, bool, logging.Logger, Union[None, int]) -> None
        """Initialize ReferenceCube_Generator.

738
739
        :param filelist_refs:   list of (hyperspectral) reference images,
                                representing BOA reflectance, scaled between 0 and 10000
740
741
742
743
744
745
746
747
        :param tgt_sat_sen_list:    list satellite/sensor tuples containing those sensors for which the reference cube
                                    is to be computed, e.g. [('Landsat-8', 'OLI_TIRS',), ('Landsat-5', 'TM')]
        :param dir_refcubes:    output directory for the generated reference cube
        :param n_clusters:      number of clusters to be used for clustering the input images (KMeans)
        :param tgt_n_samples:   number o spectra to be collected from each input image
        :param v:               verbose mode
        :param logger:          instance of logging.Logger()
        :param CPUs:            number CPUs to use for computation
748
        """
749
750
751
        # args + kwargs
        self.ims_ref = [filelist_refs, ] if isinstance(filelist_refs, str) else filelist_refs
        self.tgt_sat_sen_list = tgt_sat_sen_list or [
752
753
754
755
756
757
758
759
760
761
762
            ('Landsat-8', 'OLI_TIRS'),
            ('Landsat-7', 'ETM+'),
            ('Landsat-5', 'TM'),
            ('Sentinel-2A', 'MSI'),
            # ('Terra', 'ASTER'),  # currently does not work
            ('SPOT-4', 'HRVIR1'),
            ('SPOT-4', 'HRVIR2'),
            ('SPOT-5', 'HRG1'),
            ('SPOT-5', 'HRG2'),
            ('RapidEye-5', 'MSI')
            ]
763
764
765
        self.dir_refcubes = os.path.abspath(dir_refcubes) if dir_refcubes else ''
        self.n_clusters = n_clusters
        self.tgt_n_samples = tgt_n_samples
766
767
768
769
        self.v = v
        self.logger = logger or GMS_logger(__name__)  # must be pickable
        self.CPUs = CPUs or cpu_count()

770
        # privates
771
        self._refcubes = \
772
773
774
            {(sat, sen): RefCube(satellite=sat, sensor=sen,
                                 LayerBandsAssignment=self._get_tgt_LayerBandsAssignment(sat, sen))
             for sat, sen in self.tgt_sat_sen_list}
775

776
        # validation
777
        if dir_refcubes and not os.path.isdir(self.dir_refcubes):
778
779
            raise ValueError("%s is not a directory." % self.dir_refcubes)

780
781
782
783
784
785
786
787
788
789
790
    def __getstate__(self):
        """Defines how the attributes of ReferenceCube_Generator instances are pickled."""
        close_logger(self.logger)
        self.logger = None

        return self.__dict__

    def __del__(self):
        close_logger(self.logger)
        self.logger = None

791
792
    @property
    def refcubes(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
793
794
        """Return a dictionary holding instances of RefCube for each target satellite / sensor of self.tgt_sat_sen_list.
        """
795
        # type: () -> Dict[Tuple[str, str]: RefCube]
796
797
798
799
800
        if not self._refcubes:

            # fill self._ref_cubes with GeoArray instances of already existing reference cubes read from disk
            if self.dir_refcubes:
                for path_refcube in glob(os.path.join(self.dir_refcubes, 'refcube__*.bsq')):
801
                    # TODO check if that really works
802
803
804
                    # check if current refcube path matches the target refcube specifications
                    identifier = re.search('refcube__(.*).bsq', os.path.basename(path_refcube)).group(1)
                    sat, sen, nclust_str, nsamp_str = identifier.split('__')  # type: str
805
                    nclust, nsamp = int(nclust_str.split('nclust')[1]), int(nsamp_str.split('nclust')[1])
806
807
808
809
810
811
                    correct_specs = all([(sat, sen) in self.tgt_sat_sen_list,
                                         nclust == self.n_clusters,
                                         nsamp == self.tgt_n_samples])

                    # import the existing ref cube if it matches the target refcube specs
                    if correct_specs:
812
813
814
                        self._refcubes[(sat, sen)] = \
                            RefCube(satellite=sat, sensor=sen, filepath=path_refcube,
                                    LayerBandsAssignment=self._get_tgt_LayerBandsAssignment(sat, sen))
815
816
817

        return self._refcubes

818
    def _get_tgt_GMS_identifier(self, tgt_sat, tgt_sen):
819
        # type: (str, str) -> GMS_identifier
Daniel Scheffler's avatar
Daniel Scheffler committed
820
821
822
823
824
825
        """Get a GMS identifier for the specified target sensor such that all possible bands are included (L1A)

        :param tgt_sat:     target satellite
        :param tgt_sen:     target sensor
        :return:
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
826
        return GMS_identifier(satellite=tgt_sat, sensor=tgt_sen, subsystem=None, image_type='RSD', dataset_ID=-9999,
827
                              proc_level='L1A', logger=self.logger)  # use L1A to have all bands available
828
829

    def _get_tgt_LayerBandsAssignment(self, tgt_sat, tgt_sen):
Daniel Scheffler's avatar
Daniel Scheffler committed
830
831
832
833
834
835
836
837
838
839
        # type: (str, str) -> list
        """Get the LayerBandsAssignment for the specified target sensor.

        NOTE:   The returned bands list always contains all possible bands. Specific band seletions are later done
                using np.take().

        :param tgt_sat:     target satellite
        :param tgt_sen:     target sensor
        :return:
        """
840
841
842
        return get_LayerBandsAssignment(self._get_tgt_GMS_identifier(tgt_sat, tgt_sen), no_pan=False)

    def _get_tgt_SRF_object(self, tgt_sat, tgt_sen):
Daniel Scheffler's avatar
Daniel Scheffler committed
843
844
845
846
847
848
849
        # type: (str, str) -> SRF
        """Get an SRF instance containing the spectral response functions for for the specified target sensor.

        :param tgt_sat:     target satellite
        :param tgt_sen:     target sensor
        :return:
        """
850
851
        return SRF(self._get_tgt_GMS_identifier(tgt_sat, tgt_sen), no_pan=False)

852
853
854
    def generate_reference_cubes(self, fmt_out='ENVI', progress=True):
        # type: (str, bool) -> self.refcubes
        """Generate reference spectra from all hyperspectral input images.
855

856
857
858
859
860
861
862
863
864
865
866
867
        Workflow:
        1. Clustering/classification of hyperspectral images and selection of a given number of random signatures
            (a. Spectral downsamling to lower spectral resolution (speedup))
            b. KMeans clustering
            c. Selection of the same number of signatures from each cluster to avoid unequal amount of training data.
        2. Spectral resampling of the selected hyperspectral signatures (for each input image)
        3. Add resampled spectra to reference cubes for each target sensor and write cubes to disk

        :param fmt_out:         output format (GDAL driver code)
        :param progress:        show progress bar (default: True)
        :return:                np.array: [tgt_n_samples x images x spectral bands of the target sensor]
        """
868
        for im in self.ims_ref:
869
            # TODO implement check if current image is already included in the refcube -> skip in that case
870
871
            src_im = GeoArray(im)

872
            # get random spectra of the original (hyperspectral) image, equally distributed over all computed clusters
873
            unif_random_spectra = self.cluster_image_and_get_uniform_spectra(src_im, progress=progress).astype(np.int16)
874
875

            # resample the set of random spectra to match the spectral characteristics of all target sensors
876
            for tgt_sat, tgt_sen in self.tgt_sat_sen_list:
877
                # perform spectral resampling
878
879
880
881
                self.logger.info('Performing spectral resampling to match %s %s specifications...' % (tgt_sat, tgt_sen))
                unif_random_spectra_rsp = \
                    self.resample_spectra(unif_random_spectra,
                                          src_cwl=np.array(src_im.meta.loc['wavelength'], dtype=np.float).flatten(),
882
                                          tgt_srf=self._get_tgt_SRF_object(tgt_sat, tgt_sen))
883

884
                # add the spectra as GeoArray instance to the in-mem ref cubes
885
                refcube = self.refcubes[(tgt_sat, tgt_sen)]  # type: RefCube
886
887
                refcube.add_spectra(unif_random_spectra_rsp, src_imname=src_im.basename,
                                    LayerBandsAssignment=self._get_tgt_LayerBandsAssignment(tgt_sat, tgt_sen))
888

889
                # update the reference cubes on disk
890
                if self.dir_refcubes:
891
892
893
                    refcube.save(path_out=os.path.join(self.dir_refcubes, 'refcube__%s__%s__nclust%s__nsamp%s.bsq'
                                                       % (tgt_sat, tgt_sen, self.n_clusters, self.tgt_n_samples)),
                                 fmt=fmt_out)
894
895
896

        return self.refcubes

897
898
    def cluster_image_and_get_uniform_spectra(self, im, downsamp_sat='Sentinel-2A', downsamp_sen='MSI', progress=False):
        # type: (Union[str, GeoArray, np.ndarray], str, str, bool) -> np.ndarray
899
900
901
        """Compute KMeans clusters for the given image and return the an array of uniform random samples.

        :param im:              image to be clustered
902
903
904
905
        :param downsamp_sat:    satellite code used for intermediate image dimensionality reduction (input image is
                                spectrally resampled to this satellite before it is clustered). requires downsamp_sen.
                                If it is None, no intermediate downsampling is performed.
        :param downsamp_sen:    sensor code used for intermediate image dimensionality reduction (required downsamp_sat)
906
        :param progress:        whether to show progress bars or not
907
908
909
        :return:    2D array (rows: tgt_n_samples, columns: spectral information / bands
        """
        # input validation
910
        if downsamp_sat and not downsamp_sen or downsamp_sen and not downsamp_sat:
911
912
913
914
915
            raise ValueError("The parameters 'spec_rsp_sat' and 'spec_rsp_sen' must both be provided or completely "
                             "omitted.")

        im2clust = GeoArray(im)

916
        # first, perform spectral resampling to Sentinel-2 to reduce dimensionality (speedup)
917
        if downsamp_sat and downsamp_sen:
918
            tgt_srf = SRF(GMS_identifier(satellite=downsamp_sat, sensor=downsamp_sen, subsystem=None, image_type='RSD',
Daniel Scheffler's avatar
Daniel Scheffler committed
919
                                         dataset_ID=-9999, proc_level='L1A', logger=self.logger))
920
            im2clust = self.resample_image_spectrally(im2clust, tgt_srf, progress=progress)  # output = int16
921
922

        # compute KMeans clusters for the spectrally resampled image
923
924
925
        self.logger.info('Computing %s KMeans clusters from the input image %s...'
                         % (self.n_clusters, im2clust.basename))
        kmeans = KMeansRSImage(im2clust, n_clusters=self.n_clusters, CPUs=self.CPUs, v=self.v)
Daniel Scheffler's avatar
Daniel Scheffler committed
926
        kmeans.compute_clusters()
927
928
929
930
931
932

        if self.v:
            kmeans.plot_cluster_centers()
            kmeans.plot_cluster_histogram()

        # randomly grab the given number of spectra from each cluster
933
934
935
        self.logger.info('Getting %s random spectra from each cluster...' % (self.tgt_n_samples // self.n_clusters))
        random_samples = kmeans.get_random_spectra_from_each_cluster(src_im=GeoArray(im),
                                                                     samplesize=self.tgt_n_samples // self.n_clusters)
936
937
938
939
940
941
942
943

        # combine the spectra (2D arrays) of all clusters to a single 2D array
        self.logger.info('Combining random samples from all clusters.')
        random_samples = np.vstack([random_samples[clusterlabel] for clusterlabel in random_samples])

        return random_samples

    def resample_spectra(self, spectra, src_cwl, tgt_srf):
944
        # type: (Union[GeoArray, np.ndarray], Union[list, np.array], SRF) -> np.ndarray
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
        """Perform spectral resampling of the given image to match the given spectral response functions.

        :param spectra:     2D array (rows: spectral samples;  columns: spectral information / bands
        :param src_cwl:     central wavelength positions of input spectra
        :param tgt_srf:     target spectral response functions to be used for spectral resampling
        :return:
        """
        spectra = GeoArray(spectra)

        # perform spectral resampling of input image to match spectral properties of target sensor
        self.logger.info('Performing spectral resampling to match spectral properties of target sensor...')

        SR = SpectralResampler(src_cwl, tgt_srf)
        spectra_rsp = SR.resample_spectra(spectra, chunksize=200, CPUs=self.CPUs)

        return spectra_rsp

    def resample_image_spectrally(self, src_im, tgt_srf, progress=False):
        # type: (Union[str, GeoArray], SRF, bool) -> Union[GeoArray, None]
        """Perform spectral resampling of the given image to match the given spectral response functions.

        :param src_im:      source image to be resampled
        :param tgt_srf:     target spectral response functions to be used for spectral resampling
        :param progress:    show progress bar (default: false)
        :return:
        """
        # handle src_im provided as file path or GeoArray instance
        if isinstance(src_im, str):
            im_name = os.path.basename(src_im)
            im_gA = GeoArray(src_im)
        else:
            im_name = src_im.basename
            im_gA = src_im

        # read input image
        self.logger.info('Reading the input image %s...' % im_name)
        im_gA.cwl = np.array(im_gA.meta.loc['wavelength'], dtype=np.float).flatten()

        # perform spectral resampling of input image to match spectral properties of target sensor
        self.logger.info('Performing spectral resampling to match spectral properties of target sensor...')
        SR = SpectralResampler(im_gA.cwl, tgt_srf)

        tgt_im = GeoArray(np.zeros((*im_gA.shape[:2], len(tgt_srf.bands)), dtype=np.int16), im_gA.gt, im_gA.prj)
        tiles = im_gA.tiles((1000, 1000))  # use tiles to save memory
        for ((rS, rE), (cS, cE)), tiledata in (tqdm(tiles) if progress else tiles):
            tgt_im[rS: rE + 1, cS: cE + 1, :] = SR.resample_image(tiledata.astype(np.int16), CPUs=self.CPUs)

        return tgt_im

994

995
class RefCube(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
996
    """Data model class for reference cubes holding the training data for later fitted machine learning classifiers."""
997
    def __init__(self, filepath='', satellite='', sensor='', LayerBandsAssignment=None):
Daniel Scheffler's avatar
Daniel Scheffler committed
998
999
1000
1001
1002
1003
1004
1005
        # type: (str, str, str, list) -> None
        """Get instance of RefCube.

        :param filepath:                file path for importing an existing reference cube from disk
        :param satellite:               the satellite for which the reference cube holds its spectral data
        :param sensor:                  the sensor for which the reference cube holds its spectral data
        :param LayerBandsAssignment:    the LayerBandsAssignment for which the reference cube holds its spectral data
        """
1006
1007
        # privates
        self._col_imName_dict = dict()
1008
        self._wavelenths = []
1009
1010

        # defaults