Tie_Point_Grid.py 48 KB
Newer Older
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-
__author__='Daniel Scheffler'

import collections
import multiprocessing
import os
import warnings
8
import time
9
10

# custom
11
12
13
14
try:
    import gdal
except ImportError:
    from osgeo import gdal
15
import numpy as np
16
from matplotlib import pyplot as plt
17
from geopandas         import GeoDataFrame, GeoSeries
18
19
from pykrige.ok        import OrdinaryKriging
from shapely.geometry  import Point
20
21
from skimage.measure   import points_in_poly, ransac
from skimage.transform import AffineTransform, PolynomialTransform
22
23

# internal modules
24
25
from .CoReg import COREG
from . import io as IO
26
27
28
29
from py_tools_ds.geo.projection          import isProjectedOrGeographic, get_UTMzone, dict_to_proj4, proj4_to_WKT
from py_tools_ds.io.pathgen              import get_generic_outpath
from py_tools_ds.processing.progress_mon import ProgressBar
from py_tools_ds.geo.vector.conversion   import points_to_raster
30
from geoarray import GeoArray
31
32
33
34
35
36
37



global_shared_imref    = None
global_shared_im2shift = None


38
39
class Tie_Point_Grid(object):
    """See help(Tie_Point_Grid) for documentation!"""
40

41
    def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
42
                 tieP_filter_level=3, outlDetect_settings=None, dir_out=None, CPUs=None, progress=True, v=False, q=False):
43

44
45
46
        """Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
        are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
        correction performs a polynomial transformation using te calculated shifts of each point in the grid as GCPs.
47
        Thus 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target image.
48

49
        :param COREG_obj(object):       an instance of COREG class
50
        :param grid_res:                grid resolution in pixels of the target image (x-direction)
51
        :param max_points(int):         maximum number of points used to find coregistration tie points
52
53
54
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
Daniel Scheffler's avatar
Daniel Scheffler committed
55
        :param outFillVal(int):         if given the generated tie points grid is filled with this value in case
56
                                        no match could be found during co-registration (default: -9999)
57
58
59
60
61
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average, mode,
                                                       max, min, med, q1, q3)
                                        default: cubic (highly recommended)
62
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 3).
63
                                        NOTE: lower levels are also included if a higher level is chosen
64
                                            - Level 0: no tie point filtering
65
66
67
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
68
69
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
70
                                            - Level 3: RANSAC outlier detection
71
72
73
74
        :param outlDetect_settings      a dictionary with the settings to be passed to
                                        arosics.TiePointGrid.Tie_Point_Refiner. Available keys: min_reliability,
                                        rs_max_outlier, rs_tolerance, rs_max_iter, rs_exclude_previous_outliers,
                                        rs_timeout, q. See documentation there.
75
76
        :param dir_out(str):            output directory to be used for all outputs if nothing else is given
                                        to the individual methods
Daniel Scheffler's avatar
Daniel Scheffler committed
77
        :param CPUs(int):               number of CPUs to use during calculation of tie points grid
78
                                        (default: None, which means 'all CPUs available')
79
        :param progress(bool):          show progress bars (default: True)
80
81
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
82
        """
83
84
85

        if not isinstance(COREG_obj, COREG): raise ValueError("'COREG_obj' must be an instance of COREG class.")

86
87
        self.COREG_obj         = COREG_obj
        self.grid_res          = grid_res
88
        self.max_points        = max_points
89
90
91
        self.outFillVal        = outFillVal
        self.rspAlg_calc       = resamp_alg_calc
        self.tieP_filter_level = tieP_filter_level
92
        self.outlDetect_settings = outlDetect_settings if outlDetect_settings else dict(q=q)
93
94
95
96
97
        self.dir_out           = dir_out
        self.CPUs              = CPUs
        self.v                 = v
        self.q                 = q        if not v else False # overridden by v
        self.progress          = progress if not q else False # overridden by q
98

99
100
        self.ref               = self.COREG_obj.ref
        self.shift             = self.COREG_obj.shift
101

102
103


104
        self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
105
106
107
        self._CoRegPoints_table           = None # set by self.CoRegPoints_table
        self._GCPList                     = None # set by self.to_GCPList()
        self.kriged                       = None # set by Raster_using_Kriging()
108
109


110
111
112
113
114
115
116
117
118
119
    mean_x_shift_px   = property(lambda self:
        self.CoRegPoints_table['X_SHIFT_PX'][self.CoRegPoints_table['X_SHIFT_PX']!=self.outFillVal].mean())
    mean_y_shift_px   = property(lambda self:
        self.CoRegPoints_table['Y_SHIFT_PX'][self.CoRegPoints_table['Y_SHIFT_PX']!=self.outFillVal].mean())
    mean_x_shift_map  = property(lambda self:
        self.CoRegPoints_table['X_SHIFT_M'][self.CoRegPoints_table['X_SHIFT_M']!=self.outFillVal].mean())
    mean_y_shift_map  = property(lambda self:
        self.CoRegPoints_table['Y_SHIFT_M'][self.CoRegPoints_table['Y_SHIFT_M']!=self.outFillVal].mean())


120
121
    @property
    def CoRegPoints_table(self):
122
123
        """Returns a GeoDataFrame with the columns 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE',
        'Y_WIN_SIZE','X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE' containing all
124
        information containing all the results from coregistration for all points in the tie points grid.
125
        """
126
127
128
129
130
131
        if self._CoRegPoints_table is not None:
            return self._CoRegPoints_table
        else:
            self._CoRegPoints_table = self.get_CoRegPoints_table()
            return self._CoRegPoints_table

