Tie_Point_Grid.py 48.5 KB
Newer Older
1
2
3
4
5
6
# -*- coding: utf-8 -*-

import collections
import multiprocessing
import os
import warnings
7
import time
8
9

# custom
10
11
12
13
try:
    import gdal
except ImportError:
    from osgeo import gdal
14
import numpy as np
15
from matplotlib import pyplot as plt
16
17
18
from geopandas import GeoDataFrame, GeoSeries
from shapely.geometry import Point
from skimage.measure import points_in_poly, ransac
19
from skimage.transform import AffineTransform, PolynomialTransform
20
21

# internal modules
22
23
from .CoReg import COREG
from . import io as IO
24
25
from py_tools_ds.geo.projection import isProjectedOrGeographic, get_UTMzone, dict_to_proj4, proj4_to_WKT
from py_tools_ds.io.pathgen import get_generic_outpath
26
from py_tools_ds.processing.progress_mon import ProgressBar
27
from py_tools_ds.geo.vector.conversion import points_to_raster
28
from geoarray import GeoArray
29

30
__author__ = 'Daniel Scheffler'
31

32
global_shared_imref = None
33
34
35
global_shared_im2shift = None


36
37
class Tie_Point_Grid(object):
    """See help(Tie_Point_Grid) for documentation!"""
38

39
    def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
40
41
                 tieP_filter_level=3, outlDetect_settings=None, dir_out=None, CPUs=None, progress=True, v=False,
                 q=False):
42

43
44
45
        """Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
        are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
        correction performs a polynomial transformation using te calculated shifts of each point in the grid as GCPs.
46
        Thus 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target image.
47

48
        :param COREG_obj(object):       an instance of COREG class
49
        :param grid_res:                grid resolution in pixels of the target image (x-direction)
50
        :param max_points(int):         maximum number of points used to find coregistration tie points
51
52
53
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
Daniel Scheffler's avatar
Daniel Scheffler committed
54
        :param outFillVal(int):         if given the generated tie points grid is filled with this value in case
55
                                        no match could be found during co-registration (default: -9999)
56
57
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
58
59
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average,
                                                           mode, max, min, med, q1, q3)
60
                                        default: cubic (highly recommended)
61
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 3).
62
                                        NOTE: lower levels are also included if a higher level is chosen
63
                                            - Level 0: no tie point filtering
64
65
66
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
67
68
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
69
                                            - Level 3: RANSAC outlier detection
70
71
72
73
        :param outlDetect_settings      a dictionary with the settings to be passed to
                                        arosics.TiePointGrid.Tie_Point_Refiner. Available keys: min_reliability,
                                        rs_max_outlier, rs_tolerance, rs_max_iter, rs_exclude_previous_outliers,
                                        rs_timeout, q. See documentation there.
74
75
        :param dir_out(str):            output directory to be used for all outputs if nothing else is given
                                        to the individual methods
Daniel Scheffler's avatar
Daniel Scheffler committed
76
        :param CPUs(int):               number of CPUs to use during calculation of tie points grid
77
                                        (default: None, which means 'all CPUs available')
78
        :param progress(bool):          show progress bars (default: True)
79
80
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
81
        """
82
83
84

        if not isinstance(COREG_obj, COREG): raise ValueError("'COREG_obj' must be an instance of COREG class.")

85
86
87
88
89
        self.COREG_obj = COREG_obj
        self.grid_res = grid_res
        self.max_points = max_points
        self.outFillVal = outFillVal
        self.rspAlg_calc = resamp_alg_calc
90
        self.tieP_filter_level = tieP_filter_level
91
        self.outlDetect_settings = outlDetect_settings if outlDetect_settings else dict(q=q)
92
93
94
95
96
        self.dir_out = dir_out
        self.CPUs = CPUs
        self.v = v
        self.q = q if not v else False  # overridden by v
        self.progress = progress if not q else False  # overridden by q
97

98
99
        self.ref = self.COREG_obj.ref
        self.shift = self.COREG_obj.shift
100

101
        self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        self._CoRegPoints_table = None  # set by self.CoRegPoints_table
        self._GCPList = None  # set by self.to_GCPList()
        self.kriged = None  # set by Raster_using_Kriging()

    mean_x_shift_px = property(lambda self:
                               self.CoRegPoints_table['X_SHIFT_PX'][
                                   self.CoRegPoints_table['X_SHIFT_PX'] != self.outFillVal].mean())
    mean_y_shift_px = property(lambda self:
                               self.CoRegPoints_table['Y_SHIFT_PX'][
                                   self.CoRegPoints_table['Y_SHIFT_PX'] != self.outFillVal].mean())
    mean_x_shift_map = property(lambda self:
                                self.CoRegPoints_table['X_SHIFT_M'][
                                    self.CoRegPoints_table['X_SHIFT_M'] != self.outFillVal].mean())
    mean_y_shift_map = property(lambda self:
                                self.CoRegPoints_table['Y_SHIFT_M'][
                                    self.CoRegPoints_table['Y_SHIFT_M'] != self.outFillVal].mean())
118

119
120
    @property
    def CoRegPoints_table(self):
121
122
        """Returns a GeoDataFrame with the columns 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE',
        'Y_WIN_SIZE','X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE' containing all
123
        information containing all the results from coregistration for all points in the tie points grid.
124
        """
125
126
127
128
129
130
131
132
133
134
135
136
        if self._CoRegPoints_table is not None:
            return self._CoRegPoints_table
        else:
            self._CoRegPoints_table = self.get_CoRegPoints_table()
            return self._CoRegPoints_table

    @CoRegPoints_table.setter
    def CoRegPoints_table(self, CoRegPoints_table):
        self._CoRegPoints_table = CoRegPoints_table

    @property
    def GCPList(self):
137
138
        """Returns a list of GDAL compatible GCP objects.
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
139

140
141
142
143
        if self._GCPList:
            return self._GCPList
        else:
            self._GCPList = self.to_GCPList()
144
            return self._GCPList
145
146
147
148
149
150

    @GCPList.setter
    def GCPList(self, GCPList):
        self._GCPList = GCPList

    def _get_imXY__mapXY_points(self, grid_res):
151
152
153
154
155
156
        """Returns a numpy array containing possible positions for coregistration tie points according to the given
        grid resolution.

        :param grid_res:
        :return:
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
157

158
        if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
159
            print('Initializing tie points grid...')
160

161
162
        Xarr, Yarr = np.meshgrid(np.arange(0, self.shift.shape[1], grid_res),
                                 np.arange(0, self.shift.shape[0], grid_res))
163

164
165
        mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr * self.shift.gt[1]
        mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr * abs(self.shift.gt[5])
166

167
168
169
        XY_points = np.empty((Xarr.size, 2), Xarr.dtype)
        XY_points[:, 0] = Xarr.flat
        XY_points[:, 1] = Yarr.flat
170

171
172
173
        XY_mapPoints = np.empty((mapXarr.size, 2), mapXarr.dtype)
        XY_mapPoints[:, 0] = mapXarr.flat
        XY_mapPoints[:, 1] = mapYarr.flat
174

Daniel Scheffler's avatar
Daniel Scheffler committed
175
176
        assert XY_points.shape == XY_mapPoints.shape

177
        return XY_points, XY_mapPoints
178

179
180
181
182
183
184
185
186
187
188
189
    def _exclude_bad_XYpos(self, GDF):
        """Excludes all points outside of the image overlap area and all points where the bad data mask is True (if given).

        :param GDF:     <geopandas.GeoDataFrame> must include the columns 'X_UTM' and 'Y_UTM'
        :return:
        """

        # exclude all points outside of overlap area
        inliers = points_in_poly(self.XY_mapPoints,
                                 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
        GDF = GDF[inliers].copy()
190
        # GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower
191

192
193
        # FIXME track that
        assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!'
194
195

        # exclude all point where bad data mask is True (e.g. points on clouds etc.)
196
197
198
199
200
201
        orig_len_GDF = len(GDF)  # length of GDF after dropping all points outside the overlap polygon
        mapXY = np.array(GDF.loc[:, ['X_UTM', 'Y_UTM']])
        GDF['REF_BADDATA'] = self.COREG_obj.ref.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.ref.mask_baddata is not None else False
        GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.shift.mask_baddata is not None else False
Daniel Scheffler's avatar
Daniel Scheffler committed
202
        GDF = GDF[(~GDF['REF_BADDATA']) & (~GDF['TGT_BADDATA'])]
203
        if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
Daniel Scheffler's avatar
Daniel Scheffler committed
204
205
            if not self.q:
                print('According to the provided bad data mask(s) %s points of initially %s have been excluded.'
206
                      % (orig_len_GDF - len(GDF), orig_len_GDF))
207
208
209

        return GDF

210
211
    @staticmethod
    def _get_spatial_shifts(coreg_kwargs):
Daniel Scheffler's avatar
Daniel Scheffler committed
212
        # unpack
213
        pointID = coreg_kwargs['pointID']
214
215
        fftw_works = coreg_kwargs['fftw_works']
        del coreg_kwargs['pointID'], coreg_kwargs['fftw_works']
216

Daniel Scheffler's avatar
Daniel Scheffler committed
217
        # assertions
218
        assert global_shared_imref is not None
219
        assert global_shared_im2shift is not None
Daniel Scheffler's avatar
Daniel Scheffler committed
220
221

        # run CoReg
222
        CR = COREG(global_shared_imref, global_shared_im2shift, CPUs=1, **coreg_kwargs)
223
        CR.fftw_works = fftw_works
224
        CR.calculate_spatial_shifts()
Daniel Scheffler's avatar
Daniel Scheffler committed
225
226

        # fetch results
227
        last_err = CR.tracked_errors[-1] if CR.tracked_errors else None
228
        win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
229
230
231
        CR_res = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
                  CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
                  CR.shift_reliability, last_err]
232

233
        return [pointID] + CR_res
234

235
    def get_CoRegPoints_table(self):
236
237
        assert self.XY_points is not None and self.XY_mapPoints is not None

238
239
240
241
        # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'
        # (convert imCoords to mapCoords
        XYarr2PointGeom = np.vectorize(lambda X, Y: Point(X, Y), otypes=[Point])
        geomPoints = np.array(XYarr2PointGeom(self.XY_mapPoints[:, 0], self.XY_mapPoints[:, 1]))
242

243
        if isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'geographic':
244
            crs = dict(ellps='WGS84', datum='WGS84', proj='longlat')
245
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'projected':
246
            UTMzone = abs(get_UTMzone(prj=self.COREG_obj.shift.prj))
247
248
249
250
            south = get_UTMzone(prj=self.COREG_obj.shift.prj) < 0
            crs = dict(ellps='WGS84', datum='WGS84', proj='utm', zone=UTMzone, south=south, units='m', no_defs=True)
            if not south:
                del crs['south']
251
252
253
        else:
            crs = None

254
255
256
257
258
259
        GDF = GeoDataFrame(index=range(len(geomPoints)), crs=crs,
                           columns=['geometry', 'POINT_ID', 'X_IM', 'Y_IM', 'X_UTM', 'Y_UTM'])
        GDF['geometry'] = geomPoints
        GDF['POINT_ID'] = range(len(geomPoints))
        GDF.loc[:, ['X_IM', 'Y_IM']] = self.XY_points
        GDF.loc[:, ['X_UTM', 'Y_UTM']] = self.XY_mapPoints
260

261
262
        # exclude offsite points and points on bad data mask
        GDF = self._exclude_bad_XYpos(GDF)
263
264
265
        if GDF.empty:
            self.CoRegPoints_table = GDF
            return self.CoRegPoints_table
266

267
        # choose a random subset of points if a maximum number has been given
268
        if self.max_points and len(GDF) > self.max_points:
269
            GDF = GDF.sample(self.max_points).copy()
270

271
        # equalize pixel grids in order to save warping time
272
273
274
275
        if len(GDF) > 100:
            # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is
            # connected to that
            self.COREG_obj.equalize_pixGrids()
276

277
        # declare global variables needed for self._get_spatial_shifts()
278
279
        global global_shared_imref, global_shared_im2shift
        assert self.ref.footprint_poly  # this also checks for mask_nodata and nodata value
280
        assert self.shift.footprint_poly
281
282
283

        # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if
        # neighboured matching windows overlap during reading from disk!!
284
285
        self.ref.cache_array_subset(
            [self.COREG_obj.ref.band4match])  # only sets geoArr._arr_cache; does not change number of bands
Daniel Scheffler's avatar
Daniel Scheffler committed
286
287
        self.shift.cache_array_subset([self.COREG_obj.shift.band4match])

288
        global_shared_imref = self.ref
289
        global_shared_im2shift = self.shift
290
291

        # get all variations of kwargs for coregistration
292
        get_coreg_kwargs = lambda pID, wp: dict(
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
            pointID=pID,
            fftw_works=self.COREG_obj.fftw_works,
            wp=wp,
            ws=self.COREG_obj.win_size_XY,
            resamp_alg_calc=self.rspAlg_calc,
            footprint_poly_ref=self.COREG_obj.ref.poly,
            footprint_poly_tgt=self.COREG_obj.shift.poly,
            r_b4match=self.ref.band4match + 1,  # band4match is internally saved as index, starting from 0
            s_b4match=self.shift.band4match + 1,  # band4match is internally saved as index, starting from 0
            max_iter=self.COREG_obj.max_iter,
            max_shift=self.COREG_obj.max_shift,
            nodata=(self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
            force_quadratic_win=self.COREG_obj.force_quadratic_win,
            binary_ws=self.COREG_obj.bin_ws,
            v=False,  # otherwise this would lead to massive console output
            q=True,  # otherwise this would lead to massive console output
            ignore_errors=True
310
        )
311
        list_coreg_kwargs = (get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index)  # generator
312
313

        # run co-registration for whole grid
314
        if self.CPUs is None or self.CPUs > 1:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
315
            if not self.q:
316
                cpus = self.CPUs if self.CPUs is not None else multiprocessing.cpu_count()
317
                print("Calculating tie points grid (%s points) using %s CPU cores..." % (len(GDF), cpus))
318

319
            with multiprocessing.Pool(self.CPUs) as pool:
320
321
322
323
                if self.q or not self.progress:
                    results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
                else:
                    results = pool.map_async(self._get_spatial_shifts, list_coreg_kwargs, chunksize=1)
324
                    bar = ProgressBar(prefix='\tprogress:')
325
326
                    while True:
                        time.sleep(.1)
327
328
329
                        # this does not really represent the remaining tasks but the remaining chunks
                        # -> thus chunksize=1
                        numberDone = len(GDF) - results._number_left
330
                        if self.progress:
331
                            bar.print_progress(percent=numberDone / len(GDF) * 100)
332
                        if results.ready():
333
334
335
                            # <= this is the line where multiprocessing can freeze if an exception appears within
                            # COREG ans is not raised
                            results = results.get()
336
                            break
Daniel Scheffler's avatar
Daniel Scheffler committed
337

338
        else:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
339
            if not self.q:
340
341
342
343
                print("Calculating tie points grid (%s points) 1 CPU core..." % len(GDF))
            results = np.empty((len(geomPoints), 14), np.object)
            bar = ProgressBar(prefix='\tprogress:')
            for i, coreg_kwargs in enumerate(list_coreg_kwargs):
344
                if self.progress:
345
346
                    bar.print_progress((i + 1) / len(GDF) * 100)
                results[i, :] = self._get_spatial_shifts(coreg_kwargs)
347

348
                # merge results with GDF
349
        records = GeoDataFrame(np.array(results, np.object),
350
                               columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX', 'Y_SHIFT_PX', 'X_SHIFT_M',
351
                                        'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
352
353
                                        'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])

354
355
356
        GDF = GDF.merge(records, on='POINT_ID', how="inner")
        GDF = GDF.fillna(int(self.outFillVal))

357
358
359
        if not self.q:
            print("Found %s matches." % len(GDF[GDF.LAST_ERR == int(self.outFillVal)]))

360
        # filter tie points according to given filter level
361
        if self.tieP_filter_level > 0:
362
363
            if not self.q:
                print('Performing validity checks...')
364
            TPR = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], **self.outlDetect_settings)
365
            GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
366
            GDF = GDF.merge(GDF_filt[['POINT_ID'] + new_columns], on='POINT_ID', how="outer")
367
        GDF = GDF.fillna(int(self.outFillVal))
368

369
        self.CoRegPoints_table = GDF
370
371
372

        return self.CoRegPoints_table

373
374
375
376
    def calc_rmse(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates root mean square error of absolute shifts from the tie point grid.

Daniel Scheffler's avatar
Daniel Scheffler committed
377
        :param include_outliers:    whether to include tie points that have been marked as false-positives (if present)
378
379
380
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
381
        tbl = tbl if include_outliers else tbl[~tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else tbl
382
383
384
385
386
387
388
389
390
391
392
393
394
395

        shifts = np.array(tbl['ABS_SHIFT'])
        shifts_sq = [i * i for i in shifts if i != self.outFillVal]

        return np.sqrt(sum(shifts_sq) / len(shifts_sq))

    def calc_overall_mssim(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates the median value of all MSSIM values contained in tie point grid.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
396
        tbl = tbl if include_outliers else tbl[~tbl['OUTLIER']].copy()
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417

        mssim_col = np.array(tbl['MSSIM'])
        mssim_col = [i * i for i in mssim_col if i != self.outFillVal]

        return float(np.median(mssim_col))

    def plot_shift_distribution(self, include_outliers=True, unit='m', interactive=False, figsize=None, xlim=None,
                                ylim=None, fontsize=12, title='shift distribution'):
        # type: (bool, str, bool, tuple, list, list, int) -> tuple
        """Creates a 2D scatterplot containing the distribution of calculated X/Y-shifts.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        :param unit:                'm' for meters or 'px' for pixels (default: 'm')
        :param interactive:         interactive mode uses plotly for visualization
        :param figsize:             (xdim, ydim)
        :param xlim:                [xmin, xmax]
        :param ylim:                [ymin, ymax]
        :param fontsize:            size of all used fonts
        :param title:               the title to be plotted above the figure
        """

418
419
        if unit not in ['m', 'px']:
            raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." % unit)
420
421

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
422
        tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal]
Daniel Scheffler's avatar
Daniel Scheffler committed
423
424
        tbl_il = tbl[~tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else tbl
        tbl_ol = tbl[tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else None
425
426
        x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
        y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
427
428
        rmse = self.calc_rmse(include_outliers=False)  # always exclude outliers when calculating RMSE
        figsize = figsize if figsize else (10, 10)
429
430
431
432

        if interactive:
            from plotly.offline import iplot, init_notebook_mode
            import plotly.graph_objs as go
Daniel Scheffler's avatar
Daniel Scheffler committed
433
            # FIXME outliers are not plotted
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454

            init_notebook_mode(connected=True)

            # Create a trace
            trace = go.Scatter(
                x=tbl_il[x_attr],
                y=tbl_il[y_attr],
                mode='markers'
            )

            data = [trace]

            # Plot and embed in ipython notebook!
            iplot(data, filename='basic-scatter')

            return None, None

        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111)

Daniel Scheffler's avatar
Daniel Scheffler committed
455
            if include_outliers and 'OUTLIER' in tbl.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
456
                ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives')
457
458
459
460
461
462
463
464
465
466
467
468
            ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')

            # set axis limits
            if not xlim:
                xmax = np.abs(tbl_il[x_attr]).max()
                xlim = [-xmax, xmax]
            if not ylim:
                ymax = np.abs(tbl_il[y_attr]).max()
                ylim = [-ymax, ymax]
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

Daniel Scheffler's avatar
Daniel Scheffler committed
469
            # add text box containing RMSE of plotted shifts
470
            xlim, ylim = ax.get_xlim(), ax.get_ylim()
471
472
            plt.text(xlim[1] - (xlim[1] / 20), -ylim[1] + (ylim[1] / 20),
                     'RMSE:  %s m / %s px' % (np.round(rmse, 2), np.round(rmse / self.shift.xgsd, 2)),
473
                     ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
474

Daniel Scheffler's avatar
Daniel Scheffler committed
475
            # add grid and increase linewidth of middle line
476
477
478
            plt.grid()
            xgl = ax.get_xgridlines()
            middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
479
            middle_xgl.set_linewidth(2)
480
481
482
            middle_xgl.set_linestyle('-')
            ygl = ax.get_ygridlines()
            middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
483
            middle_ygl.set_linewidth(2)
484
485
            middle_ygl.set_linestyle('-')

Daniel Scheffler's avatar
Daniel Scheffler committed
486
487
            # set title and adjust tick labels
            ax.set_title(title, fontsize=fontsize)
488
489
            [tick.label.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
            [tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
Daniel Scheffler's avatar
Daniel Scheffler committed
490
491
            plt.xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
            plt.ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
492

493
494
            # add legend with labels in the right order
            handles, labels = ax.get_legend_handles_labels()
Daniel Scheffler's avatar
Daniel Scheffler committed
495
496
            leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3)
            leg.get_frame().set_edgecolor('black')
497

498
499
500
501
            plt.show()

            return fig, ax

502
    def dump_CoRegPoints_table(self, path_out=None):
503
504
505
506
507
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=self.dir_out, fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl"
                                                                % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                                                   self.COREG_obj.win_size_XY[1], self.shift.basename,
                                                                   self.ref.basename))
508
509
510
        if not self.q:
            print('Writing %s ...' % path_out)
        self.CoRegPoints_table.to_pickle(path_out)
511

512
    def to_GCPList(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
513
        # get copy of tie points grid without no data
Daniel Scheffler's avatar
Daniel Scheffler committed
514
515
516
517
518
        try:
            GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
        except AttributeError:
            # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded
            return []
519

520
        if getattr(GDF, 'empty'):  # GDF.empty returns AttributeError
521
522
            return []
        else:
523
            # exclude all points flagged as outliers
524
            if 'OUTLIER' in GDF.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
525
                GDF = GDF[GDF.OUTLIER == False].copy()
526
527
            avail_TP = len(GDF)

528
529
530
531
            if not avail_TP:
                # no point passed all validity checks
                return []

532
            if avail_TP > 7000:
533
534
535
                GDF = GDF.sample(7000)
                warnings.warn('By far not more than 7000 tie points can be used for warping within a limited '
                              'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen '
536
                              'out of the %s available tie points.' % avail_TP)
537

538
539
540
            # calculate GCPs
            GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
            GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
541
542
            GDF['GCP'] = GDF.apply(lambda GDF_row: gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
                                                            GDF_row.X_IM, GDF_row.Y_IM), axis=1)
543
544
545
            self.GCPList = GDF.GCP.tolist()

            if not self.q:
546
                print('Found %s valid tie points.' % len(self.GCPList))
547
548

            return self.GCPList
549

550
    def test_if_singleprocessing_equals_multiprocessing_result(self):
551
552
        # RANSAC filtering always produces different results because it includes random sampling
        self.tieP_filter_level = 1
553

Daniel Scheffler's avatar
Daniel Scheffler committed
554
        self.CPUs = None
555
        dataframe = self.get_CoRegPoints_table()
556
        mp_out = np.empty_like(dataframe.values)
557
        mp_out[:] = dataframe.values
Daniel Scheffler's avatar
Daniel Scheffler committed
558
        self.CPUs = 1
559
        dataframe = self.get_CoRegPoints_table()
560
        sp_out = np.empty_like(dataframe.values)
561
562
        sp_out[:] = dataframe.values

563
        return np.array_equal(sp_out, mp_out)
564

565
566
    def _get_line_by_PID(self, PID):
        return self.CoRegPoints_table.loc[PID, :]
567

568
    def _get_lines_by_PIDs(self, PIDs):
569
570
571
572
        assert isinstance(PIDs, list)
        lines = np.zeros((len(PIDs), self.CoRegPoints_table.shape[1]))
        for i, PID in enumerate(PIDs):
            lines[i, :] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
573
574
        return lines

575
    def to_PointShapefile(self, path_out=None, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):
576
        # type: (str, bool, str) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
577
        """Writes the calculated tie points grid to a point shapefile containing
578
        Tie_Point_Grid.CoRegPoints_table as attribute table. This shapefile can easily be displayed using GIS software.
579
580
581

        :param path_out:        <str> the output path. If not given, it is automatically defined.
        :param skip_nodata:     <bool> whether to skip all points where no valid match could be found
582
        :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to
583
584
                                identify points where no valid match could be found
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
585

586
587
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
Daniel Scheffler's avatar
Daniel Scheffler committed
588
589
590
591
592

        # replace boolean values (cannot be written)
        for col in GDF2pass.columns:
            if GDF2pass[col].dtype == np.bool:
                GDF2pass[col] = GDF2pass[col].astype(int)
593
594
595
596
597
598
599
600
        GDF2pass = GDF2pass.replace(False, 0)  # replace all remaining booleans where dtype is not np.bool but np.object
        GDF2pass = GDF2pass.replace(True, 1)

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
601
        if not self.q:
602
            print('Writing %s ...' % path_out)
603
604
        GDF2pass.to_file(path_out)

605
606
607
608
609
610
    def _to_PointShapefile(self, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):
        warnings.warn(DeprecationWarning(
            "'_tiepoints_grid_to_PointShapefile' is deprecated."  # TODO delete if other method validated
            " 'tiepoints_grid_to_PointShapefile' is much faster."))
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
611
        shapely_points = GDF2pass['geometry'].values.tolist()
612
        attr_dicts = [collections.OrderedDict(zip(GDF2pass.columns, GDF2pass.loc[i].values)) for i in GDF2pass.index]
613

614
        fName_out = "CoRegPoints_grid%s_ws%s.shp" % (self.grid_res, self.COREG_obj.win_size_XY)
615
616
617
        path_out = os.path.join(self.dir_out, fName_out)
        IO.write_shp(path_out, shapely_points, prj=self.COREG_obj.shift.prj, attrDict=attr_dicts)

618
    def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
619
        # type: (str) -> GeoArray
620
621
622
623
624
        """Saves the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield
        (e.g. using ArcGIS)

        :param path_out:    <str> the output path. If not given, it is automatically defined.
        :param fmt:         <str> output raster format string
625
626
        :param mode:        <str> The mode how the output is written ('uv' or 'md'; default: 'md')
                                    'uv': outputs X-/Y shifts
627
628
629
                                    'md': outputs magnitude and direction
        """

630
631
632
633
        assert mode in ['uv', 'md'], "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' (outputs magnitude and " \
                                     "direction)'. Got %s." % mode
        attr_b1 = 'X_SHIFT_M' if mode == 'uv' else 'ABS_SHIFT'
        attr_b2 = 'Y_SHIFT_M' if mode == 'uv' else 'ANGLE'
634

635
636
637
638
639
        xshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b1],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal=self.outFillVal)
640

641
642
643
644
645
        yshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b2],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal=self.outFillVal)
646
647
648

        out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)

649
650
651
652
653
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
654
655
656
657
658

        out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')

        return out_GA

659
    def _to_Raster_using_KrigingOLD(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
660
                                    path_out=None, tilepos=None):
661
        warnings.warn(DeprecationWarning("'to_Raster_using_KrigingOLD' is deprecated. Use to_Raster_using_Kriging "
662
                                         "instead."))  # TODO delete
663

664
665
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
666
667

        # subset if tilepos is given
668
669
670
        rows, cols = tilepos if tilepos else self.shift.shape
        GDF2pass = GDF2pass.loc[(GDF2pass['X_IM'] >= cols[0]) & (GDF2pass['X_IM'] <= cols[1]) &
                                (GDF2pass['Y_IM'] >= rows[0]) & (GDF2pass['Y_IM'] <= rows[1])]
671

672
        X_coords, Y_coords, ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'], GDF2pass[attrName]
673

674
        xmin, ymin, xmax, ymax = GDF2pass.total_bounds
675

676
677
        grid_res = outGridRes if outGridRes else int(min(xmax - xmin, ymax - ymin) / 250)
        grid_x, grid_y = np.arange(xmin, xmax + grid_res, grid_res), np.arange(ymax, ymin - grid_res, -grid_res)
678
679
680

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
681
        from pykrige.ok import OrdinaryKriging
682
683
684
685
686
687
688
689
690
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical', verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y)  # ,backend='C',)

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="Kriging__%s__grid%s_ws(%s_%s).tif"
                                          % (attrName, self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1]))
        print('Writing %s ...' % path_out)
691
        # add a half pixel grid points are centered on the output pixels
692
693
694
        xmin, ymin, xmax, ymax = xmin - grid_res / 2, ymin - grid_res / 2, xmax + grid_res / 2, ymax + grid_res / 2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res),
                                prj=self.COREG_obj.shift.prj)
695
696
697

        return zvalues

698
    def to_Raster_using_Kriging(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
699
                                fName_out=None, tilepos=None, tilesize=500, mp=None):
700

701
        mp = False if self.CPUs == 1 else True
702
703
        self._Kriging_sp(attrName, skip_nodata=skip_nodata, skip_nodata_col=skip_nodata_col,
                         outGridRes=outGridRes, fName_out=fName_out, tilepos=tilepos)
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725

        # if mp:
        #     tilepositions = UTL.get_image_tileborders([tilesize,tilesize],self.tgt_shape)
        #     args_kwargs_dicts=[]
        #     for tp in tilepositions:
        #         kwargs_dict = {'skip_nodata':skip_nodata,'skip_nodata_col':skip_nodata_col,'outGridRes':outGridRes,
        #                        'fName_out':fName_out,'tilepos':tp}
        #         args_kwargs_dicts.append({'args':[attrName],'kwargs':kwargs_dict})
        #     # self.kriged=[]
        #     # for i in args_kwargs_dicts:
        #     #     res = self.Kriging_mp(i)
        #     #     self.kriged.append(res)
        #     #     print(res)
        #
        #     with multiprocessing.Pool() as pool:
        #        self.kriged = pool.map(self.Kriging_mp,args_kwargs_dicts)
        # else:
        #     self.Kriging_sp(attrName,skip_nodata=skip_nodata,skip_nodata_col=skip_nodata_col,
        #                     outGridRes=outGridRes,fName_out=fName_out,tilepos=tilepos)
        res = self.kriged if mp else None
        return res

726
727
    def _Kriging_sp(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                    fName_out=None, tilepos=None):
728
729
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
730

731
732
733
734
735
736
737
738
739
740
        #         # subset if tilepos is given
        # #        overlap_factor =
        #         rows,cols = tilepos if tilepos else self.tgt_shape
        #         xvals, yvals = np.sort(GDF2pass['X_IM'].values.flat),np.sort(GDF2pass['Y_IM'].values.flat)
        #         cS,cE = UTL.find_nearest(xvals,cols[0],'off',1), UTL.find_nearest(xvals,cols[1],'on',1)
        #         rS,rE = UTL.find_nearest(yvals,rows[0],'off',1), UTL.find_nearest(yvals,rows[1],'on',1)
        #         # GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
        #         #                                (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]
        #         GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cS)&(GDF2pass['X_IM']<=cE)&
        #                                        (GDF2pass['Y_IM']>=rS)&(GDF2pass['Y_IM']<=rE)]
741

742
        X_coords, Y_coords, ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'], GDF2pass[attrName]
743

744
        xmin, ymin, xmax, ymax = GDF2pass.total_bounds
745

746
747
        grid_res = outGridRes if outGridRes else int(min(xmax - xmin, ymax - ymin) / 250)
        grid_x, grid_y = np.arange(xmin, xmax + grid_res, grid_res), np.arange(ymax, ymin - grid_res, -grid_res)
748
749
750

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
751
        from pykrige.ok import OrdinaryKriging
752
753
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical', verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y, backend='C', n_closest_points=12)
754

755
        if self.CPUs is None or self.CPUs > 1:
756
            fName_out = fName_out if fName_out else \
757
                "Kriging__%s__grid%s_ws%s_%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY, tilepos)
758
759
        else:
            fName_out = fName_out if fName_out else \
760
761
762
                "Kriging__%s__grid%s_ws%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY)
        path_out = get_generic_outpath(dir_out=self.dir_out, fName_out=fName_out)
        print('Writing %s ...' % path_out)
763
        # add a half pixel grid points are centered on the output pixels
764
765
766
        xmin, ymin, xmax, ymax = xmin - grid_res / 2, ymin - grid_res / 2, xmax + grid_res / 2, ymax + grid_res / 2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res),
                                prj=self.COREG_obj.shift.prj)
767
768
769

        return zvalues

770
    def _Kriging_mp(self, args_kwargs_dict):
771
772
        args = args_kwargs_dict.get('args', [])
        kwargs = args_kwargs_dict.get('kwargs', [])
773

774
        return self._Kriging_sp(*args, **kwargs)
775
776


777
class Tie_Point_Refiner(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
778
    def __init__(self, GDF, min_reliability=60, rs_max_outlier=10, rs_tolerance=2.5, rs_max_iter=15,
779
                 rs_exclude_previous_outliers=True, rs_timeout=20, q=False):
780
        """A class for performing outlier detection.
Daniel Scheffler's avatar
Daniel Scheffler committed
781

782
783
        :param GDF:                             GeoDataFrame like TiePointGrid.CoRegPoints_table containing all tie
                                                points to be filtered and the corresponding metadata
Daniel Scheffler's avatar
Daniel Scheffler committed
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
        :param min_reliability:                 <float, int> minimum threshold for previously computed tie X/Y shift
                                                reliability (default: 60%)
        :param rs_max_outlier:                  <float, int> RANSAC: maximum percentage of outliers to be detected
                                                (default: 10%)
        :param rs_tolerance:                    <float, int> RANSAC: percentage tolerance for max_outlier_percentage
                                                (default: 2.5%)
        :param rs_max_iter:                     <int> RANSAC: maximum iterations for finding the best RANSAC threshold
                                                (default: 15)
        :param rs_exclude_previous_outliers:    <bool> RANSAC: whether to exclude points that have been flagged as
                                                outlier by earlier filtering (default:True)
        :param rs_timeout:                      <float, int> RANSAC: timeout for iteration loop in seconds (default: 20)

        :param q:
        """

        self.GDF = GDF.copy()
        self.min_reliability = min_reliability
        self.rs_max_outlier_percentage = rs_max_outlier
        self.rs_tolerance = rs_tolerance
        self.rs_max_iter = rs_max_iter
        self.rs_exclude_previous_outliers = rs_exclude_previous_outliers
        self.rs_timeout = rs_timeout
        self.q = q
        self.new_cols = []
808
809
        self.ransac_model_robust = None

810
    def run_filtering(self, level=3):
811
812
        """Filter tie points used for shift correction.

813
        :param level:   tie point filter level (default: 3).
814
815
816
817
818
819
820
821
                        NOTE: lower levels are also included if a higher level is chosen
                            - Level 0: no tie point filtering
                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                reliability according to internal tests
                            - Level 2: SSIM filtering - filters all tie points out where shift
                                correction does not increase image similarity within matching window
                                (measured by mean structural similarity index)
                            - Level 3: RANSAC outlier detection
Daniel Scheffler's avatar
Daniel Scheffler committed
822
823
824
825

        :return:
        """

826
827
        # TODO catch empty GDF

828
        # RELIABILITY filtering
829
        if level > 0:
830
            marked_recs = GeoSeries(self._reliability_thresholding())
831
832
            self.GDF['L1_OUTLIER'] = marked_recs
            self.new_cols.append('L1_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
833

834
            if not self.q:
835
                print('%s tie points flagged by level 1 filtering (reliability).'
Daniel Scheffler's avatar
Daniel Scheffler committed
836
                      % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
837

838
        # SSIM filtering
839
        if level > 1:
840
            marked_recs = GeoSeries(self._SSIM_filtering())
841
842
            self.GDF['L2_OUTLIER'] = marked_recs
            self.new_cols.append('L2_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
843

844
            if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
845
                print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
846

847
        # RANSAC filtering
848
        if level > 2:
Daniel Scheffler's avatar
Daniel Scheffler committed
849
            # exclude previous outliers
Daniel Scheffler's avatar
Daniel Scheffler committed
850
            ransacInGDF = self.GDF[~self.GDF[self.new_cols].any(axis=1)].copy() \
851
                if self.rs_exclude_previous_outliers else self.GDF
Daniel Scheffler's avatar
Daniel Scheffler committed
852

853
            if len(ransacInGDF) > 4:
854
                # running RANSAC with less than four tie points makes no sense
Daniel Scheffler's avatar
Daniel Scheffler committed
855
856

                marked_recs = GeoSeries(self._RANSAC_outlier_detection(ransacInGDF))
857
858
                # we need to join a list here because otherwise it's merged by the 'index' column
                self.GDF['L3_OUTLIER'] = marked_recs.tolist()
Daniel Scheffler's avatar
Daniel Scheffler committed
859

860
                if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
861
862
                    print('%s tie points flagged by level 3 filtering (RANSAC)'
                          % (len(marked_recs[marked_recs])))
863
864
865
866
            else:
                print('RANSAC skipped because too less valid tie points have been found.')
                self.GDF['L3_OUTLIER'] = False

867
            self.new_cols.append('L3_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
868

869
870
871
872
873
        self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
        self.new_cols.append('OUTLIER')

        return self.GDF, self.new_cols

Daniel Scheffler's avatar
Daniel Scheffler committed
874
    def _reliability_thresholding(self):
875
        """Exclude all records where estimated reliability of the calculated shifts is below the given threshold."""
876

Daniel Scheffler's avatar
Daniel Scheffler committed
877
        return self.GDF.RELIABILITY < self.min_reliability
878
879

    def _SSIM_filtering(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
880
        """Exclude all records where SSIM decreased."""
881

882
        # ssim_diff  = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE'])
883

884
885
        # self.GDF.SSIM_IMPROVED = \
        #     self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1)
886

887
        return self.GDF.SSIM_IMPROVED is False
888

Daniel Scheffler's avatar
Daniel Scheffler committed
889
890
    def _RANSAC_outlier_detection(self, inGDF):
        """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm."""
891

Daniel Scheffler's avatar
Daniel Scheffler committed
892
        src_coords = np.array(inGDF[['X_UTM', 'Y_UTM']])
893
        xyShift = np.array(inGDF[['X_SHIFT_M', 'Y_SHIFT_M']])
894
895
896
        est_coords = src_coords + xyShift

        for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']):
897
            assert co.ndim == 2 and co.shape[1] == 2, "'%s' must have shape [Nx2]. Got shape %s." % (n, co.shape)
898

Daniel Scheffler's avatar
Daniel Scheffler committed
899
        if not 0 < self.rs_max_outlier_percentage < 100: raise ValueError
900
        min_inlier_percentage = 100 - self.rs_max_outlier_percentage
901
902
903
904
905
906

        class PolyTF_1(PolynomialTransform):
            def estimate(*data):
                return PolynomialTransform.estimate(*data, order=1)

        # robustly estimate affine transform model with RANSAC
907
        # eliminates not more than the given maximum outlier percentage of the tie points
908
909

        model_robust, inliers = None, None
Daniel Scheffler's avatar