Tie_Point_Grid.py 48.7 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
# AROSICS - Automated and Robust Open-Source Image Co-Registration Software
#
5
# Copyright (C) 2017-2020  Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de)
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#
# This software was developed within the context of the GeoMultiSens project funded
# by the German Federal Ministry of Education and Research
# (project grant code: 01 IS 14 010 A-C).
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.

24
25
26
27
import collections
import multiprocessing
import os
import warnings
28
import time
29
30

# custom
31
32
33
34
try:
    import gdal
except ImportError:
    from osgeo import gdal
35
import numpy as np
36
37
from geopandas import GeoDataFrame
from pandas import DataFrame, Series
38
from shapely.geometry import Point
39
40

# internal modules
41
from .CoReg import COREG
42
from py_tools_ds.geo.projection import isProjectedOrGeographic, isLocal, get_UTMzone
43
from py_tools_ds.io.pathgen import get_generic_outpath
44
from py_tools_ds.processing.progress_mon import ProgressBar
45
from py_tools_ds.geo.vector.conversion import points_to_raster
46
from py_tools_ds.io.vector.writer import write_shp
47
from geoarray import GeoArray
48

49
from .CoReg import GeoArray_CoReg  # noqa F401  # flake8 issue
50

51
__author__ = 'Daniel Scheffler'
52

53
global_shared_imref = None
54
55
56
global_shared_im2shift = None


57
def mp_initializer(imref, imtgt):
Daniel Scheffler's avatar
Daniel Scheffler committed
58
    """Declare global variables needed for self._get_spatial_shifts().
59
60
61
62
63
64
65
66
67

    :param imref:   reference image
    :param imtgt:   target image
    """
    global global_shared_imref, global_shared_im2shift
    global_shared_imref = imref
    global_shared_im2shift = imtgt


68
class Tie_Point_Grid(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
69
70
71
72
73
74
75
76
77
78
    """
    The 'Tie_Point_Grid' class applies the algorithm to detect spatial shifts to the overlap area of the input images.

    Spatial shifts are calculated for each point in grid of which the parameters can be adjusted using keyword
    arguments. Shift correction performs a polynomial transformation using te calculated shifts of each point in the
    grid as GCPs. Thus 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target
    image.

    See help(Tie_Point_Grid) for documentation!
    """
79

80
    def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
81
82
                 tieP_filter_level=3, outlDetect_settings=None, dir_out=None, CPUs=None, progress=True, v=False,
                 q=False):
Daniel Scheffler's avatar
Daniel Scheffler committed
83
        """Get an instance of the 'Tie_Point_Grid' class.
84

85
        :param COREG_obj(object):       an instance of COREG class
86
        :param grid_res:                grid resolution in pixels of the target image (x-direction)
87
        :param max_points(int):         maximum number of points used to find coregistration tie points
88
89
90
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
Daniel Scheffler's avatar
Daniel Scheffler committed
91
        :param outFillVal(int):         if given the generated tie points grid is filled with this value in case
92
                                        no match could be found during co-registration (default: -9999)
93
94
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
95
96
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average,
                                                           mode, max, min, med, q1, q3)
97
                                        default: cubic (highly recommended)
98
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 3).
99
                                        NOTE: lower levels are also included if a higher level is chosen
100
                                            - Level 0: no tie point filtering
101
102
103
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
104
105
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
106
                                            - Level 3: RANSAC outlier detection
107
108
109
110
        :param outlDetect_settings      a dictionary with the settings to be passed to
                                        arosics.TiePointGrid.Tie_Point_Refiner. Available keys: min_reliability,
                                        rs_max_outlier, rs_tolerance, rs_max_iter, rs_exclude_previous_outliers,
                                        rs_timeout, q. See documentation there.
111
112
        :param dir_out(str):            output directory to be used for all outputs if nothing else is given
                                        to the individual methods
Daniel Scheffler's avatar
Daniel Scheffler committed
113
        :param CPUs(int):               number of CPUs to use during calculation of tie points grid
114
                                        (default: None, which means 'all CPUs available')
115
        :param progress(bool):          show progress bars (default: True)
116
117
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
118
        """
119
120
        if not isinstance(COREG_obj, COREG):
            raise ValueError("'COREG_obj' must be an instance of COREG class.")
121

122
        self.COREG_obj = COREG_obj  # type: COREG
123
124
125
126
        self.grid_res = grid_res
        self.max_points = max_points
        self.outFillVal = outFillVal
        self.rspAlg_calc = resamp_alg_calc
127
        self.tieP_filter_level = tieP_filter_level
128
        self.outlDetect_settings = outlDetect_settings or dict()
129
130
131
132
133
        self.dir_out = dir_out
        self.CPUs = CPUs
        self.v = v
        self.q = q if not v else False  # overridden by v
        self.progress = progress if not q else False  # overridden by q
134

135
136
137
        if 'q' not in self.outlDetect_settings:
            self.outlDetect_settings['q'] = self.q

138
139
        self.ref = self.COREG_obj.ref  # type: GeoArray_CoReg
        self.shift = self.COREG_obj.shift  # type: GeoArray_CoReg