132

133
134
135
136
137
138
139
    @CoRegPoints_table.setter
    def CoRegPoints_table(self, CoRegPoints_table):
        self._CoRegPoints_table = CoRegPoints_table


    @property
    def GCPList(self):
140
141
        """Returns a list of GDAL compatible GCP objects.
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
142

143
144
145
146
        if self._GCPList:
            return self._GCPList
        else:
            self._GCPList = self.to_GCPList()
147
            return self._GCPList
148
149
150
151
152
153
154
155


    @GCPList.setter
    def GCPList(self, GCPList):
        self._GCPList = GCPList


    def _get_imXY__mapXY_points(self, grid_res):
156
157
158
159
160
161
        """Returns a numpy array containing possible positions for coregistration tie points according to the given
        grid resolution.

        :param grid_res:
        :return:
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
162

163
        if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
164
            print('Initializing tie points grid...')
165

166
167
168
        Xarr,Yarr       = np.meshgrid(np.arange(0,self.shift.shape[1],grid_res),
                                      np.arange(0,self.shift.shape[0],grid_res))

Daniel Scheffler's avatar
Daniel Scheffler committed
169
170
        mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr*self.shift.gt[1]
        mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr*abs(self.shift.gt[5])
171
172
173
174
175
176
177
178
179

        XY_points      = np.empty((Xarr.size,2),Xarr.dtype)
        XY_points[:,0] = Xarr.flat
        XY_points[:,1] = Yarr.flat

        XY_mapPoints      = np.empty((mapXarr.size,2),mapXarr.dtype)
        XY_mapPoints[:,0] = mapXarr.flat
        XY_mapPoints[:,1] = mapYarr.flat

Daniel Scheffler's avatar
Daniel Scheffler committed
180
181
        assert XY_points.shape == XY_mapPoints.shape

182
183
184
        return XY_points,XY_mapPoints


185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def _exclude_bad_XYpos(self, GDF):
        """Excludes all points outside of the image overlap area and all points where the bad data mask is True (if given).

        :param GDF:     <geopandas.GeoDataFrame> must include the columns 'X_UTM' and 'Y_UTM'
        :return:
        """

        # exclude all points outside of overlap area
        inliers = points_in_poly(self.XY_mapPoints,
                                 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
        GDF = GDF[inliers].copy()
        #GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower

        assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!' # FIXME track that


        # exclude all point where bad data mask is True (e.g. points on clouds etc.)
Daniel Scheffler's avatar
Daniel Scheffler committed
202
        orig_len_GDF       = len(GDF) # length of GDF after dropping all points outside the overlap polygon
203
204
205
206
207
208
209
        mapXY              = np.array(GDF.loc[:,['X_UTM','Y_UTM']])
        GDF['REF_BADDATA'] = self.COREG_obj.ref  .mask_baddata.read_pointData(mapXY) \
                                if self.COREG_obj.ref  .mask_baddata is not None else False
        GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY)\
                                if self.COREG_obj.shift.mask_baddata is not None else False
        GDF                = GDF[(GDF['REF_BADDATA']==False) & (GDF['TGT_BADDATA']==False)]
        if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
Daniel Scheffler's avatar
Daniel Scheffler committed
210
211
212
            if not self.q:
                print('According to the provided bad data mask(s) %s points of initially %s have been excluded.'
                      %(orig_len_GDF-len(GDF), orig_len_GDF))
213
214
215
216

        return GDF


217
218
    @staticmethod
    def _get_spatial_shifts(coreg_kwargs):
Daniel Scheffler's avatar
Daniel Scheffler committed
219
        # unpack
220
221
222
        pointID    = coreg_kwargs['pointID']
        fftw_works = coreg_kwargs['fftw_works']
        del coreg_kwargs['pointID'], coreg_kwargs['fftw_works']
223

Daniel Scheffler's avatar
Daniel Scheffler committed
224
        # assertions
225
226
        assert global_shared_imref    is not None
        assert global_shared_im2shift is not None
Daniel Scheffler's avatar
Daniel Scheffler committed
227
228

        # run CoReg
229
        CR = COREG(global_shared_imref, global_shared_im2shift, CPUs=1, **coreg_kwargs)
230
        CR.fftw_works = fftw_works
231
        CR.calculate_spatial_shifts()
Daniel Scheffler's avatar
Daniel Scheffler committed
232
233

        # fetch results
234
        last_err           = CR.tracked_errors[-1] if CR.tracked_errors else None
235
236
237
        win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
        CR_res   = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
                    CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
238
                    CR.shift_reliability, last_err]
239
240

        return [pointID]+CR_res
241
242


243
    def get_CoRegPoints_table(self):
244
245
        assert self.XY_points is not None and self.XY_mapPoints is not None

Daniel Scheffler's avatar
Daniel Scheffler committed
246
        # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM' (convert imCoords to mapCoords
247
        XYarr2PointGeom = np.vectorize(lambda X,Y: Point(X,Y), otypes=[Point])
248
249
250
251
252
253
254
        geomPoints      = np.array(XYarr2PointGeom(self.XY_mapPoints[:,0],self.XY_mapPoints[:,1]))

        if isProjectedOrGeographic(self.COREG_obj.shift.prj)=='geographic':
            crs = dict(ellps='WGS84', datum='WGS84', proj='longlat')
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj)=='projected':
            UTMzone = abs(get_UTMzone(prj=self.COREG_obj.shift.prj))
            south   = get_UTMzone(prj=self.COREG_obj.shift.prj)<0
255
            crs     = dict(ellps='WGS84', datum='WGS84', proj='utm', zone=UTMzone, south=south, units='m', no_defs=True)
