Tie_Point_Grid.py 48.9 KB
Newer Older
1
2
3
4
5
6
# -*- coding: utf-8 -*-

import collections
import multiprocessing
import os
import warnings
7
import time
8
9

# custom
10
11
12
13
try:
    import gdal
except ImportError:
    from osgeo import gdal
14
import numpy as np
15
from matplotlib import pyplot as plt
16
17
18
from geopandas import GeoDataFrame, GeoSeries
from shapely.geometry import Point
from skimage.measure import points_in_poly, ransac
19
from skimage.transform import AffineTransform, PolynomialTransform
20
21

# internal modules
22
23
from .CoReg import COREG
from . import io as IO
24
25
from py_tools_ds.geo.projection import isProjectedOrGeographic, get_UTMzone, dict_to_proj4, proj4_to_WKT
from py_tools_ds.io.pathgen import get_generic_outpath
26
from py_tools_ds.processing.progress_mon import ProgressBar
27
from py_tools_ds.geo.vector.conversion import points_to_raster
28
from geoarray import GeoArray
29

30
31
from .CoReg import GeoArray_CoReg

32
__author__ = 'Daniel Scheffler'
33

34
global_shared_imref = None
35
36
37
global_shared_im2shift = None


38
39
40
41
42
43
44
45
46
47
48
def mp_initializer(imref, imtgt):
    """Declare global variables needed for self._get_spatial_shifts()

    :param imref:   reference image
    :param imtgt:   target image
    """
    global global_shared_imref, global_shared_im2shift
    global_shared_imref = imref
    global_shared_im2shift = imtgt


49
50
class Tie_Point_Grid(object):
    """See help(Tie_Point_Grid) for documentation!"""
51

52
    def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
53
54
                 tieP_filter_level=3, outlDetect_settings=None, dir_out=None, CPUs=None, progress=True, v=False,
                 q=False):
55

56
57
58
        """Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
        are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
        correction performs a polynomial transformation using te calculated shifts of each point in the grid as GCPs.
59
        Thus 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target image.
60

61
        :param COREG_obj(object):       an instance of COREG class
62
        :param grid_res:                grid resolution in pixels of the target image (x-direction)
63
        :param max_points(int):         maximum number of points used to find coregistration tie points
64
65
66
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
Daniel Scheffler's avatar
Daniel Scheffler committed
67
        :param outFillVal(int):         if given the generated tie points grid is filled with this value in case
68
                                        no match could be found during co-registration (default: -9999)
69
70
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
71
72
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average,
                                                           mode, max, min, med, q1, q3)
73
                                        default: cubic (highly recommended)
74
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 3).
75
                                        NOTE: lower levels are also included if a higher level is chosen
76
                                            - Level 0: no tie point filtering
77
78
79
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
80
81
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
82
                                            - Level 3: RANSAC outlier detection
83
84
85
86
        :param outlDetect_settings      a dictionary with the settings to be passed to
                                        arosics.TiePointGrid.Tie_Point_Refiner. Available keys: min_reliability,
                                        rs_max_outlier, rs_tolerance, rs_max_iter, rs_exclude_previous_outliers,
                                        rs_timeout, q. See documentation there.
87
88
        :param dir_out(str):            output directory to be used for all outputs if nothing else is given
                                        to the individual methods
Daniel Scheffler's avatar
Daniel Scheffler committed
89
        :param CPUs(int):               number of CPUs to use during calculation of tie points grid
90
                                        (default: None, which means 'all CPUs available')
91
        :param progress(bool):          show progress bars (default: True)
92
93
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
94
        """
95

96
97
        if not isinstance(COREG_obj, COREG):
            raise ValueError("'COREG_obj' must be an instance of COREG class.")
98

99
        self.COREG_obj = COREG_obj  # type: COREG
100
101
102
103
        self.grid_res = grid_res
        self.max_points = max_points
        self.outFillVal = outFillVal
        self.rspAlg_calc = resamp_alg_calc
104
        self.tieP_filter_level = tieP_filter_level
105
        self.outlDetect_settings = outlDetect_settings if outlDetect_settings else dict(q=q)
106
107
108
109
110
        self.dir_out = dir_out
        self.CPUs = CPUs
        self.v = v
        self.q = q if not v else False  # overridden by v
        self.progress = progress if not q else False  # overridden by q
111

112
113
        self.ref = self.COREG_obj.ref  # type: GeoArray_CoReg
        self.shift = self.COREG_obj.shift  # type: GeoArray_CoReg
114

115
        self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
116
117
118
119
        self._CoRegPoints_table = None  # set by self.CoRegPoints_table
        self._GCPList = None  # set by self.to_GCPList()
        self.kriged = None  # set by Raster_using_Kriging()

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    @property
    def mean_x_shift_px(self):
        return self.CoRegPoints_table['X_SHIFT_PX'][self.CoRegPoints_table['X_SHIFT_PX'] != self.outFillVal].mean()

    @property
    def mean_y_shift_px(self):
        return self.CoRegPoints_table['Y_SHIFT_PX'][self.CoRegPoints_table['Y_SHIFT_PX'] != self.outFillVal].mean()

    @property
    def mean_x_shift_map(self):
        return self.CoRegPoints_table['X_SHIFT_M'][self.CoRegPoints_table['X_SHIFT_M'] != self.outFillVal].mean()

    @property
    def mean_y_shift_map(self):
        return self.CoRegPoints_table['Y_SHIFT_M'][self.CoRegPoints_table['Y_SHIFT_M'] != self.outFillVal].mean()
135

136
137
    @property
    def CoRegPoints_table(self):
138
139
        """Returns a GeoDataFrame with the columns 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE',
        'Y_WIN_SIZE','X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE' containing all
140
        information containing all the results from coregistration for all points in the tie points grid.
141
        """
142
143
144
145
146
147
148
149
150
151
152
153
        if self._CoRegPoints_table is not None:
            return self._CoRegPoints_table
        else:
            self._CoRegPoints_table = self.get_CoRegPoints_table()
            return self._CoRegPoints_table

    @CoRegPoints_table.setter
    def CoRegPoints_table(self, CoRegPoints_table):
        self._CoRegPoints_table = CoRegPoints_table

    @property
    def GCPList(self):
154
155
        """Returns a list of GDAL compatible GCP objects.
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
156

157
158
159
160
        if self._GCPList:
            return self._GCPList
        else:
            self._GCPList = self.to_GCPList()
161
            return self._GCPList
162
163
164
165
166
167

    @GCPList.setter
    def GCPList(self, GCPList):
        self._GCPList = GCPList

    def _get_imXY__mapXY_points(self, grid_res):
168
169
170
171
172
173
        """Returns a numpy array containing possible positions for coregistration tie points according to the given
        grid resolution.

        :param grid_res:
        :return:
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
174

175
        if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
176
            print('Initializing tie points grid...')
177

178
179
        Xarr, Yarr = np.meshgrid(np.arange(0, self.shift.shape[1], grid_res),
                                 np.arange(0, self.shift.shape[0], grid_res))
180

181
182
        mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr * self.shift.gt[1]
        mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr * abs(self.shift.gt[5])
183

184
185
186
        XY_points = np.empty((Xarr.size, 2), Xarr.dtype)
        XY_points[:, 0] = Xarr.flat
        XY_points[:, 1] = Yarr.flat
187

188
189
190
        XY_mapPoints = np.empty((mapXarr.size, 2), mapXarr.dtype)
        XY_mapPoints[:, 0] = mapXarr.flat
        XY_mapPoints[:, 1] = mapYarr.flat
191

Daniel Scheffler's avatar
Daniel Scheffler committed
192
193
        assert XY_points.shape == XY_mapPoints.shape

194
        return XY_points, XY_mapPoints
195

196
197
198
199
200
201
202
203
204
205
206
    def _exclude_bad_XYpos(self, GDF):
        """Excludes all points outside of the image overlap area and all points where the bad data mask is True (if given).

        :param GDF:     <geopandas.GeoDataFrame> must include the columns 'X_UTM' and 'Y_UTM'
        :return:
        """

        # exclude all points outside of overlap area
        inliers = points_in_poly(self.XY_mapPoints,
                                 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
        GDF = GDF[inliers].copy()
207
        # GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower
208

209
210
        # FIXME track that
        assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!'
211
212

        # exclude all point where bad data mask is True (e.g. points on clouds etc.)
213
214
215
216
217
218
        orig_len_GDF = len(GDF)  # length of GDF after dropping all points outside the overlap polygon
        mapXY = np.array(GDF.loc[:, ['X_UTM', 'Y_UTM']])
        GDF['REF_BADDATA'] = self.COREG_obj.ref.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.ref.mask_baddata is not None else False
        GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.shift.mask_baddata is not None else False
Daniel Scheffler's avatar
Daniel Scheffler committed
219
        GDF = GDF[(~GDF['REF_BADDATA']) & (~GDF['TGT_BADDATA'])]
220
        if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
Daniel Scheffler's avatar
Daniel Scheffler committed
221
222
            if not self.q:
                print('According to the provided bad data mask(s) %s points of initially %s have been excluded.'
223
                      % (orig_len_GDF - len(GDF), orig_len_GDF))
224
225
226

        return GDF

227
228
    @staticmethod
    def _get_spatial_shifts(coreg_kwargs):
Daniel Scheffler's avatar
Daniel Scheffler committed
229
        # unpack
230
        pointID = coreg_kwargs['pointID']
231
232
        fftw_works = coreg_kwargs['fftw_works']
        del coreg_kwargs['pointID'], coreg_kwargs['fftw_works']
233

Daniel Scheffler's avatar
Daniel Scheffler committed
234
        # assertions
235
        assert global_shared_imref is not None
236
        assert global_shared_im2shift is not None
Daniel Scheffler's avatar
Daniel Scheffler committed
237
238

        # run CoReg
239
        CR = COREG(global_shared_imref, global_shared_im2shift, CPUs=1, **coreg_kwargs)
240
        CR.fftw_works = fftw_works
241
        CR.calculate_spatial_shifts()
Daniel Scheffler's avatar
Daniel Scheffler committed
242
243

        # fetch results
244
        last_err = CR.tracked_errors[-1] if CR.tracked_errors else None
245
        win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
246
247
248
        CR_res = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
                  CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
                  CR.shift_reliability, last_err]
249

250
        return [pointID] + CR_res
251

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    def _get_coreg_kwargs(self, pID, wp):
        return dict(
            pointID=pID,
            fftw_works=self.COREG_obj.fftw_works,
            wp=wp,
            ws=self.COREG_obj.win_size_XY,
            resamp_alg_calc=self.rspAlg_calc,
            footprint_poly_ref=self.COREG_obj.ref.poly,
            footprint_poly_tgt=self.COREG_obj.shift.poly,
            r_b4match=self.ref.band4match + 1,  # band4match is internally saved as index, starting from 0
            s_b4match=self.shift.band4match + 1,  # band4match is internally saved as index, starting from 0
            max_iter=self.COREG_obj.max_iter,
            max_shift=self.COREG_obj.max_shift,
            nodata=(self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
            force_quadratic_win=self.COREG_obj.force_quadratic_win,
            binary_ws=self.COREG_obj.bin_ws,
            v=False,  # otherwise this would lead to massive console output
            q=True,  # otherwise this would lead to massive console output
            ignore_errors=True
        )

273
    def get_CoRegPoints_table(self):
274
275
        assert self.XY_points is not None and self.XY_mapPoints is not None

276
277
278
279
        # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'
        # (convert imCoords to mapCoords
        XYarr2PointGeom = np.vectorize(lambda X, Y: Point(X, Y), otypes=[Point])
        geomPoints = np.array(XYarr2PointGeom(self.XY_mapPoints[:, 0], self.XY_mapPoints[:, 1]))
280

281
        if isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'geographic':
282
            crs = dict(ellps='WGS84', datum='WGS84', proj='longlat')
283
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'projected':
284
            UTMzone = abs(get_UTMzone(prj=self.COREG_obj.shift.prj))
285
286
287
288
            south = get_UTMzone(prj=self.COREG_obj.shift.prj) < 0
            crs = dict(ellps='WGS84', datum='WGS84', proj='utm', zone=UTMzone, south=south, units='m', no_defs=True)
            if not south:
                del crs['south']
289
290
291
        else:
            crs = None

292
293
294
295
296
297
        GDF = GeoDataFrame(index=range(len(geomPoints)), crs=crs,
                           columns=['geometry', 'POINT_ID', 'X_IM', 'Y_IM', 'X_UTM', 'Y_UTM'])
        GDF['geometry'] = geomPoints
        GDF['POINT_ID'] = range(len(geomPoints))
        GDF.loc[:, ['X_IM', 'Y_IM']] = self.XY_points
        GDF.loc[:, ['X_UTM', 'Y_UTM']] = self.XY_mapPoints
298

299
300
        # exclude offsite points and points on bad data mask
        GDF = self._exclude_bad_XYpos(GDF)
301
302
303
        if GDF.empty:
            self.CoRegPoints_table = GDF
            return self.CoRegPoints_table
304

305
        # choose a random subset of points if a maximum number has been given
306
        if self.max_points and len(GDF) > self.max_points:
307
            GDF = GDF.sample(self.max_points).copy()
308

309
        # equalize pixel grids in order to save warping time
310
311
312
313
        if len(GDF) > 100:
            # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is
            # connected to that
            self.COREG_obj.equalize_pixGrids()
314

315
        # validate reference and target image inputs
316
        assert self.ref.footprint_poly  # this also checks for mask_nodata and nodata value
317
        assert self.shift.footprint_poly
318
319
320

        # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if
        # neighboured matching windows overlap during reading from disk!!
321
322
        self.ref.cache_array_subset(
            [self.COREG_obj.ref.band4match])  # only sets geoArr._arr_cache; does not change number of bands
Daniel Scheffler's avatar
Daniel Scheffler committed
323
324
        self.shift.cache_array_subset([self.COREG_obj.shift.band4match])

325
        # get all variations of kwargs for coregistration
326
        list_coreg_kwargs = (self._get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index)  # generator
327
328

        # run co-registration for whole grid
329
        if self.CPUs is None or self.CPUs > 1:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
330
            if not self.q:
331
                cpus = self.CPUs if self.CPUs is not None else multiprocessing.cpu_count()
332
                print("Calculating tie point grid (%s points) using %s CPU cores..." % (len(GDF), cpus))
333

334
            with multiprocessing.Pool(self.CPUs, initializer=mp_initializer, initargs=(self.ref, self.shift)) as pool:
335
336
337
338
                if self.q or not self.progress:
                    results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
                else:
                    results = pool.map_async(self._get_spatial_shifts, list_coreg_kwargs, chunksize=1)
339
                    bar = ProgressBar(prefix='\tprogress:')
340
341
                    while True:
                        time.sleep(.1)
342
343
344
                        # this does not really represent the remaining tasks but the remaining chunks
                        # -> thus chunksize=1
                        numberDone = len(GDF) - results._number_left
345
                        if self.progress:
346
                            bar.print_progress(percent=numberDone / len(GDF) * 100)
347
                        if results.ready():
348
349
350
                            # <= this is the line where multiprocessing can freeze if an exception appears within
                            # COREG ans is not raised
                            results = results.get()
351
                            break
Daniel Scheffler's avatar
Daniel Scheffler committed
352

353
        else:
354
355
356
357
358
            # declare global variables needed for self._get_spatial_shifts()
            global global_shared_imref, global_shared_im2shift
            global_shared_imref = self.ref
            global_shared_im2shift = self.shift

Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
359
            if not self.q:
360
                print("Calculating tie point grid (%s points) 1 CPU core..." % len(GDF))
361
362
363
            results = np.empty((len(geomPoints), 14), np.object)
            bar = ProgressBar(prefix='\tprogress:')
            for i, coreg_kwargs in enumerate(list_coreg_kwargs):
364
                if self.progress:
365
366
                    bar.print_progress((i + 1) / len(GDF) * 100)
                results[i, :] = self._get_spatial_shifts(coreg_kwargs)
367

368
                # merge results with GDF
369
        records = GeoDataFrame(np.array(results, np.object),
370
                               columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX', 'Y_SHIFT_PX', 'X_SHIFT_M',
371
                                        'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
372
373
                                        'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])

374
375
376
        GDF = GDF.merge(records, on='POINT_ID', how="inner")
        GDF = GDF.fillna(int(self.outFillVal))

377
378
379
        if not self.q:
            print("Found %s matches." % len(GDF[GDF.LAST_ERR == int(self.outFillVal)]))

380
        # filter tie points according to given filter level
381
        if self.tieP_filter_level > 0:
382
383
            if not self.q:
                print('Performing validity checks...')
384
            TPR = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], **self.outlDetect_settings)
385
            GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
386
            GDF = GDF.merge(GDF_filt[['POINT_ID'] + new_columns], on='POINT_ID', how="outer")
387
        GDF = GDF.fillna(int(self.outFillVal))
388

389
        self.CoRegPoints_table = GDF
390
391
392

        return self.CoRegPoints_table

393
394
395
396
    def calc_rmse(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates root mean square error of absolute shifts from the tie point grid.

Daniel Scheffler's avatar
Daniel Scheffler committed
397
        :param include_outliers:    whether to include tie points that have been marked as false-positives (if present)
398
399
400
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
401
        tbl = tbl if include_outliers else tbl[~tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else tbl
402
403
404
405
406
407
408
409
410
411
412
413
414
415

        shifts = np.array(tbl['ABS_SHIFT'])
        shifts_sq = [i * i for i in shifts if i != self.outFillVal]

        return np.sqrt(sum(shifts_sq) / len(shifts_sq))

    def calc_overall_mssim(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates the median value of all MSSIM values contained in tie point grid.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
416
        tbl = tbl if include_outliers else tbl[~tbl['OUTLIER']].copy()
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437

        mssim_col = np.array(tbl['MSSIM'])
        mssim_col = [i * i for i in mssim_col if i != self.outFillVal]

        return float(np.median(mssim_col))

    def plot_shift_distribution(self, include_outliers=True, unit='m', interactive=False, figsize=None, xlim=None,
                                ylim=None, fontsize=12, title='shift distribution'):
        # type: (bool, str, bool, tuple, list, list, int) -> tuple
        """Creates a 2D scatterplot containing the distribution of calculated X/Y-shifts.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        :param unit:                'm' for meters or 'px' for pixels (default: 'm')
        :param interactive:         interactive mode uses plotly for visualization
        :param figsize:             (xdim, ydim)
        :param xlim:                [xmin, xmax]
        :param ylim:                [ymin, ymax]
        :param fontsize:            size of all used fonts
        :param title:               the title to be plotted above the figure
        """

438
439
        if unit not in ['m', 'px']:
            raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." % unit)
440
441

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
442
        tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal]
Daniel Scheffler's avatar
Daniel Scheffler committed
443
444
        tbl_il = tbl[~tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else tbl
        tbl_ol = tbl[tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else None
445
446
        x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
        y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
447
448
        rmse = self.calc_rmse(include_outliers=False)  # always exclude outliers when calculating RMSE
        figsize = figsize if figsize else (10, 10)
449
450
451
452

        if interactive:
            from plotly.offline import iplot, init_notebook_mode
            import plotly.graph_objs as go
Daniel Scheffler's avatar
Daniel Scheffler committed
453
            # FIXME outliers are not plotted
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474

            init_notebook_mode(connected=True)

            # Create a trace
            trace = go.Scatter(
                x=tbl_il[x_attr],
                y=tbl_il[y_attr],
                mode='markers'
            )

            data = [trace]

            # Plot and embed in ipython notebook!
            iplot(data, filename='basic-scatter')

            return None, None

        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111)

Daniel Scheffler's avatar
Daniel Scheffler committed
475
            if include_outliers and 'OUTLIER' in tbl.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
476
                ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives')
477
478
479
480
481
482
483
484
485
486
487
488
            ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')

            # set axis limits
            if not xlim:
                xmax = np.abs(tbl_il[x_attr]).max()
                xlim = [-xmax, xmax]
            if not ylim:
                ymax = np.abs(tbl_il[y_attr]).max()
                ylim = [-ymax, ymax]
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

Daniel Scheffler's avatar
Daniel Scheffler committed
489
            # add text box containing RMSE of plotted shifts
490
            xlim, ylim = ax.get_xlim(), ax.get_ylim()
491
492
            plt.text(xlim[1] - (xlim[1] / 20), -ylim[1] + (ylim[1] / 20),
                     'RMSE:  %s m / %s px' % (np.round(rmse, 2), np.round(rmse / self.shift.xgsd, 2)),
493
                     ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
494

Daniel Scheffler's avatar
Daniel Scheffler committed
495
            # add grid and increase linewidth of middle line
496
497
498
            plt.grid()
            xgl = ax.get_xgridlines()
            middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
499
            middle_xgl.set_linewidth(2)
500
501
502
            middle_xgl.set_linestyle('-')
            ygl = ax.get_ygridlines()
            middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
503
            middle_ygl.set_linewidth(2)
504
505
            middle_ygl.set_linestyle('-')

Daniel Scheffler's avatar
Daniel Scheffler committed
506
507
            # set title and adjust tick labels
            ax.set_title(title, fontsize=fontsize)
508
509
            [tick.label.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
            [tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
Daniel Scheffler's avatar
Daniel Scheffler committed
510
511
            plt.xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
            plt.ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
512

513
514
            # add legend with labels in the right order
            handles, labels = ax.get_legend_handles_labels()
Daniel Scheffler's avatar
Daniel Scheffler committed
515
516
            leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3)
            leg.get_frame().set_edgecolor('black')
517

518
519
520
521
            plt.show()

            return fig, ax

522
    def dump_CoRegPoints_table(self, path_out=None):
523
524
525
526
527
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=self.dir_out, fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl"
                                                                % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                                                   self.COREG_obj.win_size_XY[1], self.shift.basename,
                                                                   self.ref.basename))
528
529
530
        if not self.q:
            print('Writing %s ...' % path_out)
        self.CoRegPoints_table.to_pickle(path_out)
531

532
    def to_GCPList(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
533
        # get copy of tie points grid without no data
Daniel Scheffler's avatar
Daniel Scheffler committed
534
535
536
537
538
        try:
            GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
        except AttributeError:
            # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded
            return []
539

540
        if getattr(GDF, 'empty'):  # GDF.empty returns AttributeError
541
542
            return []
        else:
543
            # exclude all points flagged as outliers
544
            if 'OUTLIER' in GDF.columns:
545
                GDF = GDF[GDF.OUTLIER.__eq__(False)].copy()
546
547
            avail_TP = len(GDF)

548
549
550
551
            if not avail_TP:
                # no point passed all validity checks
                return []

552
            if avail_TP > 7000:
553
554
555
                GDF = GDF.sample(7000)
                warnings.warn('By far not more than 7000 tie points can be used for warping within a limited '
                              'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen '
556
                              'out of the %s available tie points.' % avail_TP)
557

558
559
560
            # calculate GCPs
            GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
            GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
561
562
            GDF['GCP'] = GDF.apply(lambda GDF_row: gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
                                                            GDF_row.X_IM, GDF_row.Y_IM), axis=1)
563
564
565
            self.GCPList = GDF.GCP.tolist()

            if not self.q:
566
                print('Found %s valid tie points.' % len(self.GCPList))
567
568

            return self.GCPList
569

570
    def test_if_singleprocessing_equals_multiprocessing_result(self):
571
572
        # RANSAC filtering always produces different results because it includes random sampling
        self.tieP_filter_level = 1
573

Daniel Scheffler's avatar
Daniel Scheffler committed
574
        self.CPUs = None
575
        dataframe = self.get_CoRegPoints_table()
576
        mp_out = np.empty_like(dataframe.values)
577
        mp_out[:] = dataframe.values
Daniel Scheffler's avatar
Daniel Scheffler committed
578
        self.CPUs = 1
579
        dataframe = self.get_CoRegPoints_table()
580
        sp_out = np.empty_like(dataframe.values)
581
582
        sp_out[:] = dataframe.values

583
        return np.array_equal(sp_out, mp_out)
584

585
586
    def _get_line_by_PID(self, PID):
        return self.CoRegPoints_table.loc[PID, :]
587

588
    def _get_lines_by_PIDs(self, PIDs):
589
590
591
592
        assert isinstance(PIDs, list)
        lines = np.zeros((len(PIDs), self.CoRegPoints_table.shape[1]))
        for i, PID in enumerate(PIDs):
            lines[i, :] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
593
594
        return lines

595
    def to_PointShapefile(self, path_out=None, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):
596
        # type: (str, bool, str) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
597
        """Writes the calculated tie points grid to a point shapefile containing
598
        Tie_Point_Grid.CoRegPoints_table as attribute table. This shapefile can easily be displayed using GIS software.
599
600
601

        :param path_out:        <str> the output path. If not given, it is automatically defined.
        :param skip_nodata:     <bool> whether to skip all points where no valid match could be found
602
        :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to
603
604
                                identify points where no valid match could be found
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
605

606
607
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
Daniel Scheffler's avatar
Daniel Scheffler committed
608
609
610
611
612

        # replace boolean values (cannot be written)
        for col in GDF2pass.columns:
            if GDF2pass[col].dtype == np.bool:
                GDF2pass[col] = GDF2pass[col].astype(int)
613
614
615
616
617
618
619
620
        GDF2pass = GDF2pass.replace(False, 0)  # replace all remaining booleans where dtype is not np.bool but np.object
        GDF2pass = GDF2pass.replace(True, 1)

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
621
        if not self.q:
622
            print('Writing %s ...' % path_out)
623
624
        GDF2pass.to_file(path_out)

625
626
627
628
629
630
    def _to_PointShapefile(self, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):
        warnings.warn(DeprecationWarning(
            "'_tiepoints_grid_to_PointShapefile' is deprecated."  # TODO delete if other method validated
            " 'tiepoints_grid_to_PointShapefile' is much faster."))
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
631
        shapely_points = GDF2pass['geometry'].values.tolist()
632
        attr_dicts = [collections.OrderedDict(zip(GDF2pass.columns, GDF2pass.loc[i].values)) for i in GDF2pass.index]
633

634
        fName_out = "CoRegPoints_grid%s_ws%s.shp" % (self.grid_res, self.COREG_obj.win_size_XY)
635
636
637
        path_out = os.path.join(self.dir_out, fName_out)
        IO.write_shp(path_out, shapely_points, prj=self.COREG_obj.shift.prj, attrDict=attr_dicts)

638
    def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
639
        # type: (str) -> GeoArray
640
641
642
643
644
        """Saves the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield
        (e.g. using ArcGIS)

        :param path_out:    <str> the output path. If not given, it is automatically defined.
        :param fmt:         <str> output raster format string
645
646
        :param mode:        <str> The mode how the output is written ('uv' or 'md'; default: 'md')
                                    'uv': outputs X-/Y shifts
647
648
649
                                    'md': outputs magnitude and direction
        """

650
651
        assert mode in ['uv', 'md'], "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' " \
                                     "(outputs magnitude and direction)'. Got %s." % mode
652
653
        attr_b1 = 'X_SHIFT_M' if mode == 'uv' else 'ABS_SHIFT'
        attr_b2 = 'Y_SHIFT_M' if mode == 'uv' else 'ANGLE'
654

655
656
657
658
659
        xshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b1],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal=self.outFillVal)
660

661
662
663
664
665
        yshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b2],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal=self.outFillVal)
666
667
668

        out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)

669
670
671
672
673
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
674
675
676
677
678

        out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')

        return out_GA

679
    def _to_Raster_using_KrigingOLD(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
680
                                    path_out=None, tilepos=None):
681
        warnings.warn(DeprecationWarning("'to_Raster_using_KrigingOLD' is deprecated. Use to_Raster_using_Kriging "
682
                                         "instead."))  # TODO delete
683

684
685
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
686
687

        # subset if tilepos is given
688
689
690
        rows, cols = tilepos if tilepos else self.shift.shape
        GDF2pass = GDF2pass.loc[(GDF2pass['X_IM'] >= cols[0]) & (GDF2pass['X_IM'] <= cols[1]) &
                                (GDF2pass['Y_IM'] >= rows[0]) & (GDF2pass['Y_IM'] <= rows[1])]
691

692
        X_coords, Y_coords, ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'], GDF2pass[attrName]
693

694
        xmin, ymin, xmax, ymax = GDF2pass.total_bounds
695

696
697
        grid_res = outGridRes if outGridRes else int(min(xmax - xmin, ymax - ymin) / 250)
        grid_x, grid_y = np.arange(xmin, xmax + grid_res, grid_res), np.arange(ymax, ymin - grid_res, -grid_res)
698
699
700

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
701
        from pykrige.ok import OrdinaryKriging
702
703
704
705
706
707
708
709
710
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical', verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y)  # ,backend='C',)

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="Kriging__%s__grid%s_ws(%s_%s).tif"
                                          % (attrName, self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1]))
        print('Writing %s ...' % path_out)
711
        # add a half pixel grid points are centered on the output pixels
712
713
714
        xmin, ymin, xmax, ymax = xmin - grid_res / 2, ymin - grid_res / 2, xmax + grid_res / 2, ymax + grid_res / 2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res),
                                prj=self.COREG_obj.shift.prj)
715
716
717

        return zvalues

718
    def to_Raster_using_Kriging(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
719
                                fName_out=None, tilepos=None, tilesize=500, mp=None):
720

721
        mp = False if self.CPUs == 1 else True
722
723
        self._Kriging_sp(attrName, skip_nodata=skip_nodata, skip_nodata_col=skip_nodata_col,
                         outGridRes=outGridRes, fName_out=fName_out, tilepos=tilepos)
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745

        # if mp:
        #     tilepositions = UTL.get_image_tileborders([tilesize,tilesize],self.tgt_shape)
        #     args_kwargs_dicts=[]
        #     for tp in tilepositions:
        #         kwargs_dict = {'skip_nodata':skip_nodata,'skip_nodata_col':skip_nodata_col,'outGridRes':outGridRes,
        #                        'fName_out':fName_out,'tilepos':tp}
        #         args_kwargs_dicts.append({'args':[attrName],'kwargs':kwargs_dict})
        #     # self.kriged=[]
        #     # for i in args_kwargs_dicts:
        #     #     res = self.Kriging_mp(i)
        #     #     self.kriged.append(res)
        #     #     print(res)
        #
        #     with multiprocessing.Pool() as pool:
        #        self.kriged = pool.map(self.Kriging_mp,args_kwargs_dicts)
        # else:
        #     self.Kriging_sp(attrName,skip_nodata=skip_nodata,skip_nodata_col=skip_nodata_col,
        #                     outGridRes=outGridRes,fName_out=fName_out,tilepos=tilepos)
        res = self.kriged if mp else None
        return res

746
747
    def _Kriging_sp(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                    fName_out=None, tilepos=None):
748
749
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
750

751
752
753
754
755
756
757
758
759
760
        #         # subset if tilepos is given
        # #        overlap_factor =
        #         rows,cols = tilepos if tilepos else self.tgt_shape
        #         xvals, yvals = np.sort(GDF2pass['X_IM'].values.flat),np.sort(GDF2pass['Y_IM'].values.flat)
        #         cS,cE = UTL.find_nearest(xvals,cols[0],'off',1), UTL.find_nearest(xvals,cols[1],'on',1)
        #         rS,rE = UTL.find_nearest(yvals,rows[0],'off',1), UTL.find_nearest(yvals,rows[1],'on',1)
        #         # GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
        #         #                                (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]
        #         GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cS)&(GDF2pass['X_IM']<=cE)&
        #                                        (GDF2pass['Y_IM']>=rS)&(GDF2pass['Y_IM']<=rE)]
761

762
        X_coords, Y_coords, ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'], GDF2pass[attrName]
763

764
        xmin, ymin, xmax, ymax = GDF2pass.total_bounds
765

766
767
        grid_res = outGridRes if outGridRes else int(min(xmax - xmin, ymax - ymin) / 250)
        grid_x, grid_y = np.arange(xmin, xmax + grid_res, grid_res), np.arange(ymax, ymin - grid_res, -grid_res)
768
769
770

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
771
        from pykrige.ok import OrdinaryKriging
772
773
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical', verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y, backend='C', n_closest_points=12)
774

775
        if self.CPUs is None or self.CPUs > 1:
776
            fName_out = fName_out if fName_out else \
777
                "Kriging__%s__grid%s_ws%s_%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY, tilepos)
778
779
        else:
            fName_out = fName_out if fName_out else \
780
781
782
                "Kriging__%s__grid%s_ws%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY)
        path_out = get_generic_outpath(dir_out=self.dir_out, fName_out=fName_out)
        print('Writing %s ...' % path_out)
783
        # add a half pixel grid points are centered on the output pixels
784
785
786
        xmin, ymin, xmax, ymax = xmin - grid_res / 2, ymin - grid_res / 2, xmax + grid_res / 2, ymax + grid_res / 2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res),
                                prj=self.COREG_obj.shift.prj)
787
788
789

        return zvalues

790
    def _Kriging_mp(self, args_kwargs_dict):
791
792
        args = args_kwargs_dict.get('args', [])
        kwargs = args_kwargs_dict.get('kwargs', [])
793

794
        return self._Kriging_sp(*args, **kwargs)
795
796


797
class Tie_Point_Refiner(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
798
    def __init__(self, GDF, min_reliability=60, rs_max_outlier=10, rs_tolerance=2.5, rs_max_iter=15,
799
                 rs_exclude_previous_outliers=True, rs_timeout=20, q=False):
800
        """A class for performing outlier detection.
Daniel Scheffler's avatar
Daniel Scheffler committed
801

802
803
        :param GDF:                             GeoDataFrame like TiePointGrid.CoRegPoints_table containing all tie
                                                points to be filtered and the corresponding metadata
Daniel Scheffler's avatar
Daniel Scheffler committed
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
        :param min_reliability:                 <float, int> minimum threshold for previously computed tie X/Y shift
                                                reliability (default: 60%)
        :param rs_max_outlier:                  <float, int> RANSAC: maximum percentage of outliers to be detected
                                                (default: 10%)
        :param rs_tolerance:                    <float, int> RANSAC: percentage tolerance for max_outlier_percentage
                                                (default: 2.5%)
        :param rs_max_iter:                     <int> RANSAC: maximum iterations for finding the best RANSAC threshold
                                                (default: 15)
        :param rs_exclude_previous_outliers:    <bool> RANSAC: whether to exclude points that have been flagged as
                                                outlier by earlier filtering (default:True)
        :param rs_timeout:                      <float, int> RANSAC: timeout for iteration loop in seconds (default: 20)

        :param q:
        """

        self.GDF = GDF.copy()
        self.min_reliability = min_reliability
        self.rs_max_outlier_percentage = rs_max_outlier
        self.rs_tolerance = rs_tolerance
        self.rs_max_iter = rs_max_iter
        self.rs_exclude_previous_outliers = rs_exclude_previous_outliers
        self.rs_timeout = rs_timeout
        self.q = q
        self.new_cols = []
828
829
        self.ransac_model_robust = None

830
    def run_filtering(self, level=3):
831
832
        """Filter tie points used for shift correction.

833
        :param level:   tie point filter level (default: 3).
834
835
836
837
838
839
840
841
                        NOTE: lower levels are also included if a higher level is chosen
                            - Level 0: no tie point filtering
                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                reliability according to internal tests
                            - Level 2: SSIM filtering - filters all tie points out where shift
                                correction does not increase image similarity within matching window
                                (measured by mean structural similarity index)
                            - Level 3: RANSAC outlier detection
Daniel Scheffler's avatar
Daniel Scheffler committed
842
843
844
845

        :return:
        """

846
847
        # TODO catch empty GDF

848
        # RELIABILITY filtering
849
        if level > 0:
850
            marked_recs = GeoSeries(self._reliability_thresholding())
851
852
            self.GDF['L1_OUTLIER'] = marked_recs
            self.new_cols.append('L1_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
853

854
            if not self.q:
855
                print('%s tie points flagged by level 1 filtering (reliability).'
Daniel Scheffler's avatar
Daniel Scheffler committed
856
                      % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
857

858
        # SSIM filtering
859
        if level > 1:
860
            marked_recs = GeoSeries(self._SSIM_filtering())
861
862
            self.GDF['L2_OUTLIER'] = marked_recs
            self.new_cols.append('L2_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
863

864
            if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
865
                print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
866

867
        # RANSAC filtering
868
        if level > 2:
Daniel Scheffler's avatar
Daniel Scheffler committed
869
            # exclude previous outliers
Daniel Scheffler's avatar
Daniel Scheffler committed
870
            ransacInGDF = self.GDF[~self.GDF[self.new_cols].any(axis=1)].copy() \
871
                    if self.rs_exclude_previous_outliers else self.GDF
Daniel Scheffler's avatar
Daniel Scheffler committed
872

873
            if len(ransacInGDF) > 4:
874
                # running RANSAC with less than four tie points makes no sense
Daniel Scheffler's avatar
Daniel Scheffler committed
875
876

                marked_recs = GeoSeries(self._RANSAC_outlier_detection(ransacInGDF))
877
878
                # we need to join a list here because otherwise it's merged by the 'index' column
                self.GDF['L3_OUTLIER'] = marked_recs.tolist()
Daniel Scheffler's avatar
Daniel Scheffler committed
879

880
                if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
881
882
                    print('%s tie points flagged by level 3 filtering (RANSAC)'
                          % (len(marked_recs[marked_recs])))
883
884
885
886
            else:
                print('RANSAC skipped because too less valid tie points have been found.')
                self.GDF['L3_OUTLIER'] = False

887
            self.new_cols.append('L3_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
888

889
890
891
892
893
        self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
        self.new_cols.append('OUTLIER')

        return self.GDF, self.new_cols

Daniel Scheffler's avatar
Daniel Scheffler committed
894
    def _reliability_thresholding(self):
895
        """Exclude all records where estimated reliability of the calculated shifts is below the given threshold."""
896

Daniel Scheffler's avatar
Daniel Scheffler committed
897
        return self.GDF.RELIABILITY < self.min_reliability
898
899

    def _SSIM_filtering(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
900
        """Exclude all records where SSIM decreased."""
901

902
        # ssim_diff  = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE'])
903

904
905
        # self.GDF.SSIM_IMPROVED = \
        #     self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1)
906

907
        return ~self.GDF.SSIM_IMPROVED
908

Daniel Scheffler's avatar
Daniel Scheffler committed
909
910
    def _RANSAC_outlier_detection(self, inGDF):
        """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm."""
911

Daniel Scheffler's avatar
Daniel Scheffler committed
912
        src_coords = np.array(inGDF[['X_UTM', 'Y_UTM']])
913
        xyShift = np.array(inGDF[['X_SHIFT_M', 'Y_SHIFT_M']])
914
915
916
        est_coords = src_coords + xyShift

        for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']):
917
            assert co.ndim == 2 and co.shape[1] == 2, "'%s' must have shape [Nx2]. Got shape %s." % (n, co.shape)