Tie_Point_Grid.py 48.2 KB
Newer Older
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-
__author__='Daniel Scheffler'

import collections
import multiprocessing
import os
import warnings
8
import time
9
10

# custom
11
12
13
14
try:
    import gdal
except ImportError:
    from osgeo import gdal
15
import numpy as np
16
from matplotlib import pyplot as plt
17
from geopandas         import GeoDataFrame, GeoSeries
18
from shapely.geometry  import Point
19
20
from skimage.measure   import points_in_poly, ransac
from skimage.transform import AffineTransform, PolynomialTransform
21
22

# internal modules
23
24
from .CoReg import COREG
from . import io as IO
25
26
27
28
from py_tools_ds.geo.projection          import isProjectedOrGeographic, get_UTMzone, dict_to_proj4, proj4_to_WKT
from py_tools_ds.io.pathgen              import get_generic_outpath
from py_tools_ds.processing.progress_mon import ProgressBar
from py_tools_ds.geo.vector.conversion   import points_to_raster
29
from geoarray import GeoArray
30
31
32
33
34
35
36



global_shared_imref    = None
global_shared_im2shift = None


37
38
class Tie_Point_Grid(object):
    """See help(Tie_Point_Grid) for documentation!"""
39

40
    def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
41
                 tieP_filter_level=3, outlDetect_settings=None, dir_out=None, CPUs=None, progress=True, v=False, q=False):
42

43
44
45
        """Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
        are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
        correction performs a polynomial transformation using te calculated shifts of each point in the grid as GCPs.
46
        Thus 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target image.
47

48
        :param COREG_obj(object):       an instance of COREG class
49
        :param grid_res:                grid resolution in pixels of the target image (x-direction)
50
        :param max_points(int):         maximum number of points used to find coregistration tie points
51
52
53
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
Daniel Scheffler's avatar
Daniel Scheffler committed
54
        :param outFillVal(int):         if given the generated tie points grid is filled with this value in case
55
                                        no match could be found during co-registration (default: -9999)
56
57
58
59
60
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average, mode,
                                                       max, min, med, q1, q3)
                                        default: cubic (highly recommended)
61
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 3).
62
                                        NOTE: lower levels are also included if a higher level is chosen
63
                                            - Level 0: no tie point filtering
64
65
66
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
67
68
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
69
                                            - Level 3: RANSAC outlier detection
70
71
72
73
        :param outlDetect_settings      a dictionary with the settings to be passed to
                                        arosics.TiePointGrid.Tie_Point_Refiner. Available keys: min_reliability,
                                        rs_max_outlier, rs_tolerance, rs_max_iter, rs_exclude_previous_outliers,
                                        rs_timeout, q. See documentation there.
74
75
        :param dir_out(str):            output directory to be used for all outputs if nothing else is given
                                        to the individual methods
Daniel Scheffler's avatar
Daniel Scheffler committed
76
        :param CPUs(int):               number of CPUs to use during calculation of tie points grid
77
                                        (default: None, which means 'all CPUs available')
78
        :param progress(bool):          show progress bars (default: True)
79
80
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
81
        """
82
83
84

        if not isinstance(COREG_obj, COREG): raise ValueError("'COREG_obj' must be an instance of COREG class.")

85
86
        self.COREG_obj         = COREG_obj
        self.grid_res          = grid_res
87
        self.max_points        = max_points
88
89
90
        self.outFillVal        = outFillVal
        self.rspAlg_calc       = resamp_alg_calc
        self.tieP_filter_level = tieP_filter_level
91
        self.outlDetect_settings = outlDetect_settings if outlDetect_settings else dict(q=q)
92
93
94
95
96
        self.dir_out           = dir_out
        self.CPUs              = CPUs
        self.v                 = v
        self.q                 = q        if not v else False # overridden by v
        self.progress          = progress if not q else False # overridden by q
97

98
99
        self.ref               = self.COREG_obj.ref
        self.shift             = self.COREG_obj.shift
100

101
102


103
        self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
104
105
106
        self._CoRegPoints_table           = None # set by self.CoRegPoints_table
        self._GCPList                     = None # set by self.to_GCPList()
        self.kriged                       = None # set by Raster_using_Kriging()
107
108


109
110
111
112
113
114
115
116
117
118
    mean_x_shift_px   = property(lambda self:
        self.CoRegPoints_table['X_SHIFT_PX'][self.CoRegPoints_table['X_SHIFT_PX']!=self.outFillVal].mean())
    mean_y_shift_px   = property(lambda self:
        self.CoRegPoints_table['Y_SHIFT_PX'][self.CoRegPoints_table['Y_SHIFT_PX']!=self.outFillVal].mean())
    mean_x_shift_map  = property(lambda self:
        self.CoRegPoints_table['X_SHIFT_M'][self.CoRegPoints_table['X_SHIFT_M']!=self.outFillVal].mean())
    mean_y_shift_map  = property(lambda self:
        self.CoRegPoints_table['Y_SHIFT_M'][self.CoRegPoints_table['Y_SHIFT_M']!=self.outFillVal].mean())


119
120
    @property
    def CoRegPoints_table(self):
121
122
        """Returns a GeoDataFrame with the columns 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE',
        'Y_WIN_SIZE','X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE' containing all
123
        information containing all the results from coregistration for all points in the tie points grid.
124
        """
125
126
127
128
129
130
        if self._CoRegPoints_table is not None:
            return self._CoRegPoints_table
        else:
            self._CoRegPoints_table = self.get_CoRegPoints_table()
            return self._CoRegPoints_table

131