256
257
258
259
            if not south: del crs['south']
        else:
            crs = None

260
261
        GDF                          = GeoDataFrame(index=range(len(geomPoints)),crs=crs,
                                                    columns=['geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'])
262
263
        GDF       ['geometry']       = geomPoints
        GDF       ['POINT_ID']       = range(len(geomPoints))
264
        GDF.loc[:,['X_IM' ,'Y_IM' ]] = self.XY_points
265
        GDF.loc[:,['X_UTM','Y_UTM']] = self.XY_mapPoints
266

267
268
        # exclude offsite points and points on bad data mask
        GDF = self._exclude_bad_XYpos(GDF)
269
270
271
        if GDF.empty:
            self.CoRegPoints_table = GDF
            return self.CoRegPoints_table
272

273
        # choose a random subset of points if a maximum number has been given
274
        if self.max_points and len(GDF) > self.max_points:
275
            GDF = GDF.sample(self.max_points).copy()
276

277
278
279
280
        # equalize pixel grids in order to save warping time
        if len(GDF)>100:
            self.COREG_obj.equalize_pixGrids() # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is connected to that

281
282
        # declare global variables needed for self._get_spatial_shifts()
        global global_shared_imref,global_shared_im2shift
283
284
        assert self.ref  .footprint_poly # this also checks for mask_nodata and nodata value
        assert self.shift.footprint_poly
285
286
287

        # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if
        # neighboured matching windows overlap during reading from disk!!
Daniel Scheffler's avatar
Daniel Scheffler committed
288
289
290
        self.ref.cache_array_subset([self.COREG_obj.ref.band4match]) # only sets geoArr._arr_cache; does not change number of bands
        self.shift.cache_array_subset([self.COREG_obj.shift.band4match])

291
292
        global_shared_imref    = self.ref
        global_shared_im2shift = self.shift
293
294

        # get all variations of kwargs for coregistration
295
296
        get_coreg_kwargs = lambda pID, wp: dict(
            pointID            = pID,
297
            fftw_works         = self.COREG_obj.fftw_works,
298
299
300
301
302
            wp                 = wp,
            ws                 = self.COREG_obj.win_size_XY,
            resamp_alg_calc    = self.rspAlg_calc,
            footprint_poly_ref = self.COREG_obj.ref.poly,
            footprint_poly_tgt = self.COREG_obj.shift.poly,
Daniel Scheffler's avatar
Daniel Scheffler committed
303
304
            r_b4match          = self.ref.band4match+1,   # band4match is internally saved as index, starting from 0
            s_b4match          = self.shift.band4match+1, # band4match is internally saved as index, starting from 0
305
306
307
            max_iter           = self.COREG_obj.max_iter,
            max_shift          = self.COREG_obj.max_shift,
            nodata             = (self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
Daniel Scheffler's avatar
Daniel Scheffler committed
308
            force_quadratic_win= self.COREG_obj.force_quadratic_win,
309
310
311
312
313
            binary_ws          = self.COREG_obj.bin_ws,
            v                  = False, # otherwise this would lead to massive console output
            q                  = True,  # otherwise this would lead to massive console output
            ignore_errors      = True
        )
314
315
316
        list_coreg_kwargs = (get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index) # generator

        # run co-registration for whole grid
317
        if self.CPUs is None or self.CPUs>1:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
318
            if not self.q:
319
                cpus = self.CPUs if self.CPUs is not None else multiprocessing.cpu_count()
Daniel Scheffler's avatar
Daniel Scheffler committed
320
                print("Calculating tie points grid (%s points) using %s CPU cores..." %(len(GDF), cpus))
321

322
            with multiprocessing.Pool(self.CPUs) as pool:
323
324
325
326
                if self.q or not self.progress:
                    results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
                else:
                    results = pool.map_async(self._get_spatial_shifts, list_coreg_kwargs, chunksize=1)
327
                    bar     = ProgressBar(prefix='\tprogress:')
328
329
330
                    while True:
                        time.sleep(.1)
                        numberDone = len(GDF)-results._number_left # this does not really represent the remaining tasks but the remaining chunks -> thus chunksize=1
331
332
                        if self.progress:
                            bar.print_progress(percent=numberDone/len(GDF)*100)
333
                        if results.ready():
334
                            results = results.get() # <= this is the line where multiprocessing can freeze if an exception appears within COREG ans is not raised
335
                            break
Daniel Scheffler's avatar
Daniel Scheffler committed
336

337
        else:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
338
            if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
339
                print("Calculating tie points grid (%s points) 1 CPU core..." %len(GDF))
340
            results = np.empty((len(geomPoints),14), np.object)
341
            bar     = ProgressBar(prefix='\tprogress:')
342
            for i,coreg_kwargs in enumerate(list_coreg_kwargs):
343
344
                if self.progress:
                    bar.print_progress((i+1)/len(GDF)*100)
345
346
                results[i,:] = self._get_spatial_shifts(coreg_kwargs)

347
         # merge results with GDF
348
349
350
        records = GeoDataFrame(np.array(results, np.object),
                               columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M',
                                        'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
351
352
                                        'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])

353
354
355
        GDF = GDF.merge(records, on='POINT_ID', how="inner")
        GDF = GDF.fillna(int(self.outFillVal))

356
357
358
        if not self.q:
            print("Found %s matches." % len(GDF[GDF.LAST_ERR == int(self.outFillVal)]))

359
360
        # filter tie points according to given filter level
        if self.tieP_filter_level>0:
361
362
            if not self.q:
                print('Performing validity checks...')
363
            TPR                   = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], **self.outlDetect_settings)
