Tie_Point_Grid.py 49 KB
Newer Older
1
2
3
4
5
6
# -*- coding: utf-8 -*-

import collections
import multiprocessing
import os
import warnings
7
import time
8
9

# custom
10
11
12
13
try:
    import gdal
except ImportError:
    from osgeo import gdal
14
import numpy as np
15
from matplotlib import pyplot as plt
16
17
18
from geopandas import GeoDataFrame, GeoSeries
from shapely.geometry import Point
from skimage.measure import points_in_poly, ransac
19
from skimage.transform import AffineTransform, PolynomialTransform
20
21

# internal modules
22
23
from .CoReg import COREG
from . import io as IO
24
25
from py_tools_ds.geo.projection import isProjectedOrGeographic, get_UTMzone, dict_to_proj4, proj4_to_WKT
from py_tools_ds.io.pathgen import get_generic_outpath
26
from py_tools_ds.processing.progress_mon import ProgressBar
27
from py_tools_ds.geo.vector.conversion import points_to_raster
28
from geoarray import GeoArray
29

30
__author__ = 'Daniel Scheffler'
31

32
global_shared_imref = None
33
34
35
global_shared_im2shift = None


36
37
38
39
40
41
42
43
44
45
46
def mp_initializer(imref, imtgt):
    """Declare global variables needed for self._get_spatial_shifts()

    :param imref:   reference image
    :param imtgt:   target image
    """
    global global_shared_imref, global_shared_im2shift
    global_shared_imref = imref
    global_shared_im2shift = imtgt


47
48
class Tie_Point_Grid(object):
    """See help(Tie_Point_Grid) for documentation!"""
49

50
    def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
51
52
                 tieP_filter_level=3, outlDetect_settings=None, dir_out=None, CPUs=None, progress=True, v=False,
                 q=False):
53

54
55
56
        """Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
        are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
        correction performs a polynomial transformation using te calculated shifts of each point in the grid as GCPs.
57
        Thus 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target image.
58

59
        :param COREG_obj(object):       an instance of COREG class
60
        :param grid_res:                grid resolution in pixels of the target image (x-direction)
61
        :param max_points(int):         maximum number of points used to find coregistration tie points
62
63
64
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
Daniel Scheffler's avatar
Daniel Scheffler committed
65
        :param outFillVal(int):         if given the generated tie points grid is filled with this value in case
66
                                        no match could be found during co-registration (default: -9999)
67
68
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
69
70
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average,
                                                           mode, max, min, med, q1, q3)
71
                                        default: cubic (highly recommended)
72
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 3).
73
                                        NOTE: lower levels are also included if a higher level is chosen
74
                                            - Level 0: no tie point filtering
75
76
77
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
78
79
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
80
                                            - Level 3: RANSAC outlier detection
81
82
83
84
        :param outlDetect_settings      a dictionary with the settings to be passed to
                                        arosics.TiePointGrid.Tie_Point_Refiner. Available keys: min_reliability,
                                        rs_max_outlier, rs_tolerance, rs_max_iter, rs_exclude_previous_outliers,
                                        rs_timeout, q. See documentation there.
85
86
        :param dir_out(str):            output directory to be used for all outputs if nothing else is given
                                        to the individual methods
Daniel Scheffler's avatar
Daniel Scheffler committed
87
        :param CPUs(int):               number of CPUs to use during calculation of tie points grid
88
                                        (default: None, which means 'all CPUs available')
89
        :param progress(bool):          show progress bars (default: True)
90
91
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
92
        """
93

94
95
        if not isinstance(COREG_obj, COREG):
            raise ValueError("'COREG_obj' must be an instance of COREG class.")
96

97
98
99
100
101
        self.COREG_obj = COREG_obj
        self.grid_res = grid_res
        self.max_points = max_points
        self.outFillVal = outFillVal
        self.rspAlg_calc = resamp_alg_calc
102
        self.tieP_filter_level = tieP_filter_level
103
        self.outlDetect_settings = outlDetect_settings if outlDetect_settings else dict(q=q)
104
105
106
107
108
        self.dir_out = dir_out
        self.CPUs = CPUs
        self.v = v
        self.q = q if not v else False  # overridden by v
        self.progress = progress if not q else False  # overridden by q
109

110
111
        self.ref = self.COREG_obj.ref
        self.shift = self.COREG_obj.shift
112

113
        self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        self._CoRegPoints_table = None  # set by self.CoRegPoints_table
        self._GCPList = None  # set by self.to_GCPList()
        self.kriged = None  # set by Raster_using_Kriging()

    mean_x_shift_px = property(lambda self:
                               self.CoRegPoints_table['X_SHIFT_PX'][
                                   self.CoRegPoints_table['X_SHIFT_PX'] != self.outFillVal].mean())
    mean_y_shift_px = property(lambda self:
                               self.CoRegPoints_table['Y_SHIFT_PX'][
                                   self.CoRegPoints_table['Y_SHIFT_PX'] != self.outFillVal].mean())
    mean_x_shift_map = property(lambda self:
                                self.CoRegPoints_table['X_SHIFT_M'][
                                    self.CoRegPoints_table['X_SHIFT_M'] != self.outFillVal].mean())
    mean_y_shift_map = property(lambda self:
                                self.CoRegPoints_table['Y_SHIFT_M'][
                                    self.CoRegPoints_table['Y_SHIFT_M'] != self.outFillVal].mean())
