Tie_Point_Grid.py 48.5 KB
Newer Older
1
2
3
4
5
6
# -*- coding: utf-8 -*-

import collections
import multiprocessing
import os
import warnings
7
import time
8
9

# custom
10
11
12
13
try:
    import gdal
except ImportError:
    from osgeo import gdal
14
import numpy as np
15
from matplotlib import pyplot as plt
16
17
18
from geopandas import GeoDataFrame, GeoSeries
from shapely.geometry import Point
from skimage.measure import points_in_poly, ransac
19
from skimage.transform import AffineTransform, PolynomialTransform
20
21

# internal modules
22
23
from .CoReg import COREG
from . import io as IO
24
25
from py_tools_ds.geo.projection import isProjectedOrGeographic, get_UTMzone, dict_to_proj4, proj4_to_WKT
from py_tools_ds.io.pathgen import get_generic_outpath
26
from py_tools_ds.processing.progress_mon import ProgressBar
27
from py_tools_ds.geo.vector.conversion import points_to_raster
28
from geoarray import GeoArray
29

30
__author__ = 'Daniel Scheffler'
31

32
global_shared_imref = None
33
34
35
global_shared_im2shift = None


36
37
class Tie_Point_Grid(object):
    """See help(Tie_Point_Grid) for documentation!"""
38

39
    def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
40
41
                 tieP_filter_level=3, outlDetect_settings=None, dir_out=None, CPUs=None, progress=True, v=False,
                 q=False):
42

43
44
45
        """Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
        are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
        correction performs a polynomial transformation using te calculated shifts of each point in the grid as GCPs.
46
        Thus 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target image.
47

48
        :param COREG_obj(object):       an instance of COREG class
49
        :param grid_res:                grid resolution in pixels of the target image (x-direction)
50
        :param max_points(int):         maximum number of points used to find coregistration tie points
51
52
53
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
Daniel Scheffler's avatar
Daniel Scheffler committed
54
        :param outFillVal(int):         if given the generated tie points grid is filled with this value in case
55
                                        no match could be found during co-registration (default: -9999)
56
57
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
58
59
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average,
                                                           mode, max, min, med, q1, q3)
60
                                        default: cubic (highly recommended)
61
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 3).
62
                                        NOTE: lower levels are also included if a higher level is chosen
63
                                            - Level 0: no tie point filtering
64
65
66
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
67
68
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
69
                                            - Level 3: RANSAC outlier detection
70
71
72
73
        :param outlDetect_settings      a dictionary with the settings to be passed to
                                        arosics.TiePointGrid.Tie_Point_Refiner. Available keys: min_reliability,
                                        rs_max_outlier, rs_tolerance, rs_max_iter, rs_exclude_previous_outliers,
                                        rs_timeout, q. See documentation there.
74
75
        :param dir_out(str):            output directory to be used for all outputs if nothing else is given
                                        to the individual methods
Daniel Scheffler's avatar
Daniel Scheffler committed
76
        :param CPUs(int):               number of CPUs to use during calculation of tie points grid
77
                                        (default: None, which means 'all CPUs available')
78
        :param progress(bool):          show progress bars (default: True)
79
80
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
81
        """
82

83
84
        if not isinstance(COREG_obj, COREG):
            raise ValueError("'COREG_obj' must be an instance of COREG class.")
85

86
87
88
89
90
        self.COREG_obj = COREG_obj
        self.grid_res = grid_res
        self.max_points = max_points
        self.outFillVal = outFillVal
        self.rspAlg_calc = resamp_alg_calc
91
        self.tieP_filter_level = tieP_filter_level
92
        self.outlDetect_settings = outlDetect_settings if outlDetect_settings else dict(q=q)
93
94
95
96
97
        self.dir_out = dir_out
        self.CPUs = CPUs
        self.v = v
        self.q = q if not v else False  # overridden by v
        self.progress = progress if not q else False  # overridden by q
98

99
100
        self.ref = self.COREG_obj.ref
        self.shift = self.COREG_obj.shift
101

102
        self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        self._CoRegPoints_table = None  # set by self.CoRegPoints_table
        self._GCPList = None  # set by self.to_GCPList()
        self.kriged = None  # set by Raster_using_Kriging()

    mean_x_shift_px = property(lambda self:
                               self.CoRegPoints_table['X_SHIFT_PX'][
                                   self.CoRegPoints_table['X_SHIFT_PX'] != self.outFillVal].mean())
    mean_y_shift_px = property(lambda self:
                               self.CoRegPoints_table['Y_SHIFT_PX'][
                                   self.CoRegPoints_table['Y_SHIFT_PX'] != self.outFillVal].mean())
    mean_x_shift_map = property(lambda self:
                                self.CoRegPoints_table['X_SHIFT_M'][
                                    self.CoRegPoints_table['X_SHIFT_M'] != self.outFillVal].mean())
    mean_y_shift_map = property(lambda self:
                                self.CoRegPoints_table['Y_SHIFT_M'][
                                    self.CoRegPoints_table['Y_SHIFT_M'] != self.outFillVal].mean())