132
133
134
135
136
137
138
    @CoRegPoints_table.setter
    def CoRegPoints_table(self, CoRegPoints_table):
        self._CoRegPoints_table = CoRegPoints_table


    @property
    def GCPList(self):
139
140
        """Returns a list of GDAL compatible GCP objects.
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
141

142
143
144
145
        if self._GCPList:
            return self._GCPList
        else:
            self._GCPList = self.to_GCPList()
146
            return self._GCPList
147
148
149
150
151
152
153
154


    @GCPList.setter
    def GCPList(self, GCPList):
        self._GCPList = GCPList


    def _get_imXY__mapXY_points(self, grid_res):
155
156
157
158
159
160
        """Returns a numpy array containing possible positions for coregistration tie points according to the given
        grid resolution.

        :param grid_res:
        :return:
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
161

162
        if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
163
            print('Initializing tie points grid...')
164

165
166
167
        Xarr,Yarr       = np.meshgrid(np.arange(0,self.shift.shape[1],grid_res),
                                      np.arange(0,self.shift.shape[0],grid_res))

Daniel Scheffler's avatar
Daniel Scheffler committed
168
169
        mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr*self.shift.gt[1]
        mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr*abs(self.shift.gt[5])
170
171
172
173
174
175
176
177
178

        XY_points      = np.empty((Xarr.size,2),Xarr.dtype)
        XY_points[:,0] = Xarr.flat
        XY_points[:,1] = Yarr.flat

        XY_mapPoints      = np.empty((mapXarr.size,2),mapXarr.dtype)
        XY_mapPoints[:,0] = mapXarr.flat
        XY_mapPoints[:,1] = mapYarr.flat

Daniel Scheffler's avatar
Daniel Scheffler committed
179
180
        assert XY_points.shape == XY_mapPoints.shape

181
182
183
        return XY_points,XY_mapPoints


184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    def _exclude_bad_XYpos(self, GDF):
        """Excludes all points outside of the image overlap area and all points where the bad data mask is True (if given).

        :param GDF:     <geopandas.GeoDataFrame> must include the columns 'X_UTM' and 'Y_UTM'
        :return:
        """

        # exclude all points outside of overlap area
        inliers = points_in_poly(self.XY_mapPoints,
                                 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
        GDF = GDF[inliers].copy()
        #GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower

        assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!' # FIXME track that


        # exclude all point where bad data mask is True (e.g. points on clouds etc.)
Daniel Scheffler's avatar
Daniel Scheffler committed
201
        orig_len_GDF       = len(GDF) # length of GDF after dropping all points outside the overlap polygon
202
203
204
205
206
207
208
        mapXY              = np.array(GDF.loc[:,['X_UTM','Y_UTM']])
        GDF['REF_BADDATA'] = self.COREG_obj.ref  .mask_baddata.read_pointData(mapXY) \
                                if self.COREG_obj.ref  .mask_baddata is not None else False
        GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY)\
                                if self.COREG_obj.shift.mask_baddata is not None else False
        GDF                = GDF[(GDF['REF_BADDATA']==False) & (GDF['TGT_BADDATA']==False)]
        if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
Daniel Scheffler's avatar
Daniel Scheffler committed
209
210
211
            if not self.q:
                print('According to the provided bad data mask(s) %s points of initially %s have been excluded.'
                      %(orig_len_GDF-len(GDF), orig_len_GDF))
212
213
214
215

        return GDF


216
217
    @staticmethod
    def _get_spatial_shifts(coreg_kwargs):
Daniel Scheffler's avatar
Daniel Scheffler committed
218
        # unpack
219
220
221
        pointID    = coreg_kwargs['pointID']
        fftw_works = coreg_kwargs['fftw_works']
        del coreg_kwargs['pointID'], coreg_kwargs['fftw_works']
222

Daniel Scheffler's avatar
Daniel Scheffler committed
223
        # assertions
224
225
        assert global_shared_imref    is not None
        assert global_shared_im2shift is not None
Daniel Scheffler's avatar
Daniel Scheffler committed
226
227

        # run CoReg
228
        CR = COREG(global_shared_imref, global_shared_im2shift, CPUs=1, **coreg_kwargs)
229
        CR.fftw_works = fftw_works
230
        CR.calculate_spatial_shifts()
Daniel Scheffler's avatar
Daniel Scheffler committed
231
232

        # fetch results
233
        last_err           = CR.tracked_errors[-1] if CR.tracked_errors else None
234
235
236
        win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
        CR_res   = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
                    CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
237
                    CR.shift_reliability, last_err]
238
239

        return [pointID]+CR_res
240
241


242
    def get_CoRegPoints_table(self):
243
244
        assert self.XY_points is not None and self.XY_mapPoints is not None

Daniel Scheffler's avatar
Daniel Scheffler committed
245
        # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM' (convert imCoords to mapCoords
246
        XYarr2PointGeom = np.vectorize(lambda X,Y: Point(X,Y), otypes=[Point])
247
248
249
250
251
252
253
        geomPoints      = np.array(XYarr2PointGeom(self.XY_mapPoints[:,0],self.XY_mapPoints[:,1]))

        if isProjectedOrGeographic(self.COREG_obj.shift.prj)=='geographic':
            crs = dict(ellps='WGS84', datum='WGS84', proj='longlat')
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj)=='projected':
            UTMzone = abs(get_UTMzone(prj=self.COREG_obj.shift.prj))
            south   = get_UTMzone(prj=self.COREG_obj.shift.prj)<0
254
            crs     = dict(ellps='WGS84', datum='WGS84', proj='utm', zone=UTMzone, south=south, units='m', no_defs=True)
255
256
257
258
            if not south: del crs['south']
        else:
            crs = None