130

131
132
    @property
    def CoRegPoints_table(self):
133
134
        """Returns a GeoDataFrame with the columns 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE',
        'Y_WIN_SIZE','X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE' containing all
135
        information containing all the results from coregistration for all points in the tie points grid.
136
        """
137
138
139
140
141
142
143
144
145
146
147
148
        if self._CoRegPoints_table is not None:
            return self._CoRegPoints_table
        else:
            self._CoRegPoints_table = self.get_CoRegPoints_table()
            return self._CoRegPoints_table

    @CoRegPoints_table.setter
    def CoRegPoints_table(self, CoRegPoints_table):
        self._CoRegPoints_table = CoRegPoints_table

    @property
    def GCPList(self):
149
150
        """Returns a list of GDAL compatible GCP objects.
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
151

152
153
154
155
        if self._GCPList:
            return self._GCPList
        else:
            self._GCPList = self.to_GCPList()
156
            return self._GCPList
157
158
159
160
161
162

    @GCPList.setter
    def GCPList(self, GCPList):
        self._GCPList = GCPList

    def _get_imXY__mapXY_points(self, grid_res):
163
164
165
166
167
168
        """Returns a numpy array containing possible positions for coregistration tie points according to the given
        grid resolution.

        :param grid_res:
        :return:
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
169

170
        if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
171
            print('Initializing tie points grid...')
172

173
174
        Xarr, Yarr = np.meshgrid(np.arange(0, self.shift.shape[1], grid_res),
                                 np.arange(0, self.shift.shape[0], grid_res))
175

176
177
        mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr * self.shift.gt[1]
        mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr * abs(self.shift.gt[5])
178

179
180
181
        XY_points = np.empty((Xarr.size, 2), Xarr.dtype)
        XY_points[:, 0] = Xarr.flat
        XY_points[:, 1] = Yarr.flat
182

183
184
185
        XY_mapPoints = np.empty((mapXarr.size, 2), mapXarr.dtype)
        XY_mapPoints[:, 0] = mapXarr.flat
        XY_mapPoints[:, 1] = mapYarr.flat
186

Daniel Scheffler's avatar
Daniel Scheffler committed
187
188
        assert XY_points.shape == XY_mapPoints.shape

189
        return XY_points, XY_mapPoints
190

191
192
193
194
195
196
197
198
199
200
201
    def _exclude_bad_XYpos(self, GDF):
        """Excludes all points outside of the image overlap area and all points where the bad data mask is True (if given).

        :param GDF:     <geopandas.GeoDataFrame> must include the columns 'X_UTM' and 'Y_UTM'
        :return:
        """

        # exclude all points outside of overlap area
        inliers = points_in_poly(self.XY_mapPoints,
                                 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
        GDF = GDF[inliers].copy()
202
        # GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower
203

204
205
        # FIXME track that
        assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!'
206
207

        # exclude all point where bad data mask is True (e.g. points on clouds etc.)
208
209
210
211
212
213
        orig_len_GDF = len(GDF)  # length of GDF after dropping all points outside the overlap polygon
        mapXY = np.array(GDF.loc[:, ['X_UTM', 'Y_UTM']])
        GDF['REF_BADDATA'] = self.COREG_obj.ref.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.ref.mask_baddata is not None else False
        GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.shift.mask_baddata is not None else False
Daniel Scheffler's avatar
Daniel Scheffler committed
214
        GDF = GDF[(~GDF['REF_BADDATA']) & (~GDF['TGT_BADDATA'])]
215
        if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
Daniel Scheffler's avatar
Daniel Scheffler committed
216
217
            if not self.q:
                print('According to the provided bad data mask(s) %s points of initially %s have been excluded.'
218
                      % (orig_len_GDF - len(GDF), orig_len_GDF))
219
220
221

        return GDF

222
223
    @staticmethod
    def _get_spatial_shifts(coreg_kwargs):
Daniel Scheffler's avatar
Daniel Scheffler committed
224
        # unpack
225
        pointID = coreg_kwargs['pointID']
226
227
        fftw_works = coreg_kwargs['fftw_works']
        del coreg_kwargs['pointID'], coreg_kwargs['fftw_works']
228

Daniel Scheffler's avatar
Daniel Scheffler committed
229
        # assertions
230
        assert global_shared_imref is not None
231
        assert global_shared_im2shift is not None
Daniel Scheffler's avatar
Daniel Scheffler committed
232
233

        # run CoReg
234
        CR = COREG(global_shared_imref, global_shared_im2shift, CPUs=1, **coreg_kwargs)
235
        CR.fftw_works = fftw_works
236
        CR.calculate_spatial_shifts()
Daniel Scheffler's avatar
Daniel Scheffler committed
237
238

        # fetch results
239
        last_err = CR.tracked_errors[-1] if CR.tracked_errors else None
240
        win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
241
242
243
        CR_res = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
                  CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
                  CR.shift_reliability, last_err]