119

120
121
    @property
    def CoRegPoints_table(self):
122
123
        """Returns a GeoDataFrame with the columns 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE',
        'Y_WIN_SIZE','X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE' containing all
124
        information containing all the results from coregistration for all points in the tie points grid.
125
        """
126
127
128
129
130
131
132
133
134
135
136
137
        if self._CoRegPoints_table is not None:
            return self._CoRegPoints_table
        else:
            self._CoRegPoints_table = self.get_CoRegPoints_table()
            return self._CoRegPoints_table

    @CoRegPoints_table.setter
    def CoRegPoints_table(self, CoRegPoints_table):
        self._CoRegPoints_table = CoRegPoints_table

    @property
    def GCPList(self):
138
139
        """Returns a list of GDAL compatible GCP objects.
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
140

141
142
143
144
        if self._GCPList:
            return self._GCPList
        else:
            self._GCPList = self.to_GCPList()
145
            return self._GCPList
146
147
148
149
150
151

    @GCPList.setter
    def GCPList(self, GCPList):
        self._GCPList = GCPList

    def _get_imXY__mapXY_points(self, grid_res):
152
153
154
155
156
157
        """Returns a numpy array containing possible positions for coregistration tie points according to the given
        grid resolution.

        :param grid_res:
        :return:
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
158

159
        if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
160
            print('Initializing tie points grid...')
161

162
163
        Xarr, Yarr = np.meshgrid(np.arange(0, self.shift.shape[1], grid_res),
                                 np.arange(0, self.shift.shape[0], grid_res))
164

165
166
        mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr * self.shift.gt[1]
        mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr * abs(self.shift.gt[5])
167

168
169
170
        XY_points = np.empty((Xarr.size, 2), Xarr.dtype)
        XY_points[:, 0] = Xarr.flat
        XY_points[:, 1] = Yarr.flat
171

172
173
174
        XY_mapPoints = np.empty((mapXarr.size, 2), mapXarr.dtype)
        XY_mapPoints[:, 0] = mapXarr.flat
        XY_mapPoints[:, 1] = mapYarr.flat
175

Daniel Scheffler's avatar
Daniel Scheffler committed
176
177
        assert XY_points.shape == XY_mapPoints.shape

178
        return XY_points, XY_mapPoints
179

180
181
182
183
184
185
186
187
188
189
190
    def _exclude_bad_XYpos(self, GDF):
        """Excludes all points outside of the image overlap area and all points where the bad data mask is True (if given).

        :param GDF:     <geopandas.GeoDataFrame> must include the columns 'X_UTM' and 'Y_UTM'
        :return:
        """

        # exclude all points outside of overlap area
        inliers = points_in_poly(self.XY_mapPoints,
                                 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
        GDF = GDF[inliers].copy()
191
        # GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower
192

193
194
        # FIXME track that
        assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!'
195
196

        # exclude all point where bad data mask is True (e.g. points on clouds etc.)
197
198
199
200
201
202
        orig_len_GDF = len(GDF)  # length of GDF after dropping all points outside the overlap polygon
        mapXY = np.array(GDF.loc[:, ['X_UTM', 'Y_UTM']])
        GDF['REF_BADDATA'] = self.COREG_obj.ref.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.ref.mask_baddata is not None else False
        GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.shift.mask_baddata is not None else False
Daniel Scheffler's avatar
Daniel Scheffler committed
203
        GDF = GDF[(~GDF['REF_BADDATA']) & (~GDF['TGT_BADDATA'])]
204
        if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
Daniel Scheffler's avatar
Daniel Scheffler committed
205
206
            if not self.q:
                print('According to the provided bad data mask(s) %s points of initially %s have been excluded.'
207
                      % (orig_len_GDF - len(GDF), orig_len_GDF))
208
209
210

        return GDF

211
212
    @staticmethod
    def _get_spatial_shifts(coreg_kwargs):
Daniel Scheffler's avatar
Daniel Scheffler committed
213
        # unpack
214
        pointID = coreg_kwargs['pointID']
215
216
        fftw_works = coreg_kwargs['fftw_works']
        del coreg_kwargs['pointID'], coreg_kwargs['fftw_works']
217

Daniel Scheffler's avatar
Daniel Scheffler committed
218
        # assertions
219
        assert global_shared_imref is not None
220
        assert global_shared_im2shift is not None
Daniel Scheffler's avatar
Daniel Scheffler committed
221
222

        # run CoReg
223
        CR = COREG(global_shared_imref, global_shared_im2shift, CPUs=1, **coreg_kwargs)
224
        CR.fftw_works = fftw_works
225
        CR.calculate_spatial_shifts()
Daniel Scheffler's avatar
Daniel Scheffler committed
226
227

        # fetch results
228
        last_err = CR.tracked_errors[-1] if CR.tracked_errors else None
229
        win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
230
231
232
        CR_res = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
                  CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
                  CR.shift_reliability, last_err]