364
365
366
            GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
            GDF                   = GDF.merge(GDF_filt[ ['POINT_ID']+new_columns], on='POINT_ID', how="outer")
        GDF = GDF.fillna(int(self.outFillVal))
367

368
        self.CoRegPoints_table = GDF
369
370
371
372

        return self.CoRegPoints_table


373
374
375
376
    def calc_rmse(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates root mean square error of absolute shifts from the tie point grid.

Daniel Scheffler's avatar
Daniel Scheffler committed
377
        :param include_outliers:    whether to include tie points that have been marked as false-positives (if present)
378
379
380
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
381
        tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == False].copy() if 'OUTLIER' in tbl.columns else tbl
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

        shifts = np.array(tbl['ABS_SHIFT'])
        shifts_sq = [i * i for i in shifts if i != self.outFillVal]

        return np.sqrt(sum(shifts_sq) / len(shifts_sq))


    def calc_overall_mssim(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates the median value of all MSSIM values contained in tie point grid.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        """

        tbl = self.CoRegPoints_table
        tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == False].copy()

        mssim_col = np.array(tbl['MSSIM'])
        mssim_col = [i * i for i in mssim_col if i != self.outFillVal]

        return float(np.median(mssim_col))


    def plot_shift_distribution(self, include_outliers=True, unit='m', interactive=False, figsize=None, xlim=None,
                                ylim=None, fontsize=12, title='shift distribution'):
        # type: (bool, str, bool, tuple, list, list, int) -> tuple
        """Creates a 2D scatterplot containing the distribution of calculated X/Y-shifts.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        :param unit:                'm' for meters or 'px' for pixels (default: 'm')
        :param interactive:         interactive mode uses plotly for visualization
        :param figsize:             (xdim, ydim)
        :param xlim:                [xmin, xmax]
        :param ylim:                [ymin, ymax]
        :param fontsize:            size of all used fonts
        :param title:               the title to be plotted above the figure
        """

        if not unit in ['m', 'px']:
            raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." %unit)

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
424
425
426
        tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal]
        tbl_il = tbl[tbl['OUTLIER'] == False].copy() if 'OUTLIER' in tbl.columns else tbl
        tbl_ol = tbl[tbl['OUTLIER'] == True].copy() if 'OUTLIER' in tbl.columns else None
427
428
429
430
431
432
433
434
        x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
        y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
        rmse   = self.calc_rmse(include_outliers=False) # always exclude outliers when calculating RMSE
        figsize = figsize if figsize else (10,10)

        if interactive:
            from plotly.offline import iplot, init_notebook_mode
            import plotly.graph_objs as go
Daniel Scheffler's avatar
Daniel Scheffler committed
435
            # FIXME outliers are not plotted
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456

            init_notebook_mode(connected=True)

            # Create a trace
            trace = go.Scatter(
                x=tbl_il[x_attr],
                y=tbl_il[y_attr],
                mode='markers'
            )

            data = [trace]

            # Plot and embed in ipython notebook!
            iplot(data, filename='basic-scatter')

            return None, None

        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111)

Daniel Scheffler's avatar
Daniel Scheffler committed
457
            if include_outliers and 'OUTLIER' in tbl.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
458
                ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives')
459
460
461
462
463
464
465
466
467
468
469
470
            ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')

            # set axis limits
            if not xlim:
                xmax = np.abs(tbl_il[x_attr]).max()
                xlim = [-xmax, xmax]
            if not ylim:
                ymax = np.abs(tbl_il[y_attr]).max()
                ylim = [-ymax, ymax]
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

Daniel Scheffler's avatar
Daniel Scheffler committed
471
            # add text box containing RMSE of plotted shifts
472
473
474
475
            xlim, ylim = ax.get_xlim(), ax.get_ylim()
            plt.text(xlim[1]-(xlim[1]/20),-ylim[1]+(ylim[1]/20),
                     'RMSE:  %s m / %s px' %(np.round(rmse, 2), np.round(rmse/self.shift.xgsd, 2)),
                     ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
476

Daniel Scheffler's avatar
Daniel Scheffler committed
477
            # add grid and increase linewidth of middle line
478
479
480
            plt.grid()
            xgl = ax.get_xgridlines()
            middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
481
            middle_xgl.set_linewidth(2)
482
483
484
            middle_xgl.set_linestyle('-')
            ygl = ax.get_ygridlines()
            middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
485
            middle_ygl.set_linewidth(2)
486
487
            middle_ygl.set_linestyle('-')

Daniel Scheffler's avatar
Daniel Scheffler committed
488
489
            # set title and adjust tick labels
            ax.set_title(title, fontsize=fontsize)
490
491
            [tick.label.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
            [tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
Daniel Scheffler's avatar
Daniel Scheffler committed
492
493
            plt.xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
            plt.ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
494

495
496
            # add legend with labels in the right order
            handles, labels = ax.get_legend_handles_labels()
Daniel Scheffler's avatar
Daniel Scheffler committed
497
498
            leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3)
            leg.get_frame().set_edgecolor('black')
499

500
501
502
503
504
            plt.show()

            return fig, ax


505
506
507
508
509
510
511
    def dump_CoRegPoints_table(self, path_out=None):
        path_out = path_out if path_out else get_generic_outpath(dir_out=self.dir_out,
            fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
        if not self.q:
            print('Writing %s ...' % path_out)
        self.CoRegPoints_table.to_pickle(path_out)
512
513


514
    def to_GCPList(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
515
        # get copy of tie points grid without no data
Daniel Scheffler's avatar
Daniel Scheffler committed
516
517
518
519
520
        try:
            GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
        except AttributeError:
            # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded
            return []
521

522
523
524
        if getattr(GDF,'empty'): # GDF.empty returns AttributeError
            return []
        else:
525
            # exclude all points flagged as outliers
526
527
            if 'OUTLIER' in GDF.columns:
                GDF = GDF[GDF.OUTLIER == False].copy()
528
529
            avail_TP = len(GDF)

530
531
532
533
            if not avail_TP:
                # no point passed all validity checks
                return []

534
535
536
537
538
            if avail_TP>7000:
                GDF = GDF.sample(7000)
                warnings.warn('By far not more than 7000 tie points can be used for warping within a limited '
                              'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen '
                              'out of the %s available tie points.' %avail_TP)
539

540
541
542
543
544
545
546
547
            # calculate GCPs
            GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
            GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
            GDF['GCP']       = GDF.apply(lambda GDF_row:    gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
                                                                     GDF_row.X_IM, GDF_row.Y_IM), axis=1)
            self.GCPList = GDF.GCP.tolist()

            if not self.q:
548
                print('Found %s valid tie points.' %len(self.GCPList))
549
550

            return self.GCPList
551
552


553
    def test_if_singleprocessing_equals_multiprocessing_result(self):
554
555
        self.tieP_filter_level=1 # RANSAC filtering always produces different results because it includes random sampling

Daniel Scheffler's avatar
Daniel Scheffler committed
556
        self.CPUs = None
557
        dataframe = self.get_CoRegPoints_table()
558
559
        mp_out    = np.empty_like(dataframe.values)
        mp_out[:] = dataframe.values
Daniel Scheffler's avatar
Daniel Scheffler committed
560
        self.CPUs = 1
561
        dataframe = self.get_CoRegPoints_table()
562
563
564
565
566
567
        sp_out    = np.empty_like(dataframe.values)
        sp_out[:] = dataframe.values

        return np.array_equal(sp_out,mp_out)


568
569
    def _get_line_by_PID(self, PID):
        return self.CoRegPoints_table.loc[PID, :]
570
571


572
    def _get_lines_by_PIDs(self, PIDs):
573
        assert isinstance(PIDs,list)
574
        lines = np.zeros((len(PIDs),self.CoRegPoints_table.shape[1]))
575
        for i,PID in enumerate(PIDs):
576
            lines[i,:] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
577
578
579
        return lines


580
    def to_PointShapefile(self, path_out=None, skip_nodata=True, skip_nodata_col ='ABS_SHIFT'):
581
        # type: (str, bool, str) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
582
        """Writes the calculated tie points grid to a point shapefile containing
583
        Tie_Point_Grid.CoRegPoints_table as attribute table. This shapefile can easily be displayed using GIS software.
584
585
586

        :param path_out:        <str> the output path. If not given, it is automatically defined.
        :param skip_nodata:     <bool> whether to skip all points where no valid match could be found
587
        :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to
588
589
                                identify points where no valid match could be found
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
590

Daniel Scheffler's avatar
Daniel Scheffler committed
591
592
593
594
595
596
597
598
599
        GDF      = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

        # replace boolean values (cannot be written)
        for col in GDF2pass.columns:
            if GDF2pass[col].dtype == np.bool:
                GDF2pass[col] = GDF2pass[col].astype(int)
        GDF2pass = GDF2pass.replace(False, 0) # replace all remaining booleans where dtype is not np.bool but np.object
        GDF2pass = GDF2pass.replace(True,  1)
600

601
602
603
        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
604
605
        if not self.q:
            print('Writing %s ...' %path_out)
606
607
608
        GDF2pass.to_file(path_out)


609
    def _to_PointShapefile(self, skip_nodata=True, skip_nodata_col ='ABS_SHIFT'):
Daniel Scheffler's avatar
Daniel Scheffler committed
610
611
        warnings.warn(DeprecationWarning("'_tiepoints_grid_to_PointShapefile' is deprecated." # TODO delete if other method validated
                                         " 'tiepoints_grid_to_PointShapefile' is much faster."))
612
        GDF            = self.CoRegPoints_table
613
614
615
616
617
        GDF2pass       = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]
        shapely_points = GDF2pass['geometry'].values.tolist()
        attr_dicts     = [collections.OrderedDict(zip(GDF2pass.columns,GDF2pass.loc[i].values)) for i in GDF2pass.index]