244

245
        return [pointID] + CR_res
246

247
    def get_CoRegPoints_table(self):
248
249
        assert self.XY_points is not None and self.XY_mapPoints is not None

250
251
252
253
        # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'
        # (convert imCoords to mapCoords
        XYarr2PointGeom = np.vectorize(lambda X, Y: Point(X, Y), otypes=[Point])
        geomPoints = np.array(XYarr2PointGeom(self.XY_mapPoints[:, 0], self.XY_mapPoints[:, 1]))
254

255
        if isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'geographic':
256
            crs = dict(ellps='WGS84', datum='WGS84', proj='longlat')
257
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'projected':
258
            UTMzone = abs(get_UTMzone(prj=self.COREG_obj.shift.prj))
259
260
261
262
            south = get_UTMzone(prj=self.COREG_obj.shift.prj) < 0
            crs = dict(ellps='WGS84', datum='WGS84', proj='utm', zone=UTMzone, south=south, units='m', no_defs=True)
            if not south:
                del crs['south']
263
264
265
        else:
            crs = None

266
267
268
269
270
271
        GDF = GeoDataFrame(index=range(len(geomPoints)), crs=crs,
                           columns=['geometry', 'POINT_ID', 'X_IM', 'Y_IM', 'X_UTM', 'Y_UTM'])
        GDF['geometry'] = geomPoints
        GDF['POINT_ID'] = range(len(geomPoints))
        GDF.loc[:, ['X_IM', 'Y_IM']] = self.XY_points
        GDF.loc[:, ['X_UTM', 'Y_UTM']] = self.XY_mapPoints
272

273
274
        # exclude offsite points and points on bad data mask
        GDF = self._exclude_bad_XYpos(GDF)
275
276
277
        if GDF.empty:
            self.CoRegPoints_table = GDF
            return self.CoRegPoints_table
278

279
        # choose a random subset of points if a maximum number has been given
280
        if self.max_points and len(GDF) > self.max_points:
281
            GDF = GDF.sample(self.max_points).copy()
282

283
        # equalize pixel grids in order to save warping time
284
285
286
287
        if len(GDF) > 100:
            # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is
            # connected to that
            self.COREG_obj.equalize_pixGrids()
288

289
        # validate reference and target image inputs
290
        assert self.ref.footprint_poly  # this also checks for mask_nodata and nodata value
291
        assert self.shift.footprint_poly
292
293
294

        # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if
        # neighboured matching windows overlap during reading from disk!!
295
296
        self.ref.cache_array_subset(
            [self.COREG_obj.ref.band4match])  # only sets geoArr._arr_cache; does not change number of bands
Daniel Scheffler's avatar
Daniel Scheffler committed
297
298
        self.shift.cache_array_subset([self.COREG_obj.shift.band4match])

299
        # get all variations of kwargs for coregistration