233

234
        return [pointID] + CR_res
235

236
    def get_CoRegPoints_table(self):
237
238
        assert self.XY_points is not None and self.XY_mapPoints is not None

239
240
241
242
        # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'
        # (convert imCoords to mapCoords
        XYarr2PointGeom = np.vectorize(lambda X, Y: Point(X, Y), otypes=[Point])
        geomPoints = np.array(XYarr2PointGeom(self.XY_mapPoints[:, 0], self.XY_mapPoints[:, 1]))
243

244
        if isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'geographic':
245
            crs = dict(ellps='WGS84', datum='WGS84', proj='longlat')
246
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'projected':
247
            UTMzone = abs(get_UTMzone(prj=self.COREG_obj.shift.prj))
248
249
250
251
            south = get_UTMzone(prj=self.COREG_obj.shift.prj) < 0
            crs = dict(ellps='WGS84', datum='WGS84', proj='utm', zone=UTMzone, south=south, units='m', no_defs=True)
            if not south:
                del crs['south']
252
253
254
        else:
            crs = None

255
256
257
258
259
260
        GDF = GeoDataFrame(index=range(len(geomPoints)), crs=crs,
                           columns=['geometry', 'POINT_ID', 'X_IM', 'Y_IM', 'X_UTM', 'Y_UTM'])
        GDF['geometry'] = geomPoints
        GDF['POINT_ID'] = range(len(geomPoints))
        GDF.loc[:, ['X_IM', 'Y_IM']] = self.XY_points
        GDF.loc[:, ['X_UTM', 'Y_UTM']] = self.XY_mapPoints
261

262
263
        # exclude offsite points and points on bad data mask
        GDF = self._exclude_bad_XYpos(GDF)
264
265
266
        if GDF.empty:
            self.CoRegPoints_table = GDF
            return self.CoRegPoints_table
267

268
        # choose a random subset of points if a maximum number has been given
269
        if self.max_points and len(GDF) > self.max_points:
270
            GDF = GDF.sample(self.max_points).copy()
271

272
        # equalize pixel grids in order to save warping time
273
274
275
276
        if len(GDF) > 100:
            # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is
            # connected to that
            self.COREG_obj.equalize_pixGrids()
277

278
        # declare global variables needed for self._get_spatial_shifts()
279
280
        global global_shared_imref, global_shared_im2shift
        assert self.ref.footprint_poly  # this also checks for mask_nodata and nodata value
281
        assert self.shift.footprint_poly
282
283
284

        # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if
        # neighboured matching windows overlap during reading from disk!!
285
286
        self.ref.cache_array_subset(
            [self.COREG_obj.ref.band4match])  # only sets geoArr._arr_cache; does not change number of bands
Daniel Scheffler's avatar
Daniel Scheffler committed
287
288
        self.shift.cache_array_subset([self.COREG_obj.shift.band4match])

289
        global_shared_imref = self.ref
290
        global_shared_im2shift = self.shift
291
292

        # get all variations of kwargs for coregistration
293
        get_coreg_kwargs = lambda pID, wp: dict(
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
            pointID=pID,
            fftw_works=self.COREG_obj.fftw_works,
            wp=wp,
            ws=self.COREG_obj.win_size_XY,
            resamp_alg_calc=self.rspAlg_calc,
            footprint_poly_ref=self.COREG_obj.ref.poly,
            footprint_poly_tgt=self.COREG_obj.shift.poly,
            r_b4match=self.ref.band4match + 1,  # band4match is internally saved as index, starting from 0
            s_b4match=self.shift.band4match + 1,  # band4match is internally saved as index, starting from 0
            max_iter=self.COREG_obj.max_iter,
            max_shift=self.COREG_obj.max_shift,
            nodata=(self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
            force_quadratic_win=self.COREG_obj.force_quadratic_win,
            binary_ws=self.COREG_obj.bin_ws,
            v=False,  # otherwise this would lead to massive console output
            q=True,  # otherwise this would lead to massive console output
            ignore_errors=True
311
        )
312
        list_coreg_kwargs = (get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index)  # generator
313
314

        # run co-registration for whole grid
315
        if self.CPUs is None or self.CPUs > 1:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
316
            if not self.q:
317
                cpus = self.CPUs if self.CPUs is not None else multiprocessing.cpu_count()
318
                print("Calculating tie points grid (%s points) using %s CPU cores..." % (len(GDF), cpus))
319

320
            with multiprocessing.Pool(self.CPUs) as pool:
321
322
323
324
                if self.q or not self.progress:
                    results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
                else:
                    results = pool.map_async(self._get_spatial_shifts, list_coreg_kwargs, chunksize=1)
325
                    bar = ProgressBar(prefix='\tprogress:')
326
327
                    while True:
                        time.sleep(.1)
328
329
330
                        # this does not really represent the remaining tasks but the remaining chunks
                        # -> thus chunksize=1
                        numberDone = len(GDF) - results._number_left
331
                        if self.progress:
332
                            bar.print_progress(percent=numberDone / len(GDF) * 100)
333
                        if results.ready():
334
335
336
                            # <= this is the line where multiprocessing can freeze if an exception appears within
                            # COREG ans is not raised
                            results = results.get()
337
                            break
Daniel Scheffler's avatar
Daniel Scheffler committed
338

339
        else:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
340
            if not self.q:
341
342
343
344
                print("Calculating tie points grid (%s points) 1 CPU core..." % len(GDF))
            results = np.empty((len(geomPoints), 14), np.object)
            bar = ProgressBar(prefix='\tprogress:')
            for i, coreg_kwargs in enumerate(list_coreg_kwargs):
345
                if self.progress:
346
347
                    bar.print_progress((i + 1) / len(GDF) * 100)
                results[i, :] = self._get_spatial_shifts(coreg_kwargs)
348

349
                # merge results with GDF
350
        records = GeoDataFrame(np.array(results, np.object),
351
                               columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX', 'Y_SHIFT_PX', 'X_SHIFT_M',
352
                                        'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
353
354
                                        'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])

355
356
357
        GDF = GDF.merge(records, on='POINT_ID', how="inner")
        GDF = GDF.fillna(int(self.outFillVal))

358
359
360
        if not self.q:
            print("Found %s matches." % len(GDF[GDF.LAST_ERR == int(self.outFillVal)]))

361
        # filter tie points according to given filter level
362
        if self.tieP_filter_level > 0:
363
364
            if not self.q:
                print('Performing validity checks...')
365
            TPR = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], **self.outlDetect_settings)