618
        fName_out = "CoRegPoints_grid%s_ws%s.shp" %(self.grid_res, self.COREG_obj.win_size_XY)
619
620
621
622
        path_out = os.path.join(self.dir_out, fName_out)
        IO.write_shp(path_out, shapely_points, prj=self.COREG_obj.shift.prj, attrDict=attr_dicts)


623
    def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
624
        # type: (str) -> GeoArray
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        """Saves the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield
        (e.g. using ArcGIS)

        :param path_out:    <str> the output path. If not given, it is automatically defined.
        :param fmt:         <str> output raster format string
        :param mode:        <str>   'uv': outputs X-/Y shifts
                                    'md': outputs magnitude and direction
        """

        assert mode in ['uv','md'], \
            "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' (outputs magnitude and direction)'. Got %s." %mode
        attr_b1 = 'X_SHIFT_M' if mode=='uv' else 'ABS_SHIFT'
        attr_b2 = 'Y_SHIFT_M' if mode=='uv' else 'ANGLE'

        xshift_arr, gt, prj = points_to_raster(points = self.CoRegPoints_table['geometry'],
                                               values = self.CoRegPoints_table[attr_b1],
                                               tgt_res = self.shift.xgsd * self.grid_res,
642
643
                                               prj = proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal = self.outFillVal)
644
645
646
647

        yshift_arr, gt, prj = points_to_raster(points = self.CoRegPoints_table['geometry'],
                                               values = self.CoRegPoints_table[attr_b2],
                                               tgt_res = self.shift.xgsd * self.grid_res,
648
649
                                               prj = proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal = self.outFillVal)
650
651
652
653
654
655
656
657
658
659
660
661

        out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)

        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))

        out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')

        return out_GA


662
    def _to_Raster_using_KrigingOLD(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
663
                                   path_out=None, tilepos=None):
664
665
666
        warnings.warn(DeprecationWarning("'to_Raster_using_KrigingOLD' is deprecated. Use to_Raster_using_Kriging "
                                         "instead.")) # TODO delete

667
        GDF             = self.CoRegPoints_table
668
669
670
        GDF2pass        = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

        # subset if tilepos is given
671
        rows,cols = tilepos if tilepos else self.shift.shape
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
        GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
                                       (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]


        X_coords,Y_coords,ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'],GDF2pass[attrName]

        xmin,ymin,xmax,ymax = GDF2pass.total_bounds

        grid_res            = outGridRes if outGridRes else int(min(xmax-xmin,ymax-ymin)/250)
        grid_x,grid_y       = np.arange(xmin, xmax+grid_res, grid_res), np.arange(ymax, ymin-grid_res, -grid_res)

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical',verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y)#,backend='C',)

688
689
690
        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="Kriging__%s__grid%s_ws(%s_%s).tif"  % (attrName, self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1]))
691
692
693
694
695
696
697
698
        print('Writing %s ...' %path_out)
        # add a half pixel grid points are centered on the output pixels
        xmin,ymin,xmax,ymax = xmin-grid_res/2,ymin-grid_res/2,xmax+grid_res/2,ymax+grid_res/2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res), prj=self.COREG_obj.shift.prj)

        return zvalues


699
    def to_Raster_using_Kriging(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
700
                             fName_out=None, tilepos=None, tilesize=500, mp=None):
701

702
703
704
        mp = False if self.CPUs==1 else True
        self._Kriging_sp(attrName, skip_nodata=skip_nodata, skip_nodata_col=skip_nodata_col,
                         outGridRes=outGridRes, fName_out=fName_out, tilepos=tilepos)
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727

        # if mp:
        #     tilepositions = UTL.get_image_tileborders([tilesize,tilesize],self.tgt_shape)
        #     args_kwargs_dicts=[]
        #     for tp in tilepositions:
        #         kwargs_dict = {'skip_nodata':skip_nodata,'skip_nodata_col':skip_nodata_col,'outGridRes':outGridRes,
        #                        'fName_out':fName_out,'tilepos':tp}
        #         args_kwargs_dicts.append({'args':[attrName],'kwargs':kwargs_dict})
        #     # self.kriged=[]
        #     # for i in args_kwargs_dicts:
        #     #     res = self.Kriging_mp(i)
        #     #     self.kriged.append(res)
        #     #     print(res)
        #
        #     with multiprocessing.Pool() as pool:
        #        self.kriged = pool.map(self.Kriging_mp,args_kwargs_dicts)
        # else:
        #     self.Kriging_sp(attrName,skip_nodata=skip_nodata,skip_nodata_col=skip_nodata_col,
        #                     outGridRes=outGridRes,fName_out=fName_out,tilepos=tilepos)
        res = self.kriged if mp else None
        return res


728
729
730
    def _Kriging_sp(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                    fName_out=None, tilepos=None):
        GDF             = self.CoRegPoints_table
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
        GDF2pass        = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

#         # subset if tilepos is given
# #        overlap_factor =
#         rows,cols = tilepos if tilepos else self.tgt_shape
#         xvals, yvals = np.sort(GDF2pass['X_IM'].values.flat),np.sort(GDF2pass['Y_IM'].values.flat)
#         cS,cE = UTL.find_nearest(xvals,cols[0],'off',1), UTL.find_nearest(xvals,cols[1],'on',1)
#         rS,rE = UTL.find_nearest(yvals,rows[0],'off',1), UTL.find_nearest(yvals,rows[1],'on',1)
#         # GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
#         #                                (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]
#         GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cS)&(GDF2pass['X_IM']<=cE)&
#                                        (GDF2pass['Y_IM']>=rS)&(GDF2pass['Y_IM']<=rE)]

        X_coords,Y_coords,ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'],GDF2pass[attrName]

        xmin,ymin,xmax,ymax = GDF2pass.total_bounds

        grid_res            = outGridRes if outGridRes else int(min(xmax-xmin,ymax-ymin)/250)
        grid_x,grid_y       = np.arange(xmin, xmax+grid_res, grid_res), np.arange(ymax, ymin-grid_res, -grid_res)

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical',verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y,backend='C',n_closest_points=12)

Daniel Scheffler's avatar
Daniel Scheffler committed
756
        if self.CPUs is None or self.CPUs>1:
757
            fName_out = fName_out if fName_out else \
758
                "Kriging__%s__grid%s_ws%s_%s.tif" %(attrName,self.grid_res, self.COREG_obj.win_size_XY,tilepos)
759
760
        else:
            fName_out = fName_out if fName_out else \
761
762
                "Kriging__%s__grid%s_ws%s.tif" %(attrName,self.grid_res, self.COREG_obj.win_size_XY)
        path_out  = get_generic_outpath(dir_out=self.dir_out, fName_out=fName_out)
763
764
765
766
767
768
769
770
        print('Writing %s ...' %path_out)
        # add a half pixel grid points are centered on the output pixels
        xmin,ymin,xmax,ymax = xmin-grid_res/2,ymin-grid_res/2,xmax+grid_res/2,ymax+grid_res/2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res), prj=self.COREG_obj.shift.prj)

        return zvalues


771
    def _Kriging_mp(self, args_kwargs_dict):
772
        args   = args_kwargs_dict.get('args'  ,[])
773
774
        kwargs = args_kwargs_dict.get('kwargs',[])

775
        return self._Kriging_sp(*args, **kwargs)
776
777
778



779
class Tie_Point_Refiner(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
780
781
    def __init__(self, GDF, min_reliability=60, rs_max_outlier=10, rs_tolerance=2.5, rs_max_iter=15,
                      rs_exclude_previous_outliers=True, rs_timeout=20, q=False):
782
        """A class for performing outlier detection.
Daniel Scheffler's avatar
Daniel Scheffler committed
783

784
785
        :param GDF:                             GeoDataFrame like TiePointGrid.CoRegPoints_table containing all tie
                                                points to be filtered and the corresponding metadata
Daniel Scheffler's avatar
Daniel Scheffler committed
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
        :param min_reliability:                 <float, int> minimum threshold for previously computed tie X/Y shift
                                                reliability (default: 60%)
        :param rs_max_outlier:                  <float, int> RANSAC: maximum percentage of outliers to be detected
                                                (default: 10%)
        :param rs_tolerance:                    <float, int> RANSAC: percentage tolerance for max_outlier_percentage
                                                (default: 2.5%)
        :param rs_max_iter:                     <int> RANSAC: maximum iterations for finding the best RANSAC threshold
                                                (default: 15)
        :param rs_exclude_previous_outliers:    <bool> RANSAC: whether to exclude points that have been flagged as
                                                outlier by earlier filtering (default:True)
        :param rs_timeout:                      <float, int> RANSAC: timeout for iteration loop in seconds (default: 20)

        :param q:
        """

        self.GDF = GDF.copy()
        self.min_reliability = min_reliability
        self.rs_max_outlier_percentage = rs_max_outlier
        self.rs_tolerance = rs_tolerance
        self.rs_max_iter = rs_max_iter
        self.rs_exclude_previous_outliers = rs_exclude_previous_outliers
        self.rs_timeout = rs_timeout
        self.q = q
        self.new_cols = []
810
811
812
        self.ransac_model_robust = None


813
    def run_filtering(self, level=3):
814
815
        """Filter tie points used for shift correction.

816
        :param level:   tie point filter level (default: 3).
817
818
819
820
821
822
823
824
                        NOTE: lower levels are also included if a higher level is chosen
                            - Level 0: no tie point filtering
                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                reliability according to internal tests
                            - Level 2: SSIM filtering - filters all tie points out where shift
                                correction does not increase image similarity within matching window
                                (measured by mean structural similarity index)
                            - Level 3: RANSAC outlier detection
Daniel Scheffler's avatar
Daniel Scheffler committed
825
826
827
828

        :return:
        """

829
830
        # TODO catch empty GDF

831
        # RELIABILITY filtering
832
        if level>0:
833
            marked_recs = GeoSeries(self._reliability_thresholding())
834
835
            self.GDF['L1_OUTLIER'] = marked_recs
            self.new_cols.append('L1_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
836

837
838
839
            if not self.q:
                print('%s tie points flagged by level 1 filtering (reliability).' % (len(marked_recs[marked_recs==True])))

Daniel Scheffler's avatar
Daniel Scheffler committed
840

841
        # SSIM filtering
842
        if level>1:
843
            marked_recs = GeoSeries(self._SSIM_filtering())
844
845
            self.GDF['L2_OUTLIER'] = marked_recs
            self.new_cols.append('L2_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
846

847
848
849
            if not self.q:
                print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs==True])))

Daniel Scheffler's avatar
Daniel Scheffler committed
850

851
        # RANSAC filtering
852
        if level>2:
Daniel Scheffler's avatar
Daniel Scheffler committed
853
854
855
856
857
            # exclude previous outliers
            ransacInGDF = self.GDF[self.GDF[self.new_cols].any(axis=1) == False].copy()\
                            if self.rs_exclude_previous_outliers else self.GDF

            if len(ransacInGDF)>4:
858
                # running RANSAC with less than four tie points makes no sense
Daniel Scheffler's avatar
Daniel Scheffler committed
859
860

                marked_recs = GeoSeries(self._RANSAC_outlier_detection(ransacInGDF))
861
                self.GDF['L3_OUTLIER'] = marked_recs.tolist() # we need to join a list here because otherwise it's merged by the 'index' column
Daniel Scheffler's avatar
Daniel Scheffler committed
862

863
                if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
864
                    print('%s tie points flagged by level 3 filtering (RANSAC)' % (len(marked_recs[marked_recs == True])))
865
866
867
868
            else:
                print('RANSAC skipped because too less valid tie points have been found.')
                self.GDF['L3_OUTLIER'] = False

869
            self.new_cols.append('L3_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
870

871
872
873
874
875
876
877

        self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
        self.new_cols.append('OUTLIER')

        return self.GDF, self.new_cols


Daniel Scheffler's avatar
Daniel Scheffler committed
878
    def _reliability_thresholding(self):
879
        """Exclude all records where estimated reliability of the calculated shifts is below the given threshold."""
880

Daniel Scheffler's avatar
Daniel Scheffler committed
881
        return self.GDF.RELIABILITY < self.min_reliability
882
883
884


    def _SSIM_filtering(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
885
        """Exclude all records where SSIM decreased."""
886
887
888
889
890

        #ssim_diff  = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE'])

        #self.GDF.SSIM_IMPROVED = self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1)

891
892
893
        return self.GDF.SSIM_IMPROVED == False


Daniel Scheffler's avatar
Daniel Scheffler committed
894
895
    def _RANSAC_outlier_detection(self, inGDF):
        """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm."""
896

Daniel Scheffler's avatar
Daniel Scheffler committed
897
898
        src_coords = np.array(inGDF[['X_UTM', 'Y_UTM']])
        xyShift    = np.array(inGDF[['X_SHIFT_M', 'Y_SHIFT_M']])
899
900
901
902
903
        est_coords = src_coords + xyShift

        for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']):
            assert co.ndim==2 and co.shape[1]==2, "'%s' must have shape [Nx2]. Got shape %s."%(n, co.shape)

Daniel Scheffler's avatar
Daniel Scheffler committed
904
905
        if not 0 < self.rs_max_outlier_percentage < 100: raise ValueError
        min_inlier_percentage = 100-self.rs_max_outlier_percentage
906
907
908
909
910
911

        class PolyTF_1(PolynomialTransform):
            def estimate(*data):
                return PolynomialTransform.estimate(*data, order=1)

        # robustly estimate affine transform model with RANSAC
912
        # eliminates not more than the given maximum outlier percentage of the tie points
913
914
915

        model_robust, inliers = None, None
        count_inliers         = None
916
        th                    = 5  # start RANSAC threshold
917
918
919
920
        th_checked            = {} # dict of thresholds that already have been tried + calculated inlier percentage
        th_substract          = 2
        count_iter            = 0
        time_start            = time.time()
921
        ideal_count           = min_inlier_percentage * src_coords.shape[0] / 100
922

923
        # optimize RANSAC threshold so that it marks not much more or less than the given outlier percentage
924
925
        while True:
            if th_checked:
926
927
                th_too_strict = count_inliers < ideal_count # True if too less inliers remaining

928
929
930
931
932
933
                # calculate new theshold using old increment (but ensure th_new>0 by adjusting increment if needed)
                th_new = 0
                while th_new <= 0:
                    th_new = th+th_substract if th_too_strict else th-th_substract
                    if th_new <= 0:
                        th_substract /=2
934
935
936
937
938
939
940
941

                # check if calculated new threshold has been used before
                th_already_checked = th_new in th_checked.keys()

                # if yes, decrease increment and recalculate new threshold
                th_substract       = th_substract if not th_already_checked else th_substract / 2
                th                 = th_new if not th_already_checked else \
                                        (th+th_substract if th_too_strict else th-th_substract)
942

943
            # RANSAC call
944
            # model_robust, inliers = ransac((src, dst), PolynomialTransform, min_samples=3,
945
946
947
948
949
950
            if src_coords.size and est_coords.size:
                model_robust, inliers = \
                    ransac((src_coords, est_coords), AffineTransform,
                           min_samples        = 6,
                           residual_threshold = th,
                           max_trials         = 2000,
Daniel Scheffler's avatar
Daniel Scheffler committed
951
952
                           stop_sample_num    = int((min_inlier_percentage-self.rs_tolerance) /100*src_coords.shape[0]),
                           stop_residuals_sum = int((self.rs_max_outlier_percentage-self.rs_tolerance)/100*src_coords.shape[0])
953
954
955
956
957
                           )
            else:
                inliers = np.array([])
                break

958
            count_inliers  = np.count_nonzero(inliers)
959

960
            th_checked[th] = count_inliers / src_coords.shape[0] * 100
961
            #print(th,'\t', th_checked[th], )
Daniel Scheffler's avatar
Daniel Scheffler committed
962
            if min_inlier_percentage-self.rs_tolerance < th_checked[th] < min_inlier_percentage+self.rs_tolerance:
963
964
                #print('in tolerance')
                break
Daniel Scheffler's avatar
Daniel Scheffler committed
965
            if count_iter > self.rs_max_iter or time.time()-time_start > self.rs_timeout:
966
967
                break # keep last values and break while loop

968
969
            count_iter+=1

Daniel Scheffler's avatar
Daniel Scheffler committed
970
        outliers = inliers == False if inliers is not None and inliers.size else np.array([])
971

Daniel Scheffler's avatar
Daniel Scheffler committed
972
        if inGDF.empty or outliers is None or (isinstance(outliers, list) and not outliers) or \
973
                (isinstance(outliers, np.ndarray) and not outliers.size):
974
            gs              = GeoSeries([False]*len(self.GDF))
Daniel Scheffler's avatar
Daniel Scheffler committed
975
976
        elif len(inGDF) < len(self.GDF):
            inGDF['outliers'] = outliers
977
            fullGDF         = GeoDataFrame(self.GDF['POINT_ID'])
Daniel Scheffler's avatar
Daniel Scheffler committed
978
            fullGDF         = fullGDF.merge(inGDF[['POINT_ID', 'outliers']], on='POINT_ID', how="outer")
979
980
            #fullGDF.outliers.copy()[~fullGDF.POINT_ID.isin(GDF.POINT_ID)] = False
            fullGDF         = fullGDF.fillna(False) # NaNs are due to exclude_previous_outliers
981
982
983
984
985
986
            gs              = fullGDF['outliers']
        else:
            gs              = GeoSeries(outliers)

        assert len(gs)==len(self.GDF), 'RANSAC output validation failed.'

987
        self.ransac_model_robust = model_robust
988
989

        return gs
990