300
        get_coreg_kwargs = lambda pID, wp: dict(
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
            pointID=pID,
            fftw_works=self.COREG_obj.fftw_works,
            wp=wp,
            ws=self.COREG_obj.win_size_XY,
            resamp_alg_calc=self.rspAlg_calc,
            footprint_poly_ref=self.COREG_obj.ref.poly,
            footprint_poly_tgt=self.COREG_obj.shift.poly,
            r_b4match=self.ref.band4match + 1,  # band4match is internally saved as index, starting from 0
            s_b4match=self.shift.band4match + 1,  # band4match is internally saved as index, starting from 0
            max_iter=self.COREG_obj.max_iter,
            max_shift=self.COREG_obj.max_shift,
            nodata=(self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
            force_quadratic_win=self.COREG_obj.force_quadratic_win,
            binary_ws=self.COREG_obj.bin_ws,
            v=False,  # otherwise this would lead to massive console output
            q=True,  # otherwise this would lead to massive console output
            ignore_errors=True
318
        )
319
        list_coreg_kwargs = (get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index)  # generator
320
321

        # run co-registration for whole grid
322
        if self.CPUs is None or self.CPUs > 1:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
323
            if not self.q:
324
                cpus = self.CPUs if self.CPUs is not None else multiprocessing.cpu_count()
325
                print("Calculating tie points grid (%s points) using %s CPU cores..." % (len(GDF), cpus))
326

327
            with multiprocessing.Pool(self.CPUs, initializer=mp_initializer, initargs=(self.ref, self.shift)) as pool:
328
329
330
331
                if self.q or not self.progress:
                    results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
                else:
                    results = pool.map_async(self._get_spatial_shifts, list_coreg_kwargs, chunksize=1)
332
                    bar = ProgressBar(prefix='\tprogress:')
333
334
                    while True:
                        time.sleep(.1)
335
336
337
                        # this does not really represent the remaining tasks but the remaining chunks
                        # -> thus chunksize=1
                        numberDone = len(GDF) - results._number_left
338
                        if self.progress:
339
                            bar.print_progress(percent=numberDone / len(GDF) * 100)
340
                        if results.ready():
341
342
343
                            # <= this is the line where multiprocessing can freeze if an exception appears within
                            # COREG ans is not raised
                            results = results.get()
344
                            break
Daniel Scheffler's avatar
Daniel Scheffler committed
345

346
        else:
347
348
349
350
351
            # declare global variables needed for self._get_spatial_shifts()
            global global_shared_imref, global_shared_im2shift
            global_shared_imref = self.ref
            global_shared_im2shift = self.shift

Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
352
            if not self.q:
353
354
355
356
                print("Calculating tie points grid (%s points) 1 CPU core..." % len(GDF))
            results = np.empty((len(geomPoints), 14), np.object)
            bar = ProgressBar(prefix='\tprogress:')
            for i, coreg_kwargs in enumerate(list_coreg_kwargs):
357
                if self.progress:
358
359
                    bar.print_progress((i + 1) / len(GDF) * 100)
                results[i, :] = self._get_spatial_shifts(coreg_kwargs)
360

361
                # merge results with GDF
362
        records = GeoDataFrame(np.array(results, np.object),
363
                               columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX', 'Y_SHIFT_PX', 'X_SHIFT_M',
364
                                        'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
365
366
                                        'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])

367
368
369
        GDF = GDF.merge(records, on='POINT_ID', how="inner")
        GDF = GDF.fillna(int(self.outFillVal))

370
371
372
        if not self.q:
            print("Found %s matches." % len(GDF[GDF.LAST_ERR == int(self.outFillVal)]))

373
        # filter tie points according to given filter level
374
        if self.tieP_filter_level > 0:
375
376
            if not self.q:
                print('Performing validity checks...')
377
            TPR = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], **self.outlDetect_settings)
378
            GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
379
            GDF = GDF.merge(GDF_filt[['POINT_ID'] + new_columns], on='POINT_ID', how="outer")
380
        GDF = GDF.fillna(int(self.outFillVal))
381

382
        self.CoRegPoints_table = GDF
383
384
385

        return self.CoRegPoints_table