366
            GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
367
            GDF = GDF.merge(GDF_filt[['POINT_ID'] + new_columns], on='POINT_ID', how="outer")
368
        GDF = GDF.fillna(int(self.outFillVal))
369

370
        self.CoRegPoints_table = GDF
371
372
373

        return self.CoRegPoints_table

374
375
376
377
    def calc_rmse(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates root mean square error of absolute shifts from the tie point grid.

Daniel Scheffler's avatar
Daniel Scheffler committed
378
        :param include_outliers:    whether to include tie points that have been marked as false-positives (if present)
379
380
381
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
382
        tbl = tbl if include_outliers else tbl[~tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else tbl
383
384
385
386
387
388
389
390
391
392
393
394
395
396

        shifts = np.array(tbl['ABS_SHIFT'])
        shifts_sq = [i * i for i in shifts if i != self.outFillVal]

        return np.sqrt(sum(shifts_sq) / len(shifts_sq))

    def calc_overall_mssim(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates the median value of all MSSIM values contained in tie point grid.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
397
        tbl = tbl if include_outliers else tbl[~tbl['OUTLIER']].copy()
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418

        mssim_col = np.array(tbl['MSSIM'])
        mssim_col = [i * i for i in mssim_col if i != self.outFillVal]

        return float(np.median(mssim_col))

    def plot_shift_distribution(self, include_outliers=True, unit='m', interactive=False, figsize=None, xlim=None,
                                ylim=None, fontsize=12, title='shift distribution'):
        # type: (bool, str, bool, tuple, list, list, int) -> tuple
        """Creates a 2D scatterplot containing the distribution of calculated X/Y-shifts.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        :param unit:                'm' for meters or 'px' for pixels (default: 'm')
        :param interactive:         interactive mode uses plotly for visualization
        :param figsize:             (xdim, ydim)
        :param xlim:                [xmin, xmax]
        :param ylim:                [ymin, ymax]
        :param fontsize:            size of all used fonts
        :param title:               the title to be plotted above the figure
        """

419
420
        if unit not in ['m', 'px']:
            raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." % unit)
421
422

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
423
        tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal]
Daniel Scheffler's avatar
Daniel Scheffler committed
424
425
        tbl_il = tbl[~tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else tbl
        tbl_ol = tbl[tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else None
426
427
        x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
        y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
428
429
        rmse = self.calc_rmse(include_outliers=False)  # always exclude outliers when calculating RMSE
        figsize = figsize if figsize else (10, 10)
430
431
432
433

        if interactive:
            from plotly.offline import iplot, init_notebook_mode
            import plotly.graph_objs as go
Daniel Scheffler's avatar
Daniel Scheffler committed
434
            # FIXME outliers are not plotted
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455

            init_notebook_mode(connected=True)

            # Create a trace
            trace = go.Scatter(
                x=tbl_il[x_attr],
                y=tbl_il[y_attr],
                mode='markers'
            )

            data = [trace]

            # Plot and embed in ipython notebook!
            iplot(data, filename='basic-scatter')

            return None, None

        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111)

Daniel Scheffler's avatar
Daniel Scheffler committed
456
            if include_outliers and 'OUTLIER' in tbl.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
457
                ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives')
458
459
460
461
462
463
464
465
466
467
468
469
            ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')

            # set axis limits
            if not xlim:
                xmax = np.abs(tbl_il[x_attr]).max()
                xlim = [-xmax, xmax]
            if not ylim:
                ymax = np.abs(tbl_il[y_attr]).max()
                ylim = [-ymax, ymax]
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

Daniel Scheffler's avatar
Daniel Scheffler committed
470
            # add text box containing RMSE of plotted shifts
471
            xlim, ylim = ax.get_xlim(), ax.get_ylim()
472
473
            plt.text(xlim[1] - (xlim[1] / 20), -ylim[1] + (ylim[1] / 20),
                     'RMSE:  %s m / %s px' % (np.round(rmse, 2), np.round(rmse / self.shift.xgsd, 2)),
474
                     ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
475

Daniel Scheffler's avatar
Daniel Scheffler committed
476
            # add grid and increase linewidth of middle line
477
478
479
            plt.grid()
            xgl = ax.get_xgridlines()
            middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
480
            middle_xgl.set_linewidth(2)
481
482
483
            middle_xgl.set_linestyle('-')
            ygl = ax.get_ygridlines()
            middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
484
            middle_ygl.set_linewidth(2)
485
486
            middle_ygl.set_linestyle('-')

Daniel Scheffler's avatar
Daniel Scheffler committed
487
488
            # set title and adjust tick labels
            ax.set_title(title, fontsize=fontsize)
489
490
            [tick.label.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
            [tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
Daniel Scheffler's avatar
Daniel Scheffler committed
491
492
            plt.xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
            plt.ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
493

494
495
            # add legend with labels in the right order
            handles, labels = ax.get_legend_handles_labels()
Daniel Scheffler's avatar
Daniel Scheffler committed
496
497
            leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3)
            leg.get_frame().set_edgecolor('black')
498

499
500
501
502
            plt.show()

            return fig, ax

503
    def dump_CoRegPoints_table(self, path_out=None):
504
505
506
507
508
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=self.dir_out, fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl"
                                                                % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                                                   self.COREG_obj.win_size_XY[1], self.shift.basename,
                                                                   self.ref.basename))
509
510
511
        if not self.q:
            print('Writing %s ...' % path_out)
        self.CoRegPoints_table.to_pickle(path_out)
512

513
    def to_GCPList(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
514
        # get copy of tie points grid without no data
Daniel Scheffler's avatar
Daniel Scheffler committed
515
516
517
518
519
        try:
            GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
        except AttributeError:
            # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded
            return []
520

521
        if getattr(GDF, 'empty'):  # GDF.empty returns AttributeError
522
523
            return []
        else:
524
            # exclude all points flagged as outliers
525
            if 'OUTLIER' in GDF.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
526
                GDF = GDF[GDF.OUTLIER == False].copy()
527
528
            avail_TP = len(GDF)

529
530
531
532
            if not avail_TP:
                # no point passed all validity checks
                return []

533
            if avail_TP > 7000:
534
535
536
                GDF = GDF.sample(7000)
                warnings.warn('By far not more than 7000 tie points can be used for warping within a limited '
                              'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen '
537
                              'out of the %s available tie points.' % avail_TP)
538

539
540
541
            # calculate GCPs
            GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
            GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
542
543
            GDF['GCP'] = GDF.apply(lambda GDF_row: gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
                                                            GDF_row.X_IM, GDF_row.Y_IM), axis=1)
544
545
546
            self.GCPList = GDF.GCP.tolist()

            if not self.q:
547
                print('Found %s valid tie points.' % len(self.GCPList))
548
549

            return self.GCPList
550

551
    def test_if_singleprocessing_equals_multiprocessing_result(self):
552
553
        # RANSAC filtering always produces different results because it includes random sampling
        self.tieP_filter_level = 1
554

Daniel Scheffler's avatar
Daniel Scheffler committed
555
        self.CPUs = None
556
        dataframe = self.get_CoRegPoints_table()
557
        mp_out = np.empty_like(dataframe.values)
558
        mp_out[:] = dataframe.values
Daniel Scheffler's avatar
Daniel Scheffler committed
559
        self.CPUs = 1
560
        dataframe = self.get_CoRegPoints_table()
561
        sp_out = np.empty_like(dataframe.values)
562
563
        sp_out[:] = dataframe.values

564
        return np.array_equal(sp_out, mp_out)
565

566
567
    def _get_line_by_PID(self, PID):
        return self.CoRegPoints_table.loc[PID, :]
568

569
    def _get_lines_by_PIDs(self, PIDs):
570
571
572
573
        assert isinstance(PIDs, list)
        lines = np.zeros((len(PIDs), self.CoRegPoints_table.shape[1]))
        for i, PID in enumerate(PIDs):
            lines[i, :] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
574
575
        return lines

576
    def to_PointShapefile(self, path_out=None, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):
577
        # type: (str, bool, str) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
578
        """Writes the calculated tie points grid to a point shapefile containing
579
        Tie_Point_Grid.CoRegPoints_table as attribute table. This shapefile can easily be displayed using GIS software.
580
581
582

        :param path_out:        <str> the output path. If not given, it is automatically defined.
        :param skip_nodata:     <bool> whether to skip all points where no valid match could be found
583
        :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to
584
585
                                identify points where no valid match could be found
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
586

587
588
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
Daniel Scheffler's avatar
Daniel Scheffler committed
589
590
591
592
593

        # replace boolean values (cannot be written)
        for col in GDF2pass.columns:
            if GDF2pass[col].dtype == np.bool:
                GDF2pass[col] = GDF2pass[col].astype(int)
594
595
596
597
598
599
600
601
        GDF2pass = GDF2pass.replace(False, 0)  # replace all remaining booleans where dtype is not np.bool but np.object
        GDF2pass = GDF2pass.replace(True, 1)

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
602
        if not self.q:
603
            print('Writing %s ...' % path_out)
604
605
        GDF2pass.to_file(path_out)

606
607
608
609
610
611
    def _to_PointShapefile(self, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):
        warnings.warn(DeprecationWarning(
            "'_tiepoints_grid_to_PointShapefile' is deprecated."  # TODO delete if other method validated
            " 'tiepoints_grid_to_PointShapefile' is much faster."))
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
612
        shapely_points = GDF2pass['geometry'].values.tolist()
613
        attr_dicts = [collections.OrderedDict(zip(GDF2pass.columns, GDF2pass.loc[i].values)) for i in GDF2pass.index]
614

615
        fName_out = "CoRegPoints_grid%s_ws%s.shp" % (self.grid_res, self.COREG_obj.win_size_XY)
616
617
618
        path_out = os.path.join(self.dir_out, fName_out)
        IO.write_shp(path_out, shapely_points, prj=self.COREG_obj.shift.prj, attrDict=attr_dicts)

619
    def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
620
        # type: (str) -> GeoArray
621
622
623
624
625
        """Saves the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield
        (e.g. using ArcGIS)

        :param path_out:    <str> the output path. If not given, it is automatically defined.
        :param fmt:         <str> output raster format string
626
627
        :param mode:        <str> The mode how the output is written ('uv' or 'md'; default: 'md')
                                    'uv': outputs X-/Y shifts
628
629
630
                                    'md': outputs magnitude and direction
        """

631
632
        assert mode in ['uv', 'md'], "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' " \
                                     "(outputs magnitude and direction)'. Got %s." % mode
633
634
        attr_b1 = 'X_SHIFT_M' if mode == 'uv' else 'ABS_SHIFT'
        attr_b2 = 'Y_SHIFT_M' if mode == 'uv' else 'ANGLE'
635

636
637
638
639
640
        xshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b1],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal=self.outFillVal)
641

642
643
644
645
646
        yshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b2],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal=self.outFillVal)
647
648
649

        out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)

650
651
652
653
654
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
655
656
657
658
659

        out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')

        return out_GA

660
    def _to_Raster_using_KrigingOLD(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
661
                                    path_out=None, tilepos=None):
662
        warnings.warn(DeprecationWarning("'to_Raster_using_KrigingOLD' is deprecated. Use to_Raster_using_Kriging "
663
                                         "instead."))  # TODO delete
664

665
666
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
667
668

        # subset if tilepos is given
669
670
671
        rows, cols = tilepos if tilepos else self.shift.shape
        GDF2pass = GDF2pass.loc[(GDF2pass['X_IM'] >= cols[0]) & (GDF2pass['X_IM'] <= cols[1]) &
                                (GDF2pass['Y_IM'] >= rows[0]) & (GDF2pass['Y_IM'] <= rows[1])]
672

673
        X_coords, Y_coords, ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'], GDF2pass[attrName]
674

675
        xmin, ymin, xmax, ymax = GDF2pass.total_bounds
676

677
678
        grid_res = outGridRes if outGridRes else int(min(xmax - xmin, ymax - ymin) / 250)
        grid_x, grid_y = np.arange(xmin, xmax + grid_res, grid_res), np.arange(ymax, ymin - grid_res, -grid_res)
679
680
681

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
682
        from pykrige.ok import OrdinaryKriging
683
684
685
686
687
688
689
690
691
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical', verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y)  # ,backend='C',)

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="Kriging__%s__grid%s_ws(%s_%s).tif"
                                          % (attrName, self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1]))
        print('Writing %s ...' % path_out)
692
        # add a half pixel grid points are centered on the output pixels
693
694
695
        xmin, ymin, xmax, ymax = xmin - grid_res / 2, ymin - grid_res / 2, xmax + grid_res / 2, ymax + grid_res / 2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res),
                                prj=self.COREG_obj.shift.prj)
696
697
698

        return zvalues

699
    def to_Raster_using_Kriging(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
700
                                fName_out=None, tilepos=None, tilesize=500, mp=None):
701

702
        mp = False if self.CPUs == 1 else True
703
704
        self._Kriging_sp(attrName, skip_nodata=skip_nodata, skip_nodata_col=skip_nodata_col,
                         outGridRes=outGridRes, fName_out=fName_out, tilepos=tilepos)
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726

        # if mp:
        #     tilepositions = UTL.get_image_tileborders([tilesize,tilesize],self.tgt_shape)
        #     args_kwargs_dicts=[]
        #     for tp in tilepositions:
        #         kwargs_dict = {'skip_nodata':skip_nodata,'skip_nodata_col':skip_nodata_col,'outGridRes':outGridRes,
        #                        'fName_out':fName_out,'tilepos':tp}
        #         args_kwargs_dicts.append({'args':[attrName],'kwargs':kwargs_dict})
        #     # self.kriged=[]
        #     # for i in args_kwargs_dicts:
        #     #     res = self.Kriging_mp(i)
        #     #     self.kriged.append(res)
        #     #     print(res)
        #
        #     with multiprocessing.Pool() as pool:
        #        self.kriged = pool.map(self.Kriging_mp,args_kwargs_dicts)
        # else:
        #     self.Kriging_sp(attrName,skip_nodata=skip_nodata,skip_nodata_col=skip_nodata_col,
        #                     outGridRes=outGridRes,fName_out=fName_out,tilepos=tilepos)
        res = self.kriged if mp else None
        return res

727
728
    def _Kriging_sp(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                    fName_out=None, tilepos=None):
729
730
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
731

732
733
734
735
736
737
738
739
740
741
        #         # subset if tilepos is given
        # #        overlap_factor =
        #         rows,cols = tilepos if tilepos else self.tgt_shape
        #         xvals, yvals = np.sort(GDF2pass['X_IM'].values.flat),np.sort(GDF2pass['Y_IM'].values.flat)
        #         cS,cE = UTL.find_nearest(xvals,cols[0],'off',1), UTL.find_nearest(xvals,cols[1],'on',1)
        #         rS,rE = UTL.find_nearest(yvals,rows[0],'off',1), UTL.find_nearest(yvals,rows[1],'on',1)
        #         # GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
        #         #                                (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]
        #         GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cS)&(GDF2pass['X_IM']<=cE)&
        #                                        (GDF2pass['Y_IM']>=rS)&(GDF2pass['Y_IM']<=rE)]
742

743
        X_coords, Y_coords, ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'], GDF2pass[attrName]
744

745
        xmin, ymin, xmax, ymax = GDF2pass.total_bounds
746

747
748
        grid_res = outGridRes if outGridRes else int(min(xmax - xmin, ymax - ymin) / 250)
        grid_x, grid_y = np.arange(xmin, xmax + grid_res, grid_res), np.arange(ymax, ymin - grid_res, -grid_res)
749
750
751

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
752
        from pykrige.ok import OrdinaryKriging
753
754
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical', verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y, backend='C', n_closest_points=12)
755

756
        if self.CPUs is None or self.CPUs > 1:
757
            fName_out = fName_out if fName_out else \
758
                "Kriging__%s__grid%s_ws%s_%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY, tilepos)
759
760
        else:
            fName_out = fName_out if fName_out else \
761
762
763
                "Kriging__%s__grid%s_ws%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY)
        path_out = get_generic_outpath(dir_out=self.dir_out, fName_out=fName_out)
        print('Writing %s ...' % path_out)
764
        # add a half pixel grid points are centered on the output pixels
765
766
767
        xmin, ymin, xmax, ymax = xmin - grid_res / 2, ymin - grid_res / 2, xmax + grid_res / 2, ymax + grid_res / 2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res),
                                prj=self.COREG_obj.shift.prj)
768
769
770

        return zvalues

771
    def _Kriging_mp(self, args_kwargs_dict):
772
773
        args = args_kwargs_dict.get('args', [])
        kwargs = args_kwargs_dict.get('kwargs', [])
774

775
        return self._Kriging_sp(*args, **kwargs)
776
777


778
class Tie_Point_Refiner(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
779
    def __init__(self, GDF, min_reliability=60, rs_max_outlier=10, rs_tolerance=2.5, rs_max_iter=15,
780
                 rs_exclude_previous_outliers=True, rs_timeout=20, q=False):
781
        """A class for performing outlier detection.
Daniel Scheffler's avatar
Daniel Scheffler committed
782

783
784
        :param GDF:                             GeoDataFrame like TiePointGrid.CoRegPoints_table containing all tie
                                                points to be filtered and the corresponding metadata
Daniel Scheffler's avatar
Daniel Scheffler committed
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
        :param min_reliability:                 <float, int> minimum threshold for previously computed tie X/Y shift
                                                reliability (default: 60%)
        :param rs_max_outlier:                  <float, int> RANSAC: maximum percentage of outliers to be detected
                                                (default: 10%)
        :param rs_tolerance:                    <float, int> RANSAC: percentage tolerance for max_outlier_percentage
                                                (default: 2.5%)
        :param rs_max_iter:                     <int> RANSAC: maximum iterations for finding the best RANSAC threshold
                                                (default: 15)
        :param rs_exclude_previous_outliers:    <bool> RANSAC: whether to exclude points that have been flagged as
                                                outlier by earlier filtering (default:True)
        :param rs_timeout:                      <float, int> RANSAC: timeout for iteration loop in seconds (default: 20)

        :param q:
        """

        self.GDF = GDF.copy()
        self.min_reliability = min_reliability
        self.rs_max_outlier_percentage = rs_max_outlier
        self.rs_tolerance = rs_tolerance
        self.rs_max_iter = rs_max_iter
        self.rs_exclude_previous_outliers = rs_exclude_previous_outliers
        self.rs_timeout = rs_timeout
        self.q = q
        self.new_cols = []
809
810
        self.ransac_model_robust = None

811
    def run_filtering(self, level=3):
812
813
        """Filter tie points used for shift correction.

814
        :param level:   tie point filter level (default: 3).
815
816
817
818
819
820
821
822
                        NOTE: lower levels are also included if a higher level is chosen
                            - Level 0: no tie point filtering
                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                reliability according to internal tests
                            - Level 2: SSIM filtering - filters all tie points out where shift
                                correction does not increase image similarity within matching window
                                (measured by mean structural similarity index)
                            - Level 3: RANSAC outlier detection
Daniel Scheffler's avatar
Daniel Scheffler committed
823
824
825
826

        :return:
        """

827
828
        # TODO catch empty GDF

829
        # RELIABILITY filtering
830
        if level > 0:
831
            marked_recs = GeoSeries(self._reliability_thresholding())
832
833
            self.GDF['L1_OUTLIER'] = marked_recs
            self.new_cols.append('L1_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
834

835
            if not self.q:
836
                print('%s tie points flagged by level 1 filtering (reliability).'
Daniel Scheffler's avatar
Daniel Scheffler committed
837
                      % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
838

839
        # SSIM filtering
840
        if level > 1:
841
            marked_recs = GeoSeries(self._SSIM_filtering())
842
843
            self.GDF['L2_OUTLIER'] = marked_recs
            self.new_cols.append('L2_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
844

845
            if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
846
                print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
847

848
        # RANSAC filtering
849
        if level > 2:
Daniel Scheffler's avatar
Daniel Scheffler committed
850
            # exclude previous outliers
Daniel Scheffler's avatar
Daniel Scheffler committed
851
            ransacInGDF = self.GDF[~self.GDF[self.new_cols].any(axis=1)].copy() \
852
                if self.rs_exclude_previous_outliers else self.GDF
Daniel Scheffler's avatar
Daniel Scheffler committed
853

854
            if len(ransacInGDF) > 4:
855
                # running RANSAC with less than four tie points makes no sense
Daniel Scheffler's avatar
Daniel Scheffler committed
856
857

                marked_recs = GeoSeries(self._RANSAC_outlier_detection(ransacInGDF))
858
859
                # we need to join a list here because otherwise it's merged by the 'index' column
                self.GDF['L3_OUTLIER'] = marked_recs.tolist()
Daniel Scheffler's avatar
Daniel Scheffler committed
860

861
                if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
862
863
                    print('%s tie points flagged by level 3 filtering (RANSAC)'
                          % (len(marked_recs[marked_recs])))
864
865
866
867
            else:
                print('RANSAC skipped because too less valid tie points have been found.')
                self.GDF['L3_OUTLIER'] = False

868
            self.new_cols.append('L3_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
869

870
871
872
873
874
        self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
        self.new_cols.append('OUTLIER')

        return self.GDF, self.new_cols

Daniel Scheffler's avatar
Daniel Scheffler committed
875
    def _reliability_thresholding(self):
876
        """Exclude all records where estimated reliability of the calculated shifts is below the given threshold."""
877

Daniel Scheffler's avatar
Daniel Scheffler committed
878
        return self.GDF.RELIABILITY < self.min_reliability
879
880

    def _SSIM_filtering(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
881
        """Exclude all records where SSIM decreased."""
882

883
        # ssim_diff  = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE'])
884

885
886
        # self.GDF.SSIM_IMPROVED = \
        #     self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1)
887

888
        return self.GDF.SSIM_IMPROVED is False
889

Daniel Scheffler's avatar
Daniel Scheffler committed
890
891
    def _RANSAC_outlier_detection(self, inGDF):
        """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm."""
892

Daniel Scheffler's avatar
Daniel Scheffler committed
893
        src_coords = np.array(inGDF[['X_UTM', 'Y_UTM']])
894
        xyShift = np.array(inGDF[['X_SHIFT_M', 'Y_SHIFT_M']])
895
896
897
        est_coords = src_coords + xyShift

        for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']):
898
            assert co.ndim == 2 and co.shape[1] == 2, "'%s' must have shape [Nx2]. Got shape %s." % (n, co.shape)
899

900
901
        if not 0 < self.rs_max_outlier_percentage < 100:
            raise ValueError
902
        min_inlier_percentage = 100 - self.rs_max_outlier_percentage
903
904
905
906
907
908

        class PolyTF_1(PolynomialTransform):
            def estimate(*data):
                return PolynomialTransform.estimate(