259
260
        GDF                          = GeoDataFrame(index=range(len(geomPoints)),crs=crs,
                                                    columns=['geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'])
261
262
        GDF       ['geometry']       = geomPoints
        GDF       ['POINT_ID']       = range(len(geomPoints))
263
        GDF.loc[:,['X_IM' ,'Y_IM' ]] = self.XY_points
264
        GDF.loc[:,['X_UTM','Y_UTM']] = self.XY_mapPoints
265

266
267
        # exclude offsite points and points on bad data mask
        GDF = self._exclude_bad_XYpos(GDF)
268
269
270
        if GDF.empty:
            self.CoRegPoints_table = GDF
            return self.CoRegPoints_table
271

272
        # choose a random subset of points if a maximum number has been given
273
        if self.max_points and len(GDF) > self.max_points:
274
            GDF = GDF.sample(self.max_points).copy()
275

276
277
278
279
        # equalize pixel grids in order to save warping time
        if len(GDF)>100:
            self.COREG_obj.equalize_pixGrids() # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is connected to that

280
281
        # declare global variables needed for self._get_spatial_shifts()
        global global_shared_imref,global_shared_im2shift
282
283
        assert self.ref  .footprint_poly # this also checks for mask_nodata and nodata value
        assert self.shift.footprint_poly
284
285
286

        # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if
        # neighboured matching windows overlap during reading from disk!!
Daniel Scheffler's avatar
Daniel Scheffler committed
287
288
289
        self.ref.cache_array_subset([self.COREG_obj.ref.band4match]) # only sets geoArr._arr_cache; does not change number of bands
        self.shift.cache_array_subset([self.COREG_obj.shift.band4match])

290
291
        global_shared_imref    = self.ref
        global_shared_im2shift = self.shift
292
293

        # get all variations of kwargs for coregistration
294
295
        get_coreg_kwargs = lambda pID, wp: dict(
            pointID            = pID,
296
            fftw_works         = self.COREG_obj.fftw_works,
297
298
299
300
301
            wp                 = wp,
            ws                 = self.COREG_obj.win_size_XY,
            resamp_alg_calc    = self.rspAlg_calc,
            footprint_poly_ref = self.COREG_obj.ref.poly,
            footprint_poly_tgt = self.COREG_obj.shift.poly,
Daniel Scheffler's avatar
Daniel Scheffler committed
302
303
            r_b4match          = self.ref.band4match+1,   # band4match is internally saved as index, starting from 0
            s_b4match          = self.shift.band4match+1, # band4match is internally saved as index, starting from 0
304
305
306
            max_iter           = self.COREG_obj.max_iter,
            max_shift          = self.COREG_obj.max_shift,
            nodata             = (self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
Daniel Scheffler's avatar
Daniel Scheffler committed
307
            force_quadratic_win= self.COREG_obj.force_quadratic_win,
308
309
310
311
312
            binary_ws          = self.COREG_obj.bin_ws,
            v                  = False, # otherwise this would lead to massive console output
            q                  = True,  # otherwise this would lead to massive console output
            ignore_errors      = True
        )
313
314
315
        list_coreg_kwargs = (get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index) # generator

        # run co-registration for whole grid
316
        if self.CPUs is None or self.CPUs>1:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
317
            if not self.q:
318
                cpus = self.CPUs if self.CPUs is not None else multiprocessing.cpu_count()
Daniel Scheffler's avatar
Daniel Scheffler committed
319
                print("Calculating tie points grid (%s points) using %s CPU cores..." %(len(GDF), cpus))
320

321
            with multiprocessing.Pool(self.CPUs) as pool:
322
323
324
325
                if self.q or not self.progress:
                    results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
                else:
                    results = pool.map_async(self._get_spatial_shifts, list_coreg_kwargs, chunksize=1)
326
                    bar     = ProgressBar(prefix='\tprogress:')
327
328
329
                    while True:
                        time.sleep(.1)
                        numberDone = len(GDF)-results._number_left # this does not really represent the remaining tasks but the remaining chunks -> thus chunksize=1
330
331
                        if self.progress:
                            bar.print_progress(percent=numberDone/len(GDF)*100)
332
                        if results.ready():
333
                            results = results.get() # <= this is the line where multiprocessing can freeze if an exception appears within COREG ans is not raised
334
                            break
Daniel Scheffler's avatar
Daniel Scheffler committed
335

336
        else:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
337
            if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
338
                print("Calculating tie points grid (%s points) 1 CPU core..." %len(GDF))
339
            results = np.empty((len(geomPoints),14), np.object)
340
            bar     = ProgressBar(prefix='\tprogress:')
341
            for i,coreg_kwargs in enumerate(list_coreg_kwargs):
342
343
                if self.progress:
                    bar.print_progress((i+1)/len(GDF)*100)
344
345
                results[i,:] = self._get_spatial_shifts(coreg_kwargs)

346
         # merge results with GDF
347
348
349
        records = GeoDataFrame(np.array(results, np.object),
                               columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M',
                                        'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
350
351
                                        'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])

352
353
354
        GDF = GDF.merge(records, on='POINT_ID', how="inner")
        GDF = GDF.fillna(int(self.outFillVal))

355
356
357
        if not self.q:
            print("Found %s matches." % len(GDF[GDF.LAST_ERR == int(self.outFillVal)]))

358
359
        # filter tie points according to given filter level
        if self.tieP_filter_level>0:
360
361
            if not self.q:
                print('Performing validity checks...')
362
            TPR                   = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], **self.outlDetect_settings)