140

141
        self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
142
143
144
145
        self._CoRegPoints_table = None  # set by self.CoRegPoints_table
        self._GCPList = None  # set by self.to_GCPList()
        self.kriged = None  # set by Raster_using_Kriging()

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    @property
    def mean_x_shift_px(self):
        return self.CoRegPoints_table['X_SHIFT_PX'][self.CoRegPoints_table['X_SHIFT_PX'] != self.outFillVal].mean()

    @property
    def mean_y_shift_px(self):
        return self.CoRegPoints_table['Y_SHIFT_PX'][self.CoRegPoints_table['Y_SHIFT_PX'] != self.outFillVal].mean()

    @property
    def mean_x_shift_map(self):
        return self.CoRegPoints_table['X_SHIFT_M'][self.CoRegPoints_table['X_SHIFT_M'] != self.outFillVal].mean()

    @property
    def mean_y_shift_map(self):
        return self.CoRegPoints_table['Y_SHIFT_M'][self.CoRegPoints_table['Y_SHIFT_M'] != self.outFillVal].mean()
161

162
163
    @property
    def CoRegPoints_table(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
164
165
166
167
        """Return a GeoDataFrame containing all the results from coregistration for all points in the tie point grid.

        Columns of the GeoDataFrame: 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE', 'Y_WIN_SIZE',
                                     'X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE'
168
        """
169
170
171
172
173
174
175
176
177
178
179
180
        if self._CoRegPoints_table is not None:
            return self._CoRegPoints_table
        else:
            self._CoRegPoints_table = self.get_CoRegPoints_table()
            return self._CoRegPoints_table

    @CoRegPoints_table.setter
    def CoRegPoints_table(self, CoRegPoints_table):
        self._CoRegPoints_table = CoRegPoints_table

    @property
    def GCPList(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
181
        """Return a list of GDAL compatible GCP objects."""
182
183
184
185
        if self._GCPList:
            return self._GCPList
        else:
            self._GCPList = self.to_GCPList()
186
            return self._GCPList
187
188
189
190
191
192

    @GCPList.setter
    def GCPList(self, GCPList):
        self._GCPList = GCPList

    def _get_imXY__mapXY_points(self, grid_res):
Daniel Scheffler's avatar
Daniel Scheffler committed
193
194
195
        """Return a numpy array containing possible positions for coregistration tie points.

        NOTE: The returned positions are dependent from the given grid resolution.
196
197
198
199

        :param grid_res:
        :return:
        """
200
        if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
201
            print('Initializing tie points grid...')
202

203
204
        Xarr, Yarr = np.meshgrid(np.arange(0, self.shift.shape[1], grid_res),
                                 np.arange(0, self.shift.shape[0], grid_res))
205

206
207
        mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr * self.shift.gt[1]
        mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr * abs(self.shift.gt[5])
208

209
210
211
        XY_points = np.empty((Xarr.size, 2), Xarr.dtype)
        XY_points[:, 0] = Xarr.flat
        XY_points[:, 1] = Yarr.flat
212

213
214
215
        XY_mapPoints = np.empty((mapXarr.size, 2), mapXarr.dtype)
        XY_mapPoints[:, 0] = mapXarr.flat
        XY_mapPoints[:, 1] = mapYarr.flat
216

Daniel Scheffler's avatar
Daniel Scheffler committed
217
218
        assert XY_points.shape == XY_mapPoints.shape

219
        return XY_points, XY_mapPoints
220

221
    def _exclude_bad_XYpos(self, GDF):
Daniel Scheffler's avatar
Daniel Scheffler committed
222
        """Exclude all points outside of the image overlap area and where the bad data mask is True (if given).
223
224
225
226

        :param GDF:     <geopandas.GeoDataFrame> must include the columns 'X_UTM' and 'Y_UTM'
        :return:
        """
227
228
        from skimage.measure import points_in_poly  # import here to avoid static TLS ImportError

229
230
231
232
        # exclude all points outside of overlap area
        inliers = points_in_poly(self.XY_mapPoints,
                                 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
        GDF = GDF[inliers].copy()
233
        # GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower
234

235
        assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!'
236

237
        # exclude all points where bad data mask is True (e.g. points on clouds etc.)
238
239
240
241
242
243
        orig_len_GDF = len(GDF)  # length of GDF after dropping all points outside the overlap polygon
        mapXY = np.array(GDF.loc[:, ['X_UTM', 'Y_UTM']])
        GDF['REF_BADDATA'] = self.COREG_obj.ref.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.ref.mask_baddata is not None else False
        GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.shift.mask_baddata is not None else False
Daniel Scheffler's avatar
Daniel Scheffler committed
244
        GDF = GDF[(~GDF['REF_BADDATA']) & (~GDF['TGT_BADDATA'])]
245
        if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
Daniel Scheffler's avatar
Daniel Scheffler committed
246
            if not self.q:
247
248
249
250
251
252
                if not GDF.empty:
                    print('With respect to the provided bad data mask(s) %s points of initially %s have been excluded.'
                          % (orig_len_GDF - len(GDF), orig_len_GDF))
                else:
                    warnings.warn('With respect to the provided bad data mask(s) no coregistration point could be '
                                  'placed within an image area usable for coregistration.')
253
254
255

        return GDF

256
257
    @staticmethod
    def _get_spatial_shifts(coreg_kwargs):
Daniel Scheffler's avatar
Daniel Scheffler committed
258
        # unpack
259
        pointID = coreg_kwargs['pointID']
260
261
        fftw_works = coreg_kwargs['fftw_works']
        del coreg_kwargs['pointID'], coreg_kwargs['fftw_works']
262

Daniel Scheffler's avatar
Daniel Scheffler committed
263
        # assertions
264
        assert global_shared_imref is not None
265
        assert global_shared_im2shift is not None
Daniel Scheffler's avatar
Daniel Scheffler committed
266
267

        # run CoReg
268
        CR = COREG(global_shared_imref, global_shared_im2shift, CPUs=1, **coreg_kwargs)
269
        CR.fftw_works = fftw_works
270
        CR.calculate_spatial_shifts()
Daniel Scheffler's avatar
Daniel Scheffler committed
271
272

        # fetch results
273
        last_err = CR.tracked_errors[-1] if CR.tracked_errors else None
274
        win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
275
276
277
        CR_res = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
                  CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
                  CR.shift_reliability, last_err]
278

279
        return [pointID] + CR_res
280

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    def _get_coreg_kwargs(self, pID, wp):
        return dict(
            pointID=pID,
            fftw_works=self.COREG_obj.fftw_works,
            wp=wp,
            ws=self.COREG_obj.win_size_XY,
            resamp_alg_calc=self.rspAlg_calc,
            footprint_poly_ref=self.COREG_obj.ref.poly,
            footprint_poly_tgt=self.COREG_obj.shift.poly,
            r_b4match=self.ref.band4match + 1,  # band4match is internally saved as index, starting from 0
            s_b4match=self.shift.band4match + 1,  # band4match is internally saved as index, starting from 0
            max_iter=self.COREG_obj.max_iter,
            max_shift=self.COREG_obj.max_shift,
            nodata=(self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
            force_quadratic_win=self.COREG_obj.force_quadratic_win,
            binary_ws=self.COREG_obj.bin_ws,
            v=False,  # otherwise this would lead to massive console output
            q=True,  # otherwise this would lead to massive console output
            ignore_errors=True
        )

302
    def get_CoRegPoints_table(self):
303
304
        assert self.XY_points is not None and self.XY_mapPoints is not None

305
306
307
308
        # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'
        # (convert imCoords to mapCoords
        XYarr2PointGeom = np.vectorize(lambda X, Y: Point(X, Y), otypes=[Point])
        geomPoints = np.array(XYarr2PointGeom(self.XY_mapPoints[:, 0], self.XY_mapPoints[:, 1]))
309

310
311
312
        if isLocal(self.COREG_obj.shift.prj):
            crs = None
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'geographic':
313
            crs = dict(ellps='WGS84', datum='WGS84', proj='longlat')
314
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'projected':
315
            UTMzone = abs(get_UTMzone(prj=self.COREG_obj.shift.prj))
316
317
318
319
            south = get_UTMzone(prj=self.COREG_obj.shift.prj) < 0
            crs = dict(ellps='WGS84', datum='WGS84', proj='utm', zone=UTMzone, south=south, units='m', no_defs=True)
            if not south:
                del crs['south']
320
321
322
        else:
            crs = None

323
324
325
        GDF = GeoDataFrame(index=range(len(geomPoints)), crs=crs,
                           columns=['geometry', 'POINT_ID', 'X_IM', 'Y_IM', 'X_UTM', 'Y_UTM'])
        GDF['geometry'] = geomPoints
Daniel Scheffler's avatar
Daniel Scheffler committed
326
        GDF['POINT_ID'] = range(len(geomPoints))
327
328
        GDF.loc[:, ['X_IM', 'Y_IM']] = self.XY_points
        GDF.loc[:, ['X_UTM', 'Y_UTM']] = self.XY_mapPoints
329

330
331
        # exclude offsite points and points on bad data mask
        GDF = self._exclude_bad_XYpos(GDF)
332
333
334
        if GDF.empty:
            self.CoRegPoints_table = GDF
            return self.CoRegPoints_table
335

336
        # choose a random subset of points if a maximum number has been given
337
        if self.max_points and len(GDF) > self.max_points:
338
            GDF = GDF.sample(self.max_points).copy()
339

340
        # equalize pixel grids in order to save warping time
341
342
343
344
        if len(GDF) > 100:
            # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is
            # connected to that
            self.COREG_obj.equalize_pixGrids()
345

346
        # validate reference and target image inputs
347
        assert self.ref.footprint_poly  # this also checks for mask_nodata and nodata value
348
        assert self.shift.footprint_poly
349
350
351

        # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if
        # neighboured matching windows overlap during reading from disk!!
352
353
        self.ref.cache_array_subset(
            [self.COREG_obj.ref.band4match])  # only sets geoArr._arr_cache; does not change number of bands
Daniel Scheffler's avatar
Daniel Scheffler committed
354
355
        self.shift.cache_array_subset([self.COREG_obj.shift.band4match])

356
        # get all variations of kwargs for coregistration
357
        list_coreg_kwargs = (self._get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index)  # generator
358
359

        # run co-registration for whole grid
360
        if self.CPUs is None or self.CPUs > 1:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
361
            if not self.q:
362
                cpus = self.CPUs if self.CPUs is not None else multiprocessing.cpu_count()
363
                print("Calculating tie point grid (%s points) using %s CPU cores..." % (len(GDF), cpus))
364

365
            with multiprocessing.Pool(self.CPUs, initializer=mp_initializer, initargs=(self.ref, self.shift)) as pool:
366
367
368
369
                if self.q or not self.progress:
                    results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
                else:
                    results = pool.map_async(self._get_spatial_shifts, list_coreg_kwargs, chunksize=1)
370
                    bar = ProgressBar(prefix='\tprogress:')
371
372
                    while True:
                        time.sleep(.1)
373
374
                        # this does not really represent the remaining tasks but the remaining chunks
                        # -> thus chunksize=1
Daniel Scheffler's avatar
Fix.    
Daniel Scheffler committed
375
                        # noinspection PyProtectedMember
376
                        numberDone = len(GDF) - results._number_left
377
                        if self.progress:
378
                            bar.print_progress(percent=numberDone / len(GDF) * 100)
379
                        if results.ready():
380
381
382
                            # <= this is the line where multiprocessing can freeze if an exception appears within
                            # COREG ans is not raised
                            results = results.get()
383
                            break
Daniel Scheffler's avatar
Daniel Scheffler committed
384

385
        else:
386
387
388
389
390
            # declare global variables needed for self._get_spatial_shifts()
            global global_shared_imref, global_shared_im2shift
            global_shared_imref = self.ref
            global_shared_im2shift = self.shift

Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
391
            if not self.q:
392
                print("Calculating tie point grid (%s points) on 1 CPU core..." % len(GDF))
393
394
395
            results = np.empty((len(geomPoints), 14), np.object)
            bar = ProgressBar(prefix='\tprogress:')
            for i, coreg_kwargs in enumerate(list_coreg_kwargs):
396
                if self.progress:
397
398
                    bar.print_progress((i + 1) / len(GDF) * 100)
                results[i, :] = self._get_spatial_shifts(coreg_kwargs)
399

400
        # merge results with GDF
401
402
403
404
405
406
        # NOTE: We use a pandas.DataFrame here because the geometry column is missing.
        #       GDF.astype(...) fails with geopandas>0.6.0 if the geometry columns is missing.
        records = DataFrame(results,
                            columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX', 'Y_SHIFT_PX', 'X_SHIFT_M',
                                     'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
                                     'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])
407

Daniel Scheffler's avatar
Daniel Scheffler committed
408
409
        # merge DataFrames (dtype must be equal to records.dtypes; We need np.object due to None values)
        GDF = GDF.astype(np.object).merge(records.astype(np.object), on='POINT_ID', how="inner")
410
411
        GDF = GDF.replace([np.nan, None], int(self.outFillVal))  # fillna fails with geopandas==0.6.0
        GDF.crs = crs  # gets lost when using GDF.astype(np.object), so we have to reassign that
412

413
414
415
        if not self.q:
            print("Found %s matches." % len(GDF[GDF.LAST_ERR == int(self.outFillVal)]))

416
        # filter tie points according to given filter level
417
        if self.tieP_filter_level > 0:
418
419
            if not self.q:
                print('Performing validity checks...')
420
            TPR = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], **self.outlDetect_settings)
421
            GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
422
            GDF = GDF.merge(GDF_filt[['POINT_ID'] + new_columns], on='POINT_ID', how="outer")
423

424
        GDF = GDF.replace([np.nan, None], int(self.outFillVal))  # fillna fails with geopandas==0.6.0
425

426
        self.CoRegPoints_table = GDF
427

428
429
430
431
        if not self.q:
            if GDF.empty:
                warnings.warn('No valid GCPs could by identified.')
            else:
432
433
                if self.tieP_filter_level > 0:
                    print("%d valid tie points remain after filtering." % len(GDF[GDF.OUTLIER.__eq__(False)]))
434

435
436
        return self.CoRegPoints_table

437
438
    def calc_rmse(self, include_outliers=False):
        # type: (bool) -> float
Daniel Scheffler's avatar
Daniel Scheffler committed
439
        """Calculate root mean square error of absolute shifts from the tie point grid.
440

Daniel Scheffler's avatar
Daniel Scheffler committed
441
        :param include_outliers:    whether to include tie points that have been marked as false-positives (if present)
442
443
        """
        tbl = self.CoRegPoints_table
444
        tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == 0].copy() if 'OUTLIER' in tbl.columns else tbl
445
446
447
448
449
450

        shifts = np.array(tbl['ABS_SHIFT'])
        shifts_sq = [i * i for i in shifts if i != self.outFillVal]

        return np.sqrt(sum(shifts_sq) / len(shifts_sq))

451
452
    def calc_overall_ssim(self, include_outliers=False, after_correction=True):
        # type: (bool, bool) -> float
Daniel Scheffler's avatar
Daniel Scheffler committed
453
        """Calculate the median value of all SSIM values contained in tie point grid.
454
455

        :param include_outliers:    whether to include tie points that have been marked as false-positives
456
        :param after_correction:    whether to compute median SSIM before correction or after
457
458
        """
        tbl = self.CoRegPoints_table
459
        tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == 0].copy()