386
387
388
389
    def calc_rmse(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates root mean square error of absolute shifts from the tie point grid.

Daniel Scheffler's avatar
Daniel Scheffler committed
390
        :param include_outliers:    whether to include tie points that have been marked as false-positives (if present)
391
392
393
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
394
        tbl = tbl if include_outliers else tbl[~tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else tbl
395
396
397
398
399
400
401
402
403
404
405
406
407
408

        shifts = np.array(tbl['ABS_SHIFT'])
        shifts_sq = [i * i for i in shifts if i != self.outFillVal]

        return np.sqrt(sum(shifts_sq) / len(shifts_sq))

    def calc_overall_mssim(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates the median value of all MSSIM values contained in tie point grid.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
409
        tbl = tbl if include_outliers else tbl[~tbl['OUTLIER']].copy()
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430

        mssim_col = np.array(tbl['MSSIM'])
        mssim_col = [i * i for i in mssim_col if i != self.outFillVal]

        return float(np.median(mssim_col))

    def plot_shift_distribution(self, include_outliers=True, unit='m', interactive=False, figsize=None, xlim=None,
                                ylim=None, fontsize=12, title='shift distribution'):
        # type: (bool, str, bool, tuple, list, list, int) -> tuple
        """Creates a 2D scatterplot containing the distribution of calculated X/Y-shifts.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        :param unit:                'm' for meters or 'px' for pixels (default: 'm')
        :param interactive:         interactive mode uses plotly for visualization
        :param figsize:             (xdim, ydim)
        :param xlim:                [xmin, xmax]
        :param ylim:                [ymin, ymax]
        :param fontsize:            size of all used fonts
        :param title:               the title to be plotted above the figure
        """

431
432
        if unit not in ['m', 'px']:
            raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." % unit)
433
434

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
435
        tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal]
Daniel Scheffler's avatar
Daniel Scheffler committed
436
437
        tbl_il = tbl[~tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else tbl
        tbl_ol = tbl[tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else None
438
439
        x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
        y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
440
441
        rmse = self.calc_rmse(include_outliers=False)  # always exclude outliers when calculating RMSE
        figsize = figsize if figsize else (10, 10)
442
443
444
445

        if interactive:
            from plotly.offline import iplot, init_notebook_mode
            import plotly.graph_objs as go
Daniel Scheffler's avatar
Daniel Scheffler committed
446
            # FIXME outliers are not plotted
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467

            init_notebook_mode(connected=True)

            # Create a trace
            trace = go.Scatter(
                x=tbl_il[x_attr],
                y=tbl_il[y_attr],
                mode='markers'
            )

            data = [trace]

            # Plot and embed in ipython notebook!
            iplot(data, filename='basic-scatter')

            return None, None

        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111)

Daniel Scheffler's avatar
Daniel Scheffler committed
468
            if include_outliers and 'OUTLIER' in tbl.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
469
                ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives')
470
471
472
473
474
475
476
477
478
479
480
481
            ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')

            # set axis limits
            if not xlim:
                xmax = np.abs(tbl_il[x_attr]).max()
                xlim = [-xmax, xmax]
            if not ylim:
                ymax = np.abs(tbl_il[y_attr]).max()
                ylim = [-ymax, ymax]
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

Daniel Scheffler's avatar
Daniel Scheffler committed
482
            # add text box containing RMSE of plotted shifts
483
            xlim, ylim = ax.get_xlim(), ax.get_ylim()
484
485
            plt.text(xlim[1] - (xlim[1] / 20), -ylim[1] + (ylim[1] / 20),
                     'RMSE:  %s m / %s px' % (np.round(rmse, 2), np.round(rmse / self.shift.xgsd, 2)),
486
                     ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
487

Daniel Scheffler's avatar
Daniel Scheffler committed
488
            # add grid and increase linewidth of middle line
489
490
491
            plt.grid()
            xgl = ax.get_xgridlines()
            middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
492
            middle_xgl.set_linewidth(2)
493
494
495
            middle_xgl.set_linestyle('-')
            ygl = ax.get_ygridlines()
            middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
496
            middle_ygl.set_linewidth(2)
497
498
            middle_ygl.set_linestyle('-')

Daniel Scheffler's avatar
Daniel Scheffler committed
499
500
            # set title and adjust tick labels
            ax.set_title(title, fontsize=fontsize)
501
502
            [tick.label.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
            [tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
Daniel Scheffler's avatar
Daniel Scheffler committed
503
504
            plt.xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
            plt.ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
505

506
507
            # add legend with labels in the right order
            handles, labels = ax.get_legend_handles_labels()
Daniel Scheffler's avatar
Daniel Scheffler committed
508
509
            leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3)
            leg.get_frame().set_edgecolor('black')
510

511
512
513
514
            plt.show()

            return fig, ax

515
    def dump_CoRegPoints_table(self, path_out=None):
516
517
518
519
520
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=self.dir_out, fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl"
                                                                % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                                                   self.COREG_obj.win_size_XY[1], self.shift.basename,
                                                                   self.ref.basename))
521
522
523
        if not self.q:
            print('Writing %s ...' % path_out)
        self.CoRegPoints_table.to_pickle(path_out)
524

525
    def to_GCPList(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
526
        # get copy of tie points grid without no data
Daniel Scheffler's avatar
Daniel Scheffler committed
527
528
529
530
531
        try:
            GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
        except AttributeError:
            # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded
            return []
532

533
        if getattr(GDF, 'empty'):  # GDF.empty returns AttributeError
534
535
            return []
        else:
536
            # exclude all points flagged as outliers
537
            if 'OUTLIER' in GDF.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
538
                GDF = GDF[GDF.OUTLIER == False].copy()
539
540
            avail_TP = len(GDF)

541
542
543
544
            if not avail_TP:
                # no point passed all validity checks
                return []

545
            if avail_TP > 7000:
546
547
548
                GDF = GDF.sample(7000)
                warnings.warn('By far not more than 7000 tie points can be used for warping within a limited '
                              'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen '
549
                              'out of the %s available tie points.' % avail_TP)
550

551
552
553
            # calculate GCPs
            GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
            GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
554
555
            GDF['GCP'] = GDF.apply(lambda GDF_row: gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
                                                            GDF_row.X_IM, GDF_row.Y_IM), axis=1)
556
557
558
            self.GCPList = GDF.GCP.tolist()

            if not self.q:
559
                print('Found %s valid tie points.' % len(self.GCPList))
560
561

            return self.GCPList
562

563
    def test_if_singleprocessing_equals_multiprocessing_result(self):
564
565
        # RANSAC filtering always produces different results because it includes random sampling
        self.tieP_filter_level = 1
566

Daniel Scheffler's avatar
Daniel Scheffler committed
567
        self.CPUs = None
568
        dataframe = self.get_CoRegPoints_table()
569
        mp_out = np.empty_like(dataframe.values)
570
        mp_out[:] = dataframe.values
Daniel Scheffler's avatar
Daniel Scheffler committed
571
        self.CPUs = 1
572
        dataframe = self.get_CoRegPoints_table()
573
        sp_out = np.empty_like(dataframe.values)
574
575
        sp_out[:] = dataframe.values

576
        return np.array_equal(sp_out, mp_out)
577

578
579
    def _get_line_by_PID(self, PID):
        return self.CoRegPoints_table.loc[PID, :]
580

581
    def _get_lines_by_PIDs(self, PIDs):
582
583
584
585
        assert isinstance(PIDs, list)
        lines = np.zeros((len(PIDs), self.CoRegPoints_table.shape[1]))
        for i, PID in enumerate(PIDs):
            lines[i, :] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
586
587
        return lines

588
    def to_PointShapefile(self, path_out=None, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):
589
        # type: (str, bool, str) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
590
        """Writes the calculated tie points grid to a point shapefile containing
591
        Tie_Point_Grid.CoRegPoints_table as attribute table. This shapefile can easily be displayed using GIS software.
592
593
594

        :param path_out:        <str> the output path. If not given, it is automatically defined.
        :param skip_nodata:     <bool> whether to skip all points where no valid match could be found
595
        :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to
596
597
                                identify points where no valid match could be found
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
598

599
600
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
Daniel Scheffler's avatar
Daniel Scheffler committed
601
602
603
604
605

        # replace boolean values (cannot be written)
        for col in GDF2pass.columns:
            if GDF2pass[col].dtype == np.bool:
                GDF2pass[col] = GDF2pass[col].astype(int)
606
607
608
609
610
611
612
613
        GDF2pass = GDF2pass.replace(False, 0)  # replace all remaining booleans where dtype is not np.bool but np.object
        GDF2pass = GDF2pass.replace(True, 1)

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
614
        if not self.q:
615
            print('Writing %s ...' % path_out)
616
617
        GDF2pass.to_file(path_out)

618
619
620
621
622
623
    def _to_PointShapefile(self, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):
        warnings.warn(DeprecationWarning(
            "'_tiepoints_grid_to_PointShapefile' is deprecated."  # TODO delete if other method validated
            " 'tiepoints_grid_to_PointShapefile' is much faster."))
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
624
        shapely_points = GDF2pass['geometry'].values.tolist()
625
        attr_dicts = [collections.OrderedDict(zip(GDF2pass.columns, GDF2pass.loc[i].values)) for i in GDF2pass.index]
626

627
        fName_out = "CoRegPoints_grid%s_ws%s.shp" % (self.grid_res, self.COREG_obj.win_size_XY)
628
629
630
        path_out = os.path.join(self.dir_out, fName_out)
        IO.write_shp(path_out, shapely_points, prj=self.COREG_obj.shift.prj, attrDict=attr_dicts)

631
    def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
632
        # type: (str) -> GeoArray
633
634
635
636
637
        """Saves the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield
        (e.g. using ArcGIS)

        :param path_out:    <str> the output path. If not given, it is automatically defined.
        :param fmt:         <str> output raster format string
638
639
        :param mode:        <str> The mode how the output is written ('uv' or 'md'; default: 'md')
                                    'uv': outputs X-/Y shifts
640
641
642
                                    'md': outputs magnitude and direction
        """

643
644
        assert mode in ['uv', 'md'], "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' " \
                                     "(outputs magnitude and direction)'. Got %s." % mode
645
646
        attr_b1 = 'X_SHIFT_M' if mode == 'uv' else 'ABS_SHIFT'
        attr_b2 = 'Y_SHIFT_M' if mode == 'uv' else 'ANGLE'
647

648
649
650
651
652
        xshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b1],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal=self.outFillVal)
653

654
655
656
657
658
        yshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b2],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal=self.outFillVal)
659
660
661

        out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)

662
663
664
665
666
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
667
668
669
670
671

        out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')

        return out_GA

672
    def _to_Raster_using_KrigingOLD(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
673
                                    path_out=None, tilepos=None):
674
        warnings.warn(DeprecationWarning("'to_Raster_using_KrigingOLD' is deprecated. Use to_Raster_using_Kriging "
675
                                         "instead."))  # TODO delete
676

677
678
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
679
680

        # subset if tilepos is given
681
682
683
        rows, cols = tilepos if tilepos else self.shift.shape
        GDF2pass = GDF2pass.loc[(GDF2pass['X_IM'] >= cols[0]) & (GDF2pass['X_IM'] <= cols[1]) &
                                (GDF2pass['Y_IM'] >= rows[0]) & (GDF2pass['Y_IM'] <= rows[1])]
684

685
        X_coords, Y_coords, ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'], GDF2pass[attrName]
686

687
        xmin, ymin, xmax, ymax = GDF2pass.total_bounds
688

689
690
        grid_res = outGridRes if outGridRes else int(min(xmax - xmin, ymax - ymin) / 250)
        grid_x, grid_y = np.arange(xmin, xmax + grid_res, grid_res), np.arange(ymax, ymin - grid_res, -grid_res)
691
692
693

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
694
        from pykrige.ok import OrdinaryKriging
695
696
697
698
699
700
701
702
703
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical', verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y)  # ,backend='C',)

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="Kriging__%s__grid%s_ws(%s_%s).tif"
                                          % (attrName, self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1]))
        print('Writing %s ...' % path_out)
704
        # add a half pixel grid points are centered on the output pixels
705
706
707
        xmin, ymin, xmax, ymax = xmin - grid_res / 2, ymin - grid_res / 2, xmax + grid_res / 2, ymax + grid_res / 2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res),
                                prj=self.COREG_obj.shift.prj)
708
709
710

        return zvalues

711
    def to_Raster_using_Kriging(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
712
                                fName_out=None, tilepos=None, tilesize=500, mp=None):
713

714
        mp = False if self.CPUs == 1 else True
715
716
        self._Kriging_sp(attrName, skip_nodata=skip_nodata, skip_nodata_col=skip_nodata_col,
                         outGridRes=outGridRes, fName_out=fName_out, tilepos=tilepos)
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738

        # if mp:
        #     tilepositions = UTL.get_image_tileborders([tilesize,tilesize],self.tgt_shape)
        #     args_kwargs_dicts=[]
        #     for tp in tilepositions:
        #         kwargs_dict = {'skip_nodata':skip_nodata,'skip_nodata_col':skip_nodata_col,'outGridRes':outGridRes,
        #                        'fName_out':fName_out,'tilepos':tp}
        #         args_kwargs_dicts.append({'args':[attrName],'kwargs':kwargs_dict})
        #     # self.kriged=[]
        #     # for i in args_kwargs_dicts:
        #     #     res = self.Kriging_mp(i)
        #     #     self.kriged.append(res)
        #     #     print(res)
        #
        #     with multiprocessing.Pool() as pool:
        #        self.kriged = pool.map(self.Kriging_mp,args_kwargs_dicts)
        # else:
        #     self.Kriging_sp(attrName,skip_nodata=skip_nodata,skip_nodata_col=skip_nodata_col,
        #                     outGridRes=outGridRes,fName_out=fName_out,tilepos=tilepos)
        res = self.kriged if mp else None
        return res

739
740
    def _Kriging_sp(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                    fName_out=None, tilepos=None):
741
742
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
743

744
745
746
747
748
749
750
751
752
753
        #         # subset if tilepos is given
        # #        overlap_factor =
        #         rows,cols = tilepos if tilepos else self.tgt_shape
        #         xvals, yvals = np.sort(GDF2pass['X_IM'].values.flat),np.sort(GDF2pass['Y_IM'].values.flat)
        #         cS,cE = UTL.find_nearest(xvals,cols[0],'off',1), UTL.find_nearest(xvals,cols[1],'on',1)
        #         rS,rE = UTL.find_nearest(yvals,rows[0],'off',1), UTL.find_nearest(yvals,rows[1],'on',1)
        #         # GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
        #         #                                (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]
        #         GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cS)&(GDF2pass['X_IM']<=cE)&
        #                                        (GDF2pass['Y_IM']>=rS)&(GDF2pass['Y_IM']<=rE)]
754

755
        X_coords, Y_coords, ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'], GDF2pass[attrName]
756

757
        xmin, ymin, xmax, ymax = GDF2pass.total_bounds
758

759
760
        grid_res = outGridRes if outGridRes else int(min(xmax - xmin, ymax - ymin) / 250)
        grid_x, grid_y = np.arange(xmin, xmax + grid_res, grid_res), np.arange(ymax, ymin - grid_res, -grid_res)
761
762
763

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
764
        from pykrige.ok import OrdinaryKriging
765
766
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical', verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y, backend='C', n_closest_points=12)
767

768
        if self.CPUs is None or self.CPUs > 1:
769
            fName_out = fName_out if fName_out else \
770
                "Kriging__%s__grid%s_ws%s_%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY, tilepos)
771
772
        else:
            fName_out = fName_out if fName_out else \
773
774
775
                "Kriging__%s__grid%s_ws%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY)
        path_out = get_generic_outpath(dir_out=self.dir_out, fName_out=fName_out)
        print('Writing %s ...' % path_out)
776
        # add a half pixel grid points are centered on the output pixels
777
778
779
        xmin, ymin, xmax, ymax = xmin - grid_res / 2, ymin - grid_res / 2, xmax + grid_res / 2, ymax + grid_res / 2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res),
                                prj=self.COREG_obj.shift.prj)
780
781
782

        return zvalues

783
    def _Kriging_mp(self, args_kwargs_dict):
784
785
        args = args_kwargs_dict.get('args', [])
        kwargs = args_kwargs_dict.get('kwargs', [])
786

787
        return self._Kriging_sp(*args, **kwargs)
788
789


790
class Tie_Point_Refiner(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
791
    def __init__(self, GDF, min_reliability=60, rs_max_outlier=10, rs_tolerance=2.5, rs_max_iter=15,
792
                 rs_exclude_previous_outliers=True, rs_timeout=20, q=False):
793
        """A class for performing outlier detection.
Daniel Scheffler's avatar
Daniel Scheffler committed
794

795
796
        :param GDF:                             GeoDataFrame like TiePointGrid.CoRegPoints_table containing all tie
                                                points to be filtered and the corresponding metadata
Daniel Scheffler's avatar
Daniel Scheffler committed
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
        :param min_reliability:                 <float, int> minimum threshold for previously computed tie X/Y shift
                                                reliability (default: 60%)
        :param rs_max_outlier:                  <float, int> RANSAC: maximum percentage of outliers to be detected
                                                (default: 10%)
        :param rs_tolerance:                    <float, int> RANSAC: percentage tolerance for max_outlier_percentage
                                                (default: 2.5%)
        :param rs_max_iter:                     <int> RANSAC: maximum iterations for finding the best RANSAC threshold
                                                (default: 15)
        :param rs_exclude_previous_outliers:    <bool> RANSAC: whether to exclude points that have been flagged as
                                                outlier by earlier filtering (default:True)
        :param rs_timeout:                      <float, int> RANSAC: timeout for iteration loop in seconds (default: 20)

        :param q:
        """

        self.GDF = GDF.copy()
        self.min_reliability = min_reliability
        self.rs_max_outlier_percentage = rs_max_outlier
        self.rs_tolerance = rs_tolerance
        self.rs_max_iter = rs_max_iter
        self.rs_exclude_previous_outliers = rs_exclude_previous_outliers
        self.rs_timeout = rs_timeout
        self.q = q
        self.new_cols = []
821
822
        self.ransac_model_robust = None

823
    def run_filtering(self, level=3):
824
825
        """Filter tie points used for shift correction.

826
        :param level:   tie point filter level (default: 3).
827
828
829
830
831
832
833
834
                        NOTE: lower levels are also included if a higher level is chosen
                            - Level 0: no tie point filtering
                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                reliability according to internal tests
                            - Level 2: SSIM filtering - filters all tie points out where shift
                                correction does not increase image similarity within matching window
                                (measured by mean structural similarity index)
                            - Level 3: RANSAC outlier detection
Daniel Scheffler's avatar
Daniel Scheffler committed
835
836
837
838

        :return:
        """

839
840
        # TODO catch empty GDF

841
        # RELIABILITY filtering
842
        if level > 0:
843
            marked_recs = GeoSeries(self._reliability_thresholding())
844
845
            self.GDF['L1_OUTLIER'] = marked_recs
            self.new_cols.append('L1_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
846

847
            if not self.q:
848
                print('%s tie points flagged by level 1 filtering (reliability).'
Daniel Scheffler's avatar
Daniel Scheffler committed
849
                      % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
850

851
        # SSIM filtering
852
        if level > 1:
853
            marked_recs = GeoSeries(self._SSIM_filtering())
854
855
            self.GDF['L2_OUTLIER'] = marked_recs
            self.new_cols.append('L2_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
856

857
            if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
858
                print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
859

860
        # RANSAC filtering
861
        if level > 2:
Daniel Scheffler's avatar
Daniel Scheffler committed
862
            # exclude previous outliers
Daniel Scheffler's avatar
Daniel Scheffler committed
863
            ransacInGDF = self.GDF[~self.GDF[self.new_cols].any(axis=1)].copy() \
864
                if self.rs_exclude_previous_outliers else self.GDF
Daniel Scheffler's avatar
Daniel Scheffler committed
865

866
            if len(ransacInGDF) > 4:
867
                # running RANSAC with less than four tie points makes no sense
Daniel Scheffler's avatar
Daniel Scheffler committed
868
869

                marked_recs = GeoSeries(self._RANSAC_outlier_detection(ransacInGDF))
870
871
                # we need to join a list here because otherwise it's merged by the 'index' column
                self.GDF['L3_OUTLIER'] = marked_recs.tolist()
Daniel Scheffler's avatar
Daniel Scheffler committed
872

873
                if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
874
875
                    print('%s tie points flagged by level 3 filtering (RANSAC)'
                          % (len(marked_recs[marked_recs])))
876
877
878
879
            else:
                print('RANSAC skipped because too less valid tie points have been found.')
                self.GDF['L3_OUTLIER'] = False

880
            self.new_cols.append('L3_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
881

882
883
884
885
886
        self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
        self.new_cols.append('OUTLIER')

        return self.GDF, self.new_cols

Daniel Scheffler's avatar
Daniel Scheffler committed
887
    def _reliability_thresholding(self):
888
        """Exclude all records where estimated reliability of the calculated shifts is below the given threshold."""
889

Daniel Scheffler's avatar
Daniel Scheffler committed
890
        return self.GDF.RELIABILITY < self.min_reliability
891
892

    def _SSIM_filtering(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
893
        """Exclude all records where SSIM decreased."""
894

895
        # ssim_diff  = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE'])
896

897
898
        # self.GDF.SSIM_IMPROVED = \
        #     self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1)
899

900
        return self.GDF.SSIM_IMPROVED is False
901

Daniel Scheffler's avatar
Daniel Scheffler committed
902
903
    def _RANSAC_outlier_detection(self, inGDF):
        """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm."""
904

Daniel Scheffler's avatar
Daniel Scheffler committed
905
        src_coords = np.array(inGDF[['X_UTM', 'Y_UTM']])
906
        xyShift = np.array(inGDF[['X_SHIFT_M', 'Y_SHIFT_M']])
907
908
909
        est_coords = src_coords + xyShift

        for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']):
910
            assert co.ndim == 2 and co.shape[1] == 2, "'%s' must have shape [Nx2]. Got shape %s." % (n, co.shape)
911

912
913
        if not 0 < self.rs_max_outlier_percentage < 100:
            raise ValueError