363
364
365
            GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
            GDF                   = GDF.merge(GDF_filt[ ['POINT_ID']+new_columns], on='POINT_ID', how="outer")
        GDF = GDF.fillna(int(self.outFillVal))
366

367
        self.CoRegPoints_table = GDF
368
369
370
371

        return self.CoRegPoints_table


372
373
374
375
    def calc_rmse(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates root mean square error of absolute shifts from the tie point grid.

Daniel Scheffler's avatar
Daniel Scheffler committed
376
        :param include_outliers:    whether to include tie points that have been marked as false-positives (if present)
377
378
379
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
380
        tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == False].copy() if 'OUTLIER' in tbl.columns else tbl
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422

        shifts = np.array(tbl['ABS_SHIFT'])
        shifts_sq = [i * i for i in shifts if i != self.outFillVal]

        return np.sqrt(sum(shifts_sq) / len(shifts_sq))


    def calc_overall_mssim(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates the median value of all MSSIM values contained in tie point grid.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        """

        tbl = self.CoRegPoints_table
        tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == False].copy()

        mssim_col = np.array(tbl['MSSIM'])
        mssim_col = [i * i for i in mssim_col if i != self.outFillVal]

        return float(np.median(mssim_col))


    def plot_shift_distribution(self, include_outliers=True, unit='m', interactive=False, figsize=None, xlim=None,
                                ylim=None, fontsize=12, title='shift distribution'):
        # type: (bool, str, bool, tuple, list, list, int) -> tuple
        """Creates a 2D scatterplot containing the distribution of calculated X/Y-shifts.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        :param unit:                'm' for meters or 'px' for pixels (default: 'm')
        :param interactive:         interactive mode uses plotly for visualization
        :param figsize:             (xdim, ydim)
        :param xlim:                [xmin, xmax]
        :param ylim:                [ymin, ymax]
        :param fontsize:            size of all used fonts
        :param title:               the title to be plotted above the figure
        """

        if not unit in ['m', 'px']:
            raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." %unit)

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
423
424
425
        tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal]
        tbl_il = tbl[tbl['OUTLIER'] == False].copy() if 'OUTLIER' in tbl.columns else tbl
        tbl_ol = tbl[tbl['OUTLIER'] == True].copy() if 'OUTLIER' in tbl.columns else None
426
427
428
429
430
431
432
433
        x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
        y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
        rmse   = self.calc_rmse(include_outliers=False) # always exclude outliers when calculating RMSE
        figsize = figsize if figsize else (10,10)

        if interactive:
            from plotly.offline import iplot, init_notebook_mode
            import plotly.graph_objs as go
Daniel Scheffler's avatar
Daniel Scheffler committed
434
            # FIXME outliers are not plotted
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455

            init_notebook_mode(connected=True)

            # Create a trace
            trace = go.Scatter(
                x=tbl_il[x_attr],
                y=tbl_il[y_attr],
                mode='markers'
            )

            data = [trace]

            # Plot and embed in ipython notebook!
            iplot(data, filename='basic-scatter')

            return None, None

        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111)

Daniel Scheffler's avatar
Daniel Scheffler committed
456
            if include_outliers and 'OUTLIER' in tbl.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
457
                ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives')
458
459
460
461
462
463
464
465
466
467
468
469
            ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')

            # set axis limits
            if not xlim:
                xmax = np.abs(tbl_il[x_attr]).max()
                xlim = [-xmax, xmax]
            if not ylim:
                ymax = np.abs(tbl_il[y_attr]).max()
                ylim = [-ymax, ymax]
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

Daniel Scheffler's avatar
Daniel Scheffler committed
470
            # add text box containing RMSE of plotted shifts
471
472
473
474
            xlim, ylim = ax.get_xlim(), ax.get_ylim()
            plt.text(xlim[1]-(xlim[1]/20),-ylim[1]+(ylim[1]/20),
                     'RMSE:  %s m / %s px' %(np.round(rmse, 2), np.round(rmse/self.shift.xgsd, 2)),
                     ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
475

Daniel Scheffler's avatar
Daniel Scheffler committed
476
            # add grid and increase linewidth of middle line
477
478
479
            plt.grid()
            xgl = ax.get_xgridlines()
            middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
480
            middle_xgl.set_linewidth(2)
481
482
483
            middle_xgl.set_linestyle('-')
            ygl = ax.get_ygridlines()
            middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
484
            middle_ygl.set_linewidth(2)
485
486
            middle_ygl.set_linestyle('-')

Daniel Scheffler's avatar
Daniel Scheffler committed
487
488
            # set title and adjust tick labels
            ax.set_title(title, fontsize=fontsize)
489
490
            [tick.label.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
            [tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
Daniel Scheffler's avatar
Daniel Scheffler committed
491
492
            plt.xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
            plt.ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
493

494
495
            # add legend with labels in the right order
            handles, labels = ax.get_legend_handles_labels()
Daniel Scheffler's avatar
Daniel Scheffler committed
496
497
            leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3)
            leg.get_frame().set_edgecolor('black')
498

499
500
501
502
503
            plt.show()

            return fig, ax


504
505
506
507
508
509
510
    def dump_CoRegPoints_table(self, path_out=None):
        path_out = path_out if path_out else get_generic_outpath(dir_out=self.dir_out,
            fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
        if not self.q:
            print('Writing %s ...' % path_out)
        self.CoRegPoints_table.to_pickle(path_out)
511
512


513
    def to_GCPList(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
514
        # get copy of tie points grid without no data
Daniel Scheffler's avatar
Daniel Scheffler committed
515
516
517
518
519
        try:
            GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
        except AttributeError:
            # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded
            return []
520

521
522
523
        if getattr(GDF,'empty'): # GDF.empty returns AttributeError
            return []
        else:
524
            # exclude all points flagged as outliers
525
526
            if 'OUTLIER' in GDF.columns:
                GDF = GDF[GDF.OUTLIER == False].copy()
527
528
            avail_TP = len(GDF)

529
530
531
532
            if not avail_TP:
                # no point passed all validity checks
                return []

533
534
535
536
537
            if avail_TP>7000:
                GDF = GDF.sample(7000)
                warnings.warn('By far not more than 7000 tie points can be used for warping within a limited '
                              'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen '
                              'out of the %s available tie points.' %avail_TP)
538

539
540
541
542
543
544
545
546
            # calculate GCPs
            GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
            GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
            GDF['GCP']       = GDF.apply(lambda GDF_row:    gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
                                                                     GDF_row.X_IM, GDF_row.Y_IM), axis=1)
            self.GCPList = GDF.GCP.tolist()

            if not self.q:
547
                print('Found %s valid tie points.' %len(self.GCPList))
548
549

            return self.GCPList
550
551


552
    def test_if_singleprocessing_equals_multiprocessing_result(self):
553
554
        self.tieP_filter_level=1 # RANSAC filtering always produces different results because it includes random sampling

Daniel Scheffler's avatar
Daniel Scheffler committed
555
        self.CPUs = None
556
        dataframe = self.get_CoRegPoints_table()
557
558
        mp_out    = np.empty_like(dataframe.values)
        mp_out[:] = dataframe.values
Daniel Scheffler's avatar
Daniel Scheffler committed
559
        self.CPUs = 1
560
        dataframe = self.get_CoRegPoints_table()
561
562
563
564
565
566
        sp_out    = np.empty_like(dataframe.values)
        sp_out[:] = dataframe.values

        return np.array_equal(sp_out,mp_out)


567
568
    def _get_line_by_PID(self, PID):
        return self.CoRegPoints_table.loc[PID, :]
569
570


571
    def _get_lines_by_PIDs(self, PIDs):
572
        assert isinstance(PIDs,list)
573
        lines = np.zeros((len(PIDs),self.CoRegPoints_table.shape[1]))
574
        for i,PID in enumerate(PIDs):
575
            lines[i,:] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
576
577
578
        return lines


579
    def to_PointShapefile(self, path_out=None, skip_nodata=True, skip_nodata_col ='ABS_SHIFT'):
580
        # type: (str, bool, str) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
581
        """Writes the calculated tie points grid to a point shapefile containing
582
        Tie_Point_Grid.CoRegPoints_table as attribute table. This shapefile can easily be displayed using GIS software.
583
584
585

        :param path_out:        <str> the output path. If not given, it is automatically defined.
        :param skip_nodata:     <bool> whether to skip all points where no valid match could be found
586
        :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to
587
588
                                identify points where no valid match could be found
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
589

Daniel Scheffler's avatar
Daniel Scheffler committed
590
591
592
593
594
595
596
597
598
        GDF      = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

        # replace boolean values (cannot be written)
        for col in GDF2pass.columns:
            if GDF2pass[col].dtype == np.bool:
                GDF2pass[col] = GDF2pass[col].astype(int)
        GDF2pass = GDF2pass.replace(False, 0) # replace all remaining booleans where dtype is not np.bool but np.object
        GDF2pass = GDF2pass.replace(True,  1)
599

600
601
602
        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
603
604
        if not self.q:
            print('Writing %s ...' %path_out)
605
606
607
        GDF2pass.to_file(path_out)


608
    def _to_PointShapefile(self, skip_nodata=True, skip_nodata_col ='ABS_SHIFT'):
Daniel Scheffler's avatar
Daniel Scheffler committed
609
610
        warnings.warn(DeprecationWarning("'_tiepoints_grid_to_PointShapefile' is deprecated." # TODO delete if other method validated
                                         " 'tiepoints_grid_to_PointShapefile' is much faster."))
611
        GDF            = self.CoRegPoints_table
612
613
614
615
616
        GDF2pass       = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]
        shapely_points = GDF2pass['geometry'].values.tolist()
        attr_dicts     = [collections.OrderedDict(zip(GDF2pass.columns,GDF2pass.loc[i].values)) for i in GDF2pass.index]