460

461
462
        ssim_col = np.array(tbl['SSIM_AFTER' if after_correction else 'SSIM_BEFORE'])
        ssim_col = [i * i for i in ssim_col if i != self.outFillVal]
463

464
        return float(np.median(ssim_col))
465
466
467

    def plot_shift_distribution(self, include_outliers=True, unit='m', interactive=False, figsize=None, xlim=None,
                                ylim=None, fontsize=12, title='shift distribution'):
Daniel Scheffler's avatar
Daniel Scheffler committed
468
        # type: (bool, str, bool, tuple, list, list, int, str) -> tuple
Daniel Scheffler's avatar
Daniel Scheffler committed
469
        """Create a 2D scatterplot containing the distribution of calculated X/Y-shifts.
470
471
472
473
474
475
476
477
478
479

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        :param unit:                'm' for meters or 'px' for pixels (default: 'm')
        :param interactive:         interactive mode uses plotly for visualization
        :param figsize:             (xdim, ydim)
        :param xlim:                [xmin, xmax]
        :param ylim:                [ymin, ymax]
        :param fontsize:            size of all used fonts
        :param title:               the title to be plotted above the figure
        """
480
481
        from matplotlib import pyplot as plt

482
483
        if unit not in ['m', 'px']:
            raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." % unit)
484
485

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
486
        tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal]
487
        tbl_il = tbl[tbl['OUTLIER'] == 0].copy() if 'OUTLIER' in tbl.columns else tbl
Daniel Scheffler's avatar
Daniel Scheffler committed
488
        tbl_ol = tbl[tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else None
489
490
        x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
        y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
491
492
        rmse = self.calc_rmse(include_outliers=False)  # always exclude outliers when calculating RMSE
        figsize = figsize if figsize else (10, 10)
493
494
495
496

        if interactive:
            from plotly.offline import iplot, init_notebook_mode
            import plotly.graph_objs as go
Daniel Scheffler's avatar
Daniel Scheffler committed
497
            # FIXME outliers are not plotted
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518

            init_notebook_mode(connected=True)

            # Create a trace
            trace = go.Scatter(
                x=tbl_il[x_attr],
                y=tbl_il[y_attr],
                mode='markers'
            )

            data = [trace]

            # Plot and embed in ipython notebook!
            iplot(data, filename='basic-scatter')

            return None, None

        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111)

Daniel Scheffler's avatar
Daniel Scheffler committed
519
            if include_outliers and 'OUTLIER' in tbl.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
520
                ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives')
521
522
523
524
525
526
527
528
529
530
531
532
            ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')

            # set axis limits
            if not xlim:
                xmax = np.abs(tbl_il[x_attr]).max()
                xlim = [-xmax, xmax]
            if not ylim:
                ymax = np.abs(tbl_il[y_attr]).max()
                ylim = [-ymax, ymax]
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

Daniel Scheffler's avatar
Daniel Scheffler committed
533
            # add text box containing RMSE of plotted shifts
534
            xlim, ylim = ax.get_xlim(), ax.get_ylim()
535
536
            plt.text(xlim[1] - (xlim[1] / 20), -ylim[1] + (ylim[1] / 20),
                     'RMSE:  %s m / %s px' % (np.round(rmse, 2), np.round(rmse / self.shift.xgsd, 2)),
537
                     ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
538

Daniel Scheffler's avatar
Daniel Scheffler committed
539
            # add grid and increase linewidth of middle line
540
541
542
            plt.grid()
            xgl = ax.get_xgridlines()
            middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
543
            middle_xgl.set_linewidth(2)
544
545
546
            middle_xgl.set_linestyle('-')
            ygl = ax.get_ygridlines()
            middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
547
            middle_ygl.set_linewidth(2)
548
549
            middle_ygl.set_linestyle('-')

Daniel Scheffler's avatar
Daniel Scheffler committed
550
551
            # set title and adjust tick labels
            ax.set_title(title, fontsize=fontsize)
552
553
            [tick.label.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
            [tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
Daniel Scheffler's avatar
Daniel Scheffler committed
554
555
            plt.xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
            plt.ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
556

557
558
            # add legend with labels in the right order
            handles, labels = ax.get_legend_handles_labels()
Daniel Scheffler's avatar
Daniel Scheffler committed
559
560
            leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3)
            leg.get_frame().set_edgecolor('black')
561

562
563
564
565
            plt.show()

            return fig, ax

566
    def dump_CoRegPoints_table(self, path_out=None):
567
568
569
570
571
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=self.dir_out, fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl"
                                                                % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                                                   self.COREG_obj.win_size_XY[1], self.shift.basename,
                                                                   self.ref.basename))
572
573
574
        if not self.q:
            print('Writing %s ...' % path_out)
        self.CoRegPoints_table.to_pickle(path_out)
575

576
    def to_GCPList(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
577
        # get copy of tie points grid without no data
Daniel Scheffler's avatar
Daniel Scheffler committed
578
579
580
581
582
        try:
            GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
        except AttributeError:
            # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded
            return []
583

584
        if getattr(GDF, 'empty'):  # GDF.empty returns AttributeError
585
586
            return []
        else:
587
            # exclude all points flagged as outliers
588
            if 'OUTLIER' in GDF.columns:
589
                GDF = GDF[GDF.OUTLIER.__eq__(False)].copy()
590
591
            avail_TP = len(GDF)

592
593
594
595
            if not avail_TP:
                # no point passed all validity checks
                return []

596
            if avail_TP > 7000:
597
598
599
                GDF = GDF.sample(7000)
                warnings.warn('By far not more than 7000 tie points can be used for warping within a limited '
                              'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen '
600
                              'out of the %s available tie points.' % avail_TP)
601

602
603
604
            # calculate GCPs
            GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
            GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
605
606
            GDF['GCP'] = GDF.apply(lambda GDF_row: gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
                                                            GDF_row.X_IM, GDF_row.Y_IM), axis=1)
607
608
609
            self.GCPList = GDF.GCP.tolist()

            return self.GCPList
610

611
    def test_if_singleprocessing_equals_multiprocessing_result(self):
612
613
        # RANSAC filtering always produces different results because it includes random sampling
        self.tieP_filter_level = 1
614

Daniel Scheffler's avatar
Daniel Scheffler committed
615
        self.CPUs = None
616
        dataframe = self.get_CoRegPoints_table()
617
        mp_out = np.empty_like(dataframe.values)
618
        mp_out[:] = dataframe.values
Daniel Scheffler's avatar
Daniel Scheffler committed
619
        self.CPUs = 1
620
        dataframe = self.get_CoRegPoints_table()
621
        sp_out = np.empty_like(dataframe.values)
622
623
        sp_out[:] = dataframe.values

624
        return np.array_equal(sp_out, mp_out)
625

626
627
    def _get_line_by_PID(self, PID):
        return self.CoRegPoints_table.loc[PID, :]
628

629
    def _get_lines_by_PIDs(self, PIDs):
630
631
632
633
        assert isinstance(PIDs, list)
        lines = np.zeros((len(PIDs), self.CoRegPoints_table.shape[1]))
        for i, PID in enumerate(PIDs):
            lines[i, :] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
634
635
        return lines

636
    def to_PointShapefile(self, path_out=None, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):
637
        # type: (str, bool, str) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
638
639
640
        """Write the calculated tie points grid to a point shapefile (e.g., for visualization by a GIS software).

        NOTE: The shapefile uses Tie_Point_Grid.CoRegPoints_table as attribute table.
641
642
643

        :param path_out:        <str> the output path. If not given, it is automatically defined.
        :param skip_nodata:     <bool> whether to skip all points where no valid match could be found
644
        :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to
645
646
                                identify points where no valid match could be found
        """
647
        GDF = self.CoRegPoints_table
648
649
650

        if skip_nodata:
            GDF2pass = GDF[GDF[skip_nodata_col] != self.outFillVal].copy()
651
652
        else:
            GDF2pass = GDF
653
            GDF2pass.LAST_ERR = GDF2pass.apply(lambda GDF_row: repr(GDF_row.LAST_ERR), axis=1)
Daniel Scheffler's avatar
Daniel Scheffler committed
654
655

        # replace boolean values (cannot be written)
656
657
        GDF2pass = GDF2pass.replace(False, 0).copy()  # replace booleans where column dtype is not np.bool but np.object
        GDF2pass = GDF2pass.replace(True, 1).copy()
Daniel Scheffler's avatar
Daniel Scheffler committed
658
659
660
        for col in GDF2pass.columns:
            if GDF2pass[col].dtype == np.bool:
                GDF2pass[col] = GDF2pass[col].astype(int)
661
662
663
664
665
666

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
667
        if not self.q:
668
            print('Writing %s ...' % path_out)
669
670
        GDF2pass.to_file(path_out)

671
    def _to_PointShapefile(self, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):  # pragma: no cover
672
673
674
675
676
        warnings.warn(DeprecationWarning(
            "'_tiepoints_grid_to_PointShapefile' is deprecated."  # TODO delete if other method validated
            " 'tiepoints_grid_to_PointShapefile' is much faster."))
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
677
        shapely_points = GDF2pass['geometry'].values.tolist()
678
        attr_dicts = [collections.OrderedDict(zip(GDF2pass.columns, GDF2pass.loc[i].values)) for i in GDF2pass.index]
679

680
        fName_out = "CoRegPoints_grid%s_ws%s.shp" % (self.grid_res, self.COREG_obj.win_size_XY)
681
        path_out = os.path.join(self.dir_out, fName_out)
682
        write_shp(path_out, shapely_points, prj=self.COREG_obj.shift.prj, attrDict=attr_dicts)
683

684
    def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
Daniel Scheffler's avatar
Daniel Scheffler committed
685
        # type: (str, str, str) -> GeoArray
Daniel Scheffler's avatar
Daniel Scheffler committed
686
687
688
        """Save the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield.

        NOTE: For example ArcGIS is able to visualize such 2-band raster files as a vectorfield.
689
690
691

        :param path_out:    <str> the output path. If not given, it is automatically defined.
        :param fmt:         <str> output raster format string
692
693
        :param mode:        <str> The mode how the output is written ('uv' or 'md'; default: 'md')
                                    'uv': outputs X-/Y shifts
694
695
                                    'md': outputs magnitude and direction
        """
696
697
        assert mode in ['uv', 'md'], "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' " \
                                     "(outputs magnitude and direction)'. Got %s." % mode
698
699
        attr_b1 = 'X_SHIFT_M' if mode == 'uv' else 'ABS_SHIFT'
        attr_b2 = 'Y_SHIFT_M' if mode == 'uv' else 'ANGLE'
700

701
702
703
        xshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b1],
                                               tgt_res=self.shift.xgsd * self.grid_res,
704
                                               prj=self.CoRegPoints_table.crs.to_wkt(),
705
                                               fillVal=self.outFillVal)
706

707
708
709
        yshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b2],
                                               tgt_res=self.shift.xgsd * self.grid_res,
710
                                               prj=self.CoRegPoints_table.crs.to_wkt(),
711
                                               fillVal=self.outFillVal)
712
713
714

        out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)

715
716
717
718
719
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
720
721
722
723
724

        out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')

        return out_GA

725
    def to_Raster_using_Kriging(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
726
                                fName_out=None, tilepos=None, tilesize=500, mp=None):
727

728
        mp = False if self.CPUs == 1 else True
729
730
        self._Kriging_sp(attrName, skip_nodata=skip_nodata, skip_nodata_col=skip_nodata_col,
                         outGridRes=outGridRes, fName_out=fName_out, tilepos=tilepos)
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752

        # if mp:
        #     tilepositions = UTL.get_image_tileborders([tilesize,tilesize],self.tgt_shape)
        #     args_kwargs_dicts=[]
        #     for tp in tilepositions:
        #         kwargs_dict = {'skip_nodata':skip_nodata,'skip_nodata_col':skip_nodata_col,'outGridRes':outGridRes,
        #                        'fName_out':fName_out,'tilepos':tp}
        #         args_kwargs_dicts.append({'args':[attrName],'kwargs':kwargs_dict})
        #     # self.kriged=[]
        #     # for i in args_kwargs_dicts:
        #     #     res = self.Kriging_mp(i)
        #     #     self.kriged.append(res)
        #     #     print(res)
        #
        #     with multiprocessing.Pool() as pool:
        #        self.kriged = pool.map(self.Kriging_mp,args_kwargs_dicts)
        # else:
        #     self.Kriging_sp(attrName,skip_nodata=skip_nodata,skip_nodata_col=skip_nodata_col,
        #                     outGridRes=outGridRes,fName_out=fName_out,tilepos=tilepos)
        res = self.kriged if mp else None
        return res

753
754
    def _Kriging_sp(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                    fName_out=None, tilepos=None):
755
756
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
757

758
        X_coords, Y_coords, ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'], GDF2pass[attrName]
759

760
        xmin, ymin, xmax, ymax = GDF2pass.total_bounds
761

762
763
        grid_res = outGridRes if outGridRes else int(min(xmax - xmin, ymax - ymin) / 250)
        grid_x, grid_y = np.arange(xmin, xmax + grid_res, grid_res), np.arange(ymax, ymin - grid_res, -grid_res)
764
765
766

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
767
        from pykrige.ok import OrdinaryKriging
768
769
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical', verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y, backend='C', n_closest_points=12)
770

771
        if self.CPUs is None or self.CPUs > 1:
772
            fName_out = fName_out if fName_out else \
773
                "Kriging__%s__grid%s_ws%s_%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY, tilepos)
774
775
        else:
            fName_out = fName_out if fName_out else \
776
777
                "Kriging__%s__grid%s_ws%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY)
        path_out = get_generic_outpath(dir_out=self.dir_out, fName_out=fName_out)
778
        # add a half pixel grid points are centered on the output pixels
779
        xmin, ymin, xmax, ymax = xmin - grid_res / 2, ymin - grid_res / 2, xmax + grid_res / 2, ymax + grid_res / 2
780
781
782
783

        GeoArray(zvalues,
                 geotransform=(xmin, grid_res, 0, ymax, 0, -grid_res),
                 projection=self.COREG_obj.shift.prj).save(path_out)
784
785
786

        return zvalues

787
    def _Kriging_mp(self, args_kwargs_dict):
788
789
        args = args_kwargs_dict.get('args', [])
        kwargs = args_kwargs_dict.get('kwargs', [])
790

791
        return self._Kriging_sp(*args, **kwargs)
792
793


794
class Tie_Point_Refiner(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
795
796
    """A class for performing outlier detection."""

Daniel Scheffler's avatar
Daniel Scheffler committed
797
    def __init__(self, GDF, min_reliability=60, rs_max_outlier=10, rs_tolerance=2.5, rs_max_iter=15,
798
                 rs_exclude_previous_outliers=True, rs_timeout=20, q=False):
Daniel Scheffler's avatar
Daniel Scheffler committed
799
        """Get an instance of Tie_Point_Refiner.
Daniel Scheffler's avatar
Daniel Scheffler committed
800

801
802
        :param GDF:                             GeoDataFrame like TiePointGrid.CoRegPoints_table containing all tie
                                                points to be filtered and the corresponding metadata
Daniel Scheffler's avatar
Daniel Scheffler committed
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
        :param min_reliability:                 <float, int> minimum threshold for previously computed tie X/Y shift
                                                reliability (default: 60%)
        :param rs_max_outlier:                  <float, int> RANSAC: maximum percentage of outliers to be detected
                                                (default: 10%)
        :param rs_tolerance:                    <float, int> RANSAC: percentage tolerance for max_outlier_percentage
                                                (default: 2.5%)
        :param rs_max_iter:                     <int> RANSAC: maximum iterations for finding the best RANSAC threshold
                                                (default: 15)
        :param rs_exclude_previous_outliers:    <bool> RANSAC: whether to exclude points that have been flagged as
                                                outlier by earlier filtering (default:True)
        :param rs_timeout:                      <float, int> RANSAC: timeout for iteration loop in seconds (default: 20)

        :param q:
        """
        self.GDF = GDF.copy()
        self.min_reliability = min_reliability
        self.rs_max_outlier_percentage = rs_max_outlier
        self.rs_tolerance = rs_tolerance
        self.rs_max_iter = rs_max_iter
        self.rs_exclude_previous_outliers = rs_exclude_previous_outliers
        self.rs_timeout = rs_timeout
        self.q = q
        self.new_cols = []
826
827
        self.ransac_model_robust = None

828
    def run_filtering(self, level=3):
829
830
        """Filter tie points used for shift correction.

831
        :param level:   tie point filter level (default: 3).
832
833
834
835
836
837
838
839
                        NOTE: lower levels are also included if a higher level is chosen
                            - Level 0: no tie point filtering
                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                reliability according to internal tests
                            - Level 2: SSIM filtering - filters all tie points out where shift
                                correction does not increase image similarity within matching window
                                (measured by mean structural similarity index)
                            - Level 3: RANSAC outlier detection
Daniel Scheffler's avatar
Daniel Scheffler committed
840
841
842

        :return:
        """
843
844
        # TODO catch empty GDF

845
        # RELIABILITY filtering
846
        if level > 0:
847
            marked_recs = self._reliability_thresholding()  # type: Series
848
849
            self.GDF['L1_OUTLIER'] = marked_recs
            self.new_cols.append('L1_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
850

851
            if not self.q:
852
                print('%s tie points flagged by level 1 filtering (reliability).'
Daniel Scheffler's avatar
Daniel Scheffler committed
853
                      % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
854

855
        # SSIM filtering
856
        if level > 1:
857
858
            marked_recs = self._SSIM_filtering()
            self.GDF['L2_OUTLIER'] = marked_recs  # type: Series
859
            self.new_cols.append('L2_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
860

861
            if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
862
                print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
863

864
        # RANSAC filtering
865
        if level > 2:
Daniel Scheffler's avatar
Daniel Scheffler committed
866
            # exclude previous outliers
Daniel Scheffler's avatar
Daniel Scheffler committed
867
            ransacInGDF = self.GDF[~self.GDF[self.new_cols].any(axis=1)].copy() \
868
                    if self.rs_exclude_previous_outliers else self.GDF
Daniel Scheffler's avatar
Daniel Scheffler committed
869

870
            if len(ransacInGDF) > 4:
871
                # running RANSAC with less than four tie points makes no sense
Daniel Scheffler's avatar
Daniel Scheffler committed
872

873
                marked_recs = self._RANSAC_outlier_detection(ransacInGDF)  # type: Series
874
875
                # we need to join a list here because otherwise it's merged by the 'index' column
                self.GDF['L3_OUTLIER'] = marked_recs.tolist()
Daniel Scheffler's avatar
Daniel Scheffler committed
876

877
                if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
878
879
                    print('%s tie points flagged by level 3 filtering (RANSAC)'
                          % (len(marked_recs[marked_recs])))
880
881
882
883
            else:
                print('RANSAC skipped because too less valid tie points have been found.')
                self.GDF['L3_OUTLIER'] = False

884
            self.new_cols.append('L3_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
885

886
887
888
889
890
        self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
        self.new_cols.append('OUTLIER')

        return self.GDF, self.new_cols

Daniel Scheffler's avatar
Daniel Scheffler committed
891
    def _reliability_thresholding(self):
892
        """Exclude all records where estimated reliability of the calculated shifts is below the given threshold."""
Daniel Scheffler's avatar
Daniel Scheffler committed
893
        return self.GDF.RELIABILITY < self.min_reliability