617
        fName_out = "CoRegPoints_grid%s_ws%s.shp" %(self.grid_res, self.COREG_obj.win_size_XY)
618
619
620
621
        path_out = os.path.join(self.dir_out, fName_out)
        IO.write_shp(path_out, shapely_points, prj=self.COREG_obj.shift.prj, attrDict=attr_dicts)


622
    def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
623
        # type: (str) -> GeoArray
624
625
626
627
628
        """Saves the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield
        (e.g. using ArcGIS)

        :param path_out:    <str> the output path. If not given, it is automatically defined.
        :param fmt:         <str> output raster format string
629
630
        :param mode:        <str> The mode how the output is written ('uv' or 'md'; default: 'md')
                                    'uv': outputs X-/Y shifts
631
632
633
634
635
636
637
638
639
640
641
                                    'md': outputs magnitude and direction
        """

        assert mode in ['uv','md'], \
            "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' (outputs magnitude and direction)'. Got %s." %mode
        attr_b1 = 'X_SHIFT_M' if mode=='uv' else 'ABS_SHIFT'
        attr_b2 = 'Y_SHIFT_M' if mode=='uv' else 'ANGLE'

        xshift_arr, gt, prj = points_to_raster(points = self.CoRegPoints_table['geometry'],
                                               values = self.CoRegPoints_table[attr_b1],
                                               tgt_res = self.shift.xgsd * self.grid_res,
642
643
                                               prj = proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal = self.outFillVal)
644
645
646
647

        yshift_arr, gt, prj = points_to_raster(points = self.CoRegPoints_table['geometry'],
                                               values = self.CoRegPoints_table[attr_b2],
                                               tgt_res = self.shift.xgsd * self.grid_res,
648
649
                                               prj = proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal = self.outFillVal)
650
651
652
653
654
655
656
657
658
659
660
661

        out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)

        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))

        out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')

        return out_GA


662
    def _to_Raster_using_KrigingOLD(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
663
                                   path_out=None, tilepos=None):
664
665
666
        warnings.warn(DeprecationWarning("'to_Raster_using_KrigingOLD' is deprecated. Use to_Raster_using_Kriging "
                                         "instead.")) # TODO delete

667
        GDF             = self.CoRegPoints_table
668
669
670
        GDF2pass        = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

        # subset if tilepos is given
671
        rows,cols = tilepos if tilepos else self.shift.shape
672
673
674
675
676
677
678
679
680
681
682
683
684
        GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
                                       (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]


        X_coords,Y_coords,ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'],GDF2pass[attrName]

        xmin,ymin,xmax,ymax = GDF2pass.total_bounds

        grid_res            = outGridRes if outGridRes else int(min(xmax-xmin,ymax-ymin)/250)
        grid_x,grid_y       = np.arange(xmin, xmax+grid_res, grid_res), np.arange(ymax, ymin-grid_res, -grid_res)

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
685
        from pykrige.ok import OrdinaryKriging
686
687
688
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical',verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y)#,backend='C',)

689
690
691
        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="Kriging__%s__grid%s_ws(%s_%s).tif"  % (attrName, self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1]))
692
693
694
695
696
697
698
699
        print('Writing %s ...' %path_out)
        # add a half pixel grid points are centered on the output pixels
        xmin,ymin,xmax,ymax = xmin-grid_res/2,ymin-grid_res/2,xmax+grid_res/2,ymax+grid_res/2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res), prj=self.COREG_obj.shift.prj)

        return zvalues


700
    def to_Raster_using_Kriging(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
701
                             fName_out=None, tilepos=None, tilesize=500, mp=None):
702

703
704
705
        mp = False if self.CPUs==1 else True
        self._Kriging_sp(attrName, skip_nodata=skip_nodata, skip_nodata_col=skip_nodata_col,
                         outGridRes=outGridRes, fName_out=fName_out, tilepos=tilepos)
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728

        # if mp:
        #     tilepositions = UTL.get_image_tileborders([tilesize,tilesize],self.tgt_shape)
        #     args_kwargs_dicts=[]
        #     for tp in tilepositions:
        #         kwargs_dict = {'skip_nodata':skip_nodata,'skip_nodata_col':skip_nodata_col,'outGridRes':outGridRes,
        #                        'fName_out':fName_out,'tilepos':tp}
        #         args_kwargs_dicts.append({'args':[attrName],'kwargs':kwargs_dict})
        #     # self.kriged=[]
        #     # for i in args_kwargs_dicts:
        #     #     res = self.Kriging_mp(i)
        #     #     self.kriged.append(res)
        #     #     print(res)
        #
        #     with multiprocessing.Pool() as pool:
        #        self.kriged = pool.map(self.Kriging_mp,args_kwargs_dicts)
        # else:
        #     self.Kriging_sp(attrName,skip_nodata=skip_nodata,skip_nodata_col=skip_nodata_col,
        #                     outGridRes=outGridRes,fName_out=fName_out,tilepos=tilepos)
        res = self.kriged if mp else None
        return res


729
730
731
    def _Kriging_sp(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                    fName_out=None, tilepos=None):
        GDF             = self.CoRegPoints_table
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
        GDF2pass        = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

#         # subset if tilepos is given
# #        overlap_factor =
#         rows,cols = tilepos if tilepos else self.tgt_shape
#         xvals, yvals = np.sort(GDF2pass['X_IM'].values.flat),np.sort(GDF2pass['Y_IM'].values.flat)
#         cS,cE = UTL.find_nearest(xvals,cols[0],'off',1), UTL.find_nearest(xvals,cols[1],'on',1)
#         rS,rE = UTL.find_nearest(yvals,rows[0],'off',1), UTL.find_nearest(yvals,rows[1],'on',1)
#         # GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
#         #                                (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]
#         GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cS)&(GDF2pass['X_IM']<=cE)&
#                                        (GDF2pass['Y_IM']>=rS)&(GDF2pass['Y_IM']<=rE)]

        X_coords,Y_coords,ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'],GDF2pass[attrName]

        xmin,ymin,xmax,ymax = GDF2pass.total_bounds

        grid_res            = outGridRes if outGridRes else int(min(xmax-xmin,ymax-ymin)/250)
        grid_x,grid_y       = np.arange(xmin, xmax+grid_res, grid_res), np.arange(ymax, ymin-grid_res, -grid_res)

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
754
        from pykrige.ok import OrdinaryKriging
755
756
757
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical',verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y,backend='C',n_closest_points=12)

Daniel Scheffler's avatar
Daniel Scheffler committed
758
        if self.CPUs is None or self.CPUs>1:
759
            fName_out = fName_out if fName_out else \
760
                "Kriging__%s__grid%s_ws%s_%s.tif" %(attrName,self.grid_res, self.COREG_obj.win_size_XY,tilepos)
761
762
        else:
            fName_out = fName_out if fName_out else \
763
764
                "Kriging__%s__grid%s_ws%s.tif" %(attrName,self.grid_res, self.COREG_obj.win_size_XY)
        path_out  = get_generic_outpath(dir_out=self.dir_out, fName_out=fName_out)
765
766
767
768
769
770
771
772
        print('Writing %s ...' %path_out)
        # add a half pixel grid points are centered on the output pixels
        xmin,ymin,xmax,ymax = xmin-grid_res/2,ymin-grid_res/2,xmax+grid_res/2,ymax+grid_res/2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res), prj=self.COREG_obj.shift.prj)

        return zvalues


773
    def _Kriging_mp(self, args_kwargs_dict):
774
        args   = args_kwargs_dict.get('args'  ,[])
775
776
        kwargs = args_kwargs_dict.get('kwargs',[])

777
        return self._Kriging_sp(*args, **kwargs)
778
779
780



781
class Tie_Point_Refiner(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
782
783
    def __init__(self, GDF, min_reliability=60, rs_max_outlier=10, rs_tolerance=2.5, rs_max_iter=15,
                      rs_exclude_previous_outliers=True, rs_timeout=20, q=False):
784
        """A class for performing outlier detection.
Daniel Scheffler's avatar
Daniel Scheffler committed
785

786
787
        :param GDF:                             GeoDataFrame like TiePointGrid.CoRegPoints_table containing all tie
                                                points to be filtered and the corresponding metadata
Daniel Scheffler's avatar
Daniel Scheffler committed
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
        :param min_reliability:                 <float, int> minimum threshold for previously computed tie X/Y shift
                                                reliability (default: 60%)
        :param rs_max_outlier:                  <float, int> RANSAC: maximum percentage of outliers to be detected
                                                (default: 10%)
        :param rs_tolerance:                    <float, int> RANSAC: percentage tolerance for max_outlier_percentage
                                                (default: 2.5%)
        :param rs_max_iter:                     <int> RANSAC: maximum iterations for finding the best RANSAC threshold
                                                (default: 15)
        :param rs_exclude_previous_outliers:    <bool> RANSAC: whether to exclude points that have been flagged as
                                                outlier by earlier filtering (default:True)
        :param rs_timeout:                      <float, int> RANSAC: timeout for iteration loop in seconds (default: 20)

        :param q:
        """

        self.GDF = GDF.copy()
        self.min_reliability = min_reliability
        self.rs_max_outlier_percentage = rs_max_outlier
        self.rs_tolerance = rs_tolerance
        self.rs_max_iter = rs_max_iter
        self.rs_exclude_previous_outliers = rs_exclude_previous_outliers
        self.rs_timeout = rs_timeout
        self.q = q
        self.new_cols = []
812
813
814
        self.ransac_model_robust = None


815
    def run_filtering(self, level=3):
816
817
        """Filter tie points used for shift correction.

818
        :param level:   tie point filter level (default: 3).
819
820
821
822
823
824
825
826
                        NOTE: lower levels are also included if a higher level is chosen
                            - Level 0: no tie point filtering
                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                reliability according to internal tests
                            - Level 2: SSIM filtering - filters all tie points out where shift
                                correction does not increase image similarity within matching window
                                (measured by mean structural similarity index)
                            - Level 3: RANSAC outlier detection
Daniel Scheffler's avatar
Daniel Scheffler committed
827
828
829
830

        :return:
        """

831
832
        # TODO catch empty GDF

833
        # RELIABILITY filtering
834
        if level>0:
835
            marked_recs = GeoSeries(self._reliability_thresholding())
836
837
            self.GDF['L1_OUTLIER'] = marked_recs
            self.new_cols.append('L1_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
838

839
840
841
            if not self.q:
                print('%s tie points flagged by level 1 filtering (reliability).' % (len(marked_recs[marked_recs==True])))

Daniel Scheffler's avatar
Daniel Scheffler committed
842

843
        # SSIM filtering
844
        if level>1:
845
            marked_recs = GeoSeries(self._SSIM_filtering())
846
847
            self.GDF['L2_OUTLIER'] = marked_recs
            self.new_cols.append('L2_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
848

849
850
851
            if not self.q:
                print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs==True])))

Daniel Scheffler's avatar
Daniel Scheffler committed
852

853
        # RANSAC filtering
854
        if level>2:
Daniel Scheffler's avatar
Daniel Scheffler committed
855
856
857
858
859
            # exclude previous outliers
            ransacInGDF = self.GDF[self.GDF[self.new_cols].any(axis=1) == False].copy()\
                            if self.rs_exclude_previous_outliers else self.GDF

            if len(ransacInGDF)>4:
860
                # running RANSAC with less than four tie points makes no sense
Daniel Scheffler's avatar
Daniel Scheffler committed
861
862

                marked_recs = GeoSeries(self._RANSAC_outlier_detection(ransacInGDF))
863
                self.GDF['L3_OUTLIER'] = marked_recs.tolist() # we need to join a list here because otherwise it's merged by the 'index' column
Daniel Scheffler's avatar
Daniel Scheffler committed
864

865
                if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
866
                    print('%s tie points flagged by level 3 filtering (RANSAC)' % (len(marked_recs[marked_recs == True])))
867
868
869
870
            else:
                print('RANSAC skipped because too less valid tie points have been found.')
                self.GDF['L3_OUTLIER'] = False

871
            self.new_cols.append('L3_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
872

873
874
875
876
877
878
879

        self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
        self.new_cols.append('OUTLIER')

        return self.GDF, self.new_cols


Daniel Scheffler's avatar
Daniel Scheffler committed
880
    def _reliability_thresholding(self):
881
        """Exclude all records where estimated reliability of the calculated shifts is below the given threshold."""
882

Daniel Scheffler's avatar
Daniel Scheffler committed
883
        return self.GDF.RELIABILITY < self.min_reliability
884
885
886


    def _SSIM_filtering(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
887
        """Exclude all records where SSIM decreased."""
888
889
890
891
892

        #ssim_diff  = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE'])

        #self.GDF.SSIM_IMPROVED = self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1)

893
894
895
        return self.GDF.SSIM_IMPROVED == False


Daniel Scheffler's avatar
Daniel Scheffler committed
896
897
    def _RANSAC_outlier_detection(self, inGDF):
        """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm."""
898

Daniel Scheffler's avatar
Daniel Scheffler committed
899
900
        src_coords = np.array(inGDF[['X_UTM', 'Y_UTM']])
        xyShift    = np.array(inGDF[['X_SHIFT_M', 'Y_SHIFT_M']])
901
902
903
904
905
        est_coords = src_coords + xyShift

        for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']):
            assert co.ndim==2 and co.shape[1]==2, "'%s' must have shape [Nx2]. Got shape %s."%(n, co.shape)

Daniel Scheffler's avatar
Daniel Scheffler committed
906
907
        if not 0 < self.rs_max_outlier_percentage < 100: raise ValueError
        min_inlier_percentage = 100-self.rs_max_outlier_percentage
908
909
910
911
912
913

        class PolyTF_1(PolynomialTransform):
            def estimate(*data):
                return PolynomialTransform.estimate(*data, order=1)

        # robustly estimate affine transform model with RANSAC
914
        # eliminates not more than the given maximum outlier percentage of the tie points
915
916
917

        model_robust, inliers = None, None
        count_inliers         = None
918
        th                    = 5  # start RANSAC threshold
919
920
921
922
        th_checked            = {} # dict of thresholds that already have been tried + calculated inlier percentage
        th_substract          = 2
        count_iter            = 0
        time_start            = time.time()
923
        ideal_count           = min_inlier_percentage * src_coords.shape[0] / 100
924

925
        # optimize RANSAC threshold so that it marks not much more or less than the given outlier percentage
926
927
        while True:
            if th_checked:
928
929
                th_too_strict = count_inliers < ideal_count # True if too less inliers remaining

930
931
932
933
934
935
                # calculate new theshold using old increment (but ensure th_new>0 by adjusting increment if needed)
                th_new = 0
                while th_new <= 0:
                    th_new = th+th_substract if th_too_strict else th-th_substract
                    if th_new <= 0:
                        th_substract /=2
936
937
938
939
940
941
942
943

                # check if calculated new threshold has been used before
                th_already_checked = th_new in th_checked.keys()

                # if yes, decrease increment and recalculate new threshold
                th_substract       = th_substract if not th_already_checked else th_substract / 2
                th                 = th_new if not th_already_checked else \
                                        (th+th_substract if th_too_strict else th-th_substract)
944

945
            # RANSAC call
946
            # model_robust, inliers = ransac((src, dst), PolynomialTransform, min_samples=3,
947
948
949
950
951
952
            if src_coords.size and est_coords.size:
                model_robust, inliers = \
                    ransac((src_coords, est_coords), AffineTransform,
                           min_samples        = 6,
                           residual_threshold = th,
                           max_trials         = 2000,
Daniel Scheffler's avatar
Daniel Scheffler committed
953
954
                           stop_sample_num    = int((min_inlier_percentage-self.rs_tolerance) /100*src_coords.shape[0]),
                           stop_residuals_sum = int((self.rs_max_outlier_percentage-self.rs_tolerance)/100*src_coords.shape[0])
955
956
957
958
959
                           )
            else:
                inliers = np.array([])
                break

960
            count_inliers  = np.count_nonzero(inliers)
961

962
            th_checked[th] = count_inliers / src_coords.shape[0] * 100
963
            #print(th,'\t', th_checked[th], )
Daniel Scheffler's avatar
Daniel Scheffler committed
964
            if min_inlier_percentage-self.rs_tolerance < th_checked[th] < min_inlier_percentage+self.rs_tolerance:
965
966
                #print('in tolerance')
                break
Daniel Scheffler's avatar
Daniel Scheffler committed
967
            if count_iter > self.rs_max_iter or time.time()-time_start > self.rs_timeout:
968
969
                break # keep last values and break while loop

970
971
            count_iter+=1

Daniel Scheffler's avatar
Daniel Scheffler committed
972
        outliers = inliers == False if inliers is not None and inliers.size else np.array([])
973

Daniel Scheffler's avatar
Daniel Scheffler committed
974
        if inGDF.empty or outliers is None or (isinstance(outliers, list) and not outliers) or \
975
                (isinstance(outliers, np.ndarray) and not outliers.size):
976
            gs              = GeoSeries([False]*len(self.GDF))
Daniel Scheffler's avatar
Daniel Scheffler committed
977
978
        elif len(inGDF) < len(self.GDF):
            inGDF['outliers'] = outliers
979
            fullGDF         = GeoDataFrame(self.GDF['POINT_ID'])
Daniel Scheffler's avatar
Daniel Scheffler committed
980
            fullGDF         = fullGDF.merge(inGDF[['POINT_ID', 'outliers']], on='POINT_ID', how="outer")
981
982
            #fullGDF.outliers.copy()[~fullGDF.POINT_ID.isin(GDF.POINT_ID)] = False
            fullGDF         = fullGDF.fillna(False) # NaNs are due to exclude_previous_outliers
983
984
985
986
987
988
            gs              = fullGDF['outliers']
        else:
            gs              = GeoSeries(outliers)

        assert len(gs)==len(self.GDF), 'RANSAC output validation failed.'

989
        self.ransac_model_robust = model_robust
990
991

        return gs
992