Tie_Point_Grid.py 47.4 KB
Newer Older
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-
__author__='Daniel Scheffler'

import collections
import multiprocessing
import os
import warnings
8
import time
9
10

# custom
11
12
13
14
try:
    import gdal
except ImportError:
    from osgeo import gdal
15
import numpy as np
16
from matplotlib import pyplot as plt
17
from geopandas         import GeoDataFrame, GeoSeries
18
19
from pykrige.ok        import OrdinaryKriging
from shapely.geometry  import Point
20
21
from skimage.measure   import points_in_poly, ransac
from skimage.transform import AffineTransform, PolynomialTransform
22
23

# internal modules
24
25
from .CoReg import COREG
from . import io as IO
26
27
28
29
from py_tools_ds.geo.projection          import isProjectedOrGeographic, get_UTMzone, dict_to_proj4, proj4_to_WKT
from py_tools_ds.io.pathgen              import get_generic_outpath
from py_tools_ds.processing.progress_mon import ProgressBar
from py_tools_ds.geo.vector.conversion   import points_to_raster
30
from geoarray import GeoArray
31
32
33
34
35
36
37



global_shared_imref    = None
global_shared_im2shift = None


38
39
class Tie_Point_Grid(object):
    """See help(Tie_Point_Grid) for documentation!"""
40

41
    def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
42
                 tieP_filter_level=3, outlDetect_settings=None, dir_out=None, CPUs=None, progress=True, v=False, q=False):
43

44
45
46
        """Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
        are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
        correction performs a polynomial transformation using te calculated shifts of each point in the grid as GCPs.
47
        Thus 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target image.
48

49
        :param COREG_obj(object):       an instance of COREG class
50
        :param grid_res:                grid resolution in pixels of the target image (x-direction)
51
        :param max_points(int):         maximum number of points used to find coregistration tie points
52
53
54
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
Daniel Scheffler's avatar
Daniel Scheffler committed
55
        :param outFillVal(int):         if given the generated tie points grid is filled with this value in case
56
                                        no match could be found during co-registration (default: -9999)
57
58
59
60
61
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average, mode,
                                                       max, min, med, q1, q3)
                                        default: cubic (highly recommended)
62
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 3).
63
                                        NOTE: lower levels are also included if a higher level is chosen
64
                                            - Level 0: no tie point filtering
65
66
67
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
68
69
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
70
                                            - Level 3: RANSAC outlier detection
71
72
73
74
        :param outlDetect_settings      a dictionary with the settings to be passed to
                                        arosics.TiePointGrid.Tie_Point_Refiner. Available keys: min_reliability,
                                        rs_max_outlier, rs_tolerance, rs_max_iter, rs_exclude_previous_outliers,
                                        rs_timeout, q. See documentation there.
75
76
        :param dir_out(str):            output directory to be used for all outputs if nothing else is given
                                        to the individual methods
Daniel Scheffler's avatar
Daniel Scheffler committed
77
        :param CPUs(int):               number of CPUs to use during calculation of tie points grid
78
                                        (default: None, which means 'all CPUs available')
79
        :param progress(bool):          show progress bars (default: True)
80
81
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
82
        """
83
84
85

        if not isinstance(COREG_obj, COREG): raise ValueError("'COREG_obj' must be an instance of COREG class.")

86
87
        self.COREG_obj         = COREG_obj
        self.grid_res          = grid_res
88
        self.max_points        = max_points
89
90
91
        self.outFillVal        = outFillVal
        self.rspAlg_calc       = resamp_alg_calc
        self.tieP_filter_level = tieP_filter_level
92
        self.outlDetect_settings = outlDetect_settings if outlDetect_settings else dict(q=q)
93
94
95
96
97
        self.dir_out           = dir_out
        self.CPUs              = CPUs
        self.v                 = v
        self.q                 = q        if not v else False # overridden by v
        self.progress          = progress if not q else False # overridden by q
98

99
100
        self.ref               = self.COREG_obj.ref
        self.shift             = self.COREG_obj.shift
101
102

        self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
103
104
105
        self._CoRegPoints_table           = None # set by self.CoRegPoints_table
        self._GCPList                     = None # set by self.to_GCPList()
        self.kriged                       = None # set by Raster_using_Kriging()
106
107


108
109
    @property
    def CoRegPoints_table(self):
110
111
        """Returns a GeoDataFrame with the columns 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE',
        'Y_WIN_SIZE','X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE' containing all
Daniel Scheffler's avatar
Daniel Scheffler committed
112
        information containing all the results frm coregistration for all points in the tie points grid.
113
        """
114
115
116
117
118
119
        if self._CoRegPoints_table is not None:
            return self._CoRegPoints_table
        else:
            self._CoRegPoints_table = self.get_CoRegPoints_table()
            return self._CoRegPoints_table

120

121
122
123
124
125
126
127
    @CoRegPoints_table.setter
    def CoRegPoints_table(self, CoRegPoints_table):
        self._CoRegPoints_table = CoRegPoints_table


    @property
    def GCPList(self):
128
129
        """Returns a list of GDAL compatible GCP objects.
        """
130
131
132
133
        if self._GCPList:
            return self._GCPList
        else:
            self._GCPList = self.to_GCPList()
134
            return self._GCPList
135
136
137
138
139
140
141
142


    @GCPList.setter
    def GCPList(self, GCPList):
        self._GCPList = GCPList


    def _get_imXY__mapXY_points(self, grid_res):
143
144
145
146
147
148
        """Returns a numpy array containing possible positions for coregistration tie points according to the given
        grid resolution.

        :param grid_res:
        :return:
        """
149
        if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
150
            print('Initializing tie points grid...')
151

152
153
154
        Xarr,Yarr       = np.meshgrid(np.arange(0,self.shift.shape[1],grid_res),
                                      np.arange(0,self.shift.shape[0],grid_res))

Daniel Scheffler's avatar
Daniel Scheffler committed
155
156
        mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr*self.shift.gt[1]
        mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr*abs(self.shift.gt[5])
157
158
159
160
161
162
163
164
165

        XY_points      = np.empty((Xarr.size,2),Xarr.dtype)
        XY_points[:,0] = Xarr.flat
        XY_points[:,1] = Yarr.flat

        XY_mapPoints      = np.empty((mapXarr.size,2),mapXarr.dtype)
        XY_mapPoints[:,0] = mapXarr.flat
        XY_mapPoints[:,1] = mapYarr.flat

Daniel Scheffler's avatar
Daniel Scheffler committed
166
167
        assert XY_points.shape == XY_mapPoints.shape

168
169
170
        return XY_points,XY_mapPoints


171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    def _exclude_bad_XYpos(self, GDF):
        """Excludes all points outside of the image overlap area and all points where the bad data mask is True (if given).

        :param GDF:     <geopandas.GeoDataFrame> must include the columns 'X_UTM' and 'Y_UTM'
        :return:
        """

        # exclude all points outside of overlap area
        inliers = points_in_poly(self.XY_mapPoints,
                                 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
        GDF = GDF[inliers].copy()
        #GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower

        assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!' # FIXME track that


        # exclude all point where bad data mask is True (e.g. points on clouds etc.)
Daniel Scheffler's avatar
Daniel Scheffler committed
188
        orig_len_GDF       = len(GDF) # length of GDF after dropping all points outside the overlap polygon
189
190
191
192
193
194
195
        mapXY              = np.array(GDF.loc[:,['X_UTM','Y_UTM']])
        GDF['REF_BADDATA'] = self.COREG_obj.ref  .mask_baddata.read_pointData(mapXY) \
                                if self.COREG_obj.ref  .mask_baddata is not None else False
        GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY)\
                                if self.COREG_obj.shift.mask_baddata is not None else False
        GDF                = GDF[(GDF['REF_BADDATA']==False) & (GDF['TGT_BADDATA']==False)]
        if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
Daniel Scheffler's avatar
Daniel Scheffler committed
196
197
198
            if not self.q:
                print('According to the provided bad data mask(s) %s points of initially %s have been excluded.'
                      %(orig_len_GDF-len(GDF), orig_len_GDF))
199
200
201
202

        return GDF


203
204
    @staticmethod
    def _get_spatial_shifts(coreg_kwargs):
Daniel Scheffler's avatar
Daniel Scheffler committed
205
        # unpack
206
207
208
        pointID    = coreg_kwargs['pointID']
        fftw_works = coreg_kwargs['fftw_works']
        del coreg_kwargs['pointID'], coreg_kwargs['fftw_works']
209

Daniel Scheffler's avatar
Daniel Scheffler committed
210
        # assertions
211
212
        assert global_shared_imref    is not None
        assert global_shared_im2shift is not None
Daniel Scheffler's avatar
Daniel Scheffler committed
213
214

        # run CoReg
215
        CR = COREG(global_shared_imref, global_shared_im2shift, CPUs=1, **coreg_kwargs)
216
        CR.fftw_works = fftw_works
217
        CR.calculate_spatial_shifts()
Daniel Scheffler's avatar
Daniel Scheffler committed
218
219

        # fetch results
220
        last_err           = CR.tracked_errors[-1] if CR.tracked_errors else None
221
222
223
        win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
        CR_res   = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
                    CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
224
                    CR.shift_reliability, last_err]
225
226

        return [pointID]+CR_res
227
228


229
    def get_CoRegPoints_table(self):
230
231
        assert self.XY_points is not None and self.XY_mapPoints is not None

Daniel Scheffler's avatar
Daniel Scheffler committed
232
        # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM' (convert imCoords to mapCoords
233
        XYarr2PointGeom = np.vectorize(lambda X,Y: Point(X,Y), otypes=[Point])
234
235
236
237
238
239
240
        geomPoints      = np.array(XYarr2PointGeom(self.XY_mapPoints[:,0],self.XY_mapPoints[:,1]))

        if isProjectedOrGeographic(self.COREG_obj.shift.prj)=='geographic':
            crs = dict(ellps='WGS84', datum='WGS84', proj='longlat')
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj)=='projected':
            UTMzone = abs(get_UTMzone(prj=self.COREG_obj.shift.prj))
            south   = get_UTMzone(prj=self.COREG_obj.shift.prj)<0
241
            crs     = dict(ellps='WGS84', datum='WGS84', proj='utm', zone=UTMzone, south=south, units='m', no_defs=True)
242
243
244
245
            if not south: del crs['south']
        else:
            crs = None

246
247
        GDF                          = GeoDataFrame(index=range(len(geomPoints)),crs=crs,
                                                    columns=['geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'])
248
249
        GDF       ['geometry']       = geomPoints
        GDF       ['POINT_ID']       = range(len(geomPoints))
250
        GDF.loc[:,['X_IM' ,'Y_IM' ]] = self.XY_points
251
        GDF.loc[:,['X_UTM','Y_UTM']] = self.XY_mapPoints
252

253
254
        # exclude offsite points and points on bad data mask
        GDF = self._exclude_bad_XYpos(GDF)
255
256
257
        if GDF.empty:
            self.CoRegPoints_table = GDF
            return self.CoRegPoints_table
258

259
        # choose a random subset of points if a maximum number has been given
260
        if self.max_points and len(GDF) > self.max_points:
261
            GDF = GDF.sample(self.max_points).copy()
262

263
264
265
266
        # equalize pixel grids in order to save warping time
        if len(GDF)>100:
            self.COREG_obj.equalize_pixGrids() # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is connected to that

267
268
        # declare global variables needed for self._get_spatial_shifts()
        global global_shared_imref,global_shared_im2shift
269
270
        assert self.ref  .footprint_poly # this also checks for mask_nodata and nodata value
        assert self.shift.footprint_poly
271
272
273

        # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if
        # neighboured matching windows overlap during reading from disk!!
Daniel Scheffler's avatar
Daniel Scheffler committed
274
275
276
        self.ref.cache_array_subset([self.COREG_obj.ref.band4match]) # only sets geoArr._arr_cache; does not change number of bands
        self.shift.cache_array_subset([self.COREG_obj.shift.band4match])

277
278
        global_shared_imref    = self.ref
        global_shared_im2shift = self.shift
279
280

        # get all variations of kwargs for coregistration
281
282
        get_coreg_kwargs = lambda pID, wp: dict(
            pointID            = pID,
283
            fftw_works         = self.COREG_obj.fftw_works,
284
285
286
287
288
            wp                 = wp,
            ws                 = self.COREG_obj.win_size_XY,
            resamp_alg_calc    = self.rspAlg_calc,
            footprint_poly_ref = self.COREG_obj.ref.poly,
            footprint_poly_tgt = self.COREG_obj.shift.poly,
Daniel Scheffler's avatar
Daniel Scheffler committed
289
290
            r_b4match          = self.ref.band4match+1,   # band4match is internally saved as index, starting from 0
            s_b4match          = self.shift.band4match+1, # band4match is internally saved as index, starting from 0
291
292
293
            max_iter           = self.COREG_obj.max_iter,
            max_shift          = self.COREG_obj.max_shift,
            nodata             = (self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
Daniel Scheffler's avatar
Daniel Scheffler committed
294
            force_quadratic_win= self.COREG_obj.force_quadratic_win,
295
296
297
298
299
            binary_ws          = self.COREG_obj.bin_ws,
            v                  = False, # otherwise this would lead to massive console output
            q                  = True,  # otherwise this would lead to massive console output
            ignore_errors      = True
        )
300
301
302
        list_coreg_kwargs = (get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index) # generator

        # run co-registration for whole grid
303
        if self.CPUs is None or self.CPUs>1:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
304
            if not self.q:
305
                cpus = self.CPUs if self.CPUs is not None else multiprocessing.cpu_count()
Daniel Scheffler's avatar
Daniel Scheffler committed
306
                print("Calculating tie points grid (%s points) using %s CPU cores..." %(len(GDF), cpus))
307

308
            with multiprocessing.Pool(self.CPUs) as pool:
309
310
311
312
                if self.q or not self.progress:
                    results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
                else:
                    results = pool.map_async(self._get_spatial_shifts, list_coreg_kwargs, chunksize=1)
313
                    bar     = ProgressBar(prefix='\tprogress:')
314
315
316
                    while True:
                        time.sleep(.1)
                        numberDone = len(GDF)-results._number_left # this does not really represent the remaining tasks but the remaining chunks -> thus chunksize=1
317
318
                        if self.progress:
                            bar.print_progress(percent=numberDone/len(GDF)*100)
319
                        if results.ready():
320
                            results = results.get() # <= this is the line where multiprocessing can freeze if an exception appears within COREG ans is not raised
321
                            break
Daniel Scheffler's avatar
Daniel Scheffler committed
322

323
        else:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
324
            if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
325
                print("Calculating tie points grid (%s points) 1 CPU core..." %len(GDF))
326
            results = np.empty((len(geomPoints),14), np.object)
327
            bar     = ProgressBar(prefix='\tprogress:')
328
            for i,coreg_kwargs in enumerate(list_coreg_kwargs):
329
330
                if self.progress:
                    bar.print_progress((i+1)/len(GDF)*100)
331
332
                results[i,:] = self._get_spatial_shifts(coreg_kwargs)

333
         # merge results with GDF
334
335
336
        records = GeoDataFrame(np.array(results, np.object),
                               columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M',
                                        'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
337
338
                                        'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])

339
340
341
        GDF = GDF.merge(records, on='POINT_ID', how="inner")
        GDF = GDF.fillna(int(self.outFillVal))

342
343
344
        if not self.q:
            print("Found %s matches." % len(GDF[GDF.LAST_ERR == int(self.outFillVal)]))

345
346
        # filter tie points according to given filter level
        if self.tieP_filter_level>0:
347
348
            if not self.q:
                print('Performing validity checks...')
349
            TPR                   = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], **self.outlDetect_settings)
350
351
352
            GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
            GDF                   = GDF.merge(GDF_filt[ ['POINT_ID']+new_columns], on='POINT_ID', how="outer")
        GDF = GDF.fillna(int(self.outFillVal))
353

354
        self.CoRegPoints_table = GDF
355
356
357
358

        return self.CoRegPoints_table


359
360
361
362
    def calc_rmse(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates root mean square error of absolute shifts from the tie point grid.

Daniel Scheffler's avatar
Daniel Scheffler committed
363
        :param include_outliers:    whether to include tie points that have been marked as false-positives (if present)
364
365
366
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
367
        tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == False].copy() if 'OUTLIER' in tbl.columns else tbl
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409

        shifts = np.array(tbl['ABS_SHIFT'])
        shifts_sq = [i * i for i in shifts if i != self.outFillVal]

        return np.sqrt(sum(shifts_sq) / len(shifts_sq))


    def calc_overall_mssim(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates the median value of all MSSIM values contained in tie point grid.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        """

        tbl = self.CoRegPoints_table
        tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == False].copy()

        mssim_col = np.array(tbl['MSSIM'])
        mssim_col = [i * i for i in mssim_col if i != self.outFillVal]

        return float(np.median(mssim_col))


    def plot_shift_distribution(self, include_outliers=True, unit='m', interactive=False, figsize=None, xlim=None,
                                ylim=None, fontsize=12, title='shift distribution'):
        # type: (bool, str, bool, tuple, list, list, int) -> tuple
        """Creates a 2D scatterplot containing the distribution of calculated X/Y-shifts.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        :param unit:                'm' for meters or 'px' for pixels (default: 'm')
        :param interactive:         interactive mode uses plotly for visualization
        :param figsize:             (xdim, ydim)
        :param xlim:                [xmin, xmax]
        :param ylim:                [ymin, ymax]
        :param fontsize:            size of all used fonts
        :param title:               the title to be plotted above the figure
        """

        if not unit in ['m', 'px']:
            raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." %unit)

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
410
411
412
        tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal]
        tbl_il = tbl[tbl['OUTLIER'] == False].copy() if 'OUTLIER' in tbl.columns else tbl
        tbl_ol = tbl[tbl['OUTLIER'] == True].copy() if 'OUTLIER' in tbl.columns else None
413
414
415
416
417
418
419
420
        x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
        y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
        rmse   = self.calc_rmse(include_outliers=False) # always exclude outliers when calculating RMSE
        figsize = figsize if figsize else (10,10)

        if interactive:
            from plotly.offline import iplot, init_notebook_mode
            import plotly.graph_objs as go
Daniel Scheffler's avatar
Daniel Scheffler committed
421
            # FIXME outliers are not plotted
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

            init_notebook_mode(connected=True)

            # Create a trace
            trace = go.Scatter(
                x=tbl_il[x_attr],
                y=tbl_il[y_attr],
                mode='markers'
            )

            data = [trace]

            # Plot and embed in ipython notebook!
            iplot(data, filename='basic-scatter')

            return None, None

        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111)

Daniel Scheffler's avatar
Daniel Scheffler committed
443
            if include_outliers and 'OUTLIER' in tbl.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
444
                ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives')
445
446
447
448
449
450
451
452
453
454
455
456
            ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')

            # set axis limits
            if not xlim:
                xmax = np.abs(tbl_il[x_attr]).max()
                xlim = [-xmax, xmax]
            if not ylim:
                ymax = np.abs(tbl_il[y_attr]).max()
                ylim = [-ymax, ymax]
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

Daniel Scheffler's avatar
Daniel Scheffler committed
457
            # add text box containing RMSE of plotted shifts
458
459
460
461
            xlim, ylim = ax.get_xlim(), ax.get_ylim()
            plt.text(xlim[1]-(xlim[1]/20),-ylim[1]+(ylim[1]/20),
                     'RMSE:  %s m / %s px' %(np.round(rmse, 2), np.round(rmse/self.shift.xgsd, 2)),
                     ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
462

Daniel Scheffler's avatar
Daniel Scheffler committed
463
            # add grid and increase linewidth of middle line
464
465
466
            plt.grid()
            xgl = ax.get_xgridlines()
            middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
467
            middle_xgl.set_linewidth(2)
468
469
470
            middle_xgl.set_linestyle('-')
            ygl = ax.get_ygridlines()
            middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
471
            middle_ygl.set_linewidth(2)
472
473
            middle_ygl.set_linestyle('-')

Daniel Scheffler's avatar
Daniel Scheffler committed
474
475
            # set title and adjust tick labels
            ax.set_title(title, fontsize=fontsize)
476
477
            [tick.label.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
            [tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
Daniel Scheffler's avatar
Daniel Scheffler committed
478
479
            plt.xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
            plt.ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
480

481
482
            # add legend with labels in the right order
            handles, labels = ax.get_legend_handles_labels()
Daniel Scheffler's avatar
Daniel Scheffler committed
483
484
            leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3)
            leg.get_frame().set_edgecolor('black')
485

486
487
488
489
490
            plt.show()

            return fig, ax


491
492
493
494
495
496
497
    def dump_CoRegPoints_table(self, path_out=None):
        path_out = path_out if path_out else get_generic_outpath(dir_out=self.dir_out,
            fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
        if not self.q:
            print('Writing %s ...' % path_out)
        self.CoRegPoints_table.to_pickle(path_out)
498
499


500
    def to_GCPList(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
501
        # get copy of tie points grid without no data
Daniel Scheffler's avatar
Daniel Scheffler committed
502
503
504
505
506
        try:
            GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
        except AttributeError:
            # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded
            return []
507

508
509
510
        if getattr(GDF,'empty'): # GDF.empty returns AttributeError
            return []
        else:
511
            # exclude all points flagged as outliers
512
513
            if 'OUTLIER' in GDF.columns:
                GDF = GDF[GDF.OUTLIER == False].copy()
514
515
            avail_TP = len(GDF)

516
517
518
519
            if not avail_TP:
                # no point passed all validity checks
                return []

520
521
522
523
524
            if avail_TP>7000:
                GDF = GDF.sample(7000)
                warnings.warn('By far not more than 7000 tie points can be used for warping within a limited '
                              'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen '
                              'out of the %s available tie points.' %avail_TP)
525

526
527
528
529
530
531
532
533
            # calculate GCPs
            GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
            GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
            GDF['GCP']       = GDF.apply(lambda GDF_row:    gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
                                                                     GDF_row.X_IM, GDF_row.Y_IM), axis=1)
            self.GCPList = GDF.GCP.tolist()

            if not self.q:
534
                print('Found %s valid tie points.' %len(self.GCPList))
535
536

            return self.GCPList
537
538


539
    def test_if_singleprocessing_equals_multiprocessing_result(self):
540
541
        self.tieP_filter_level=1 # RANSAC filtering always produces different results because it includes random sampling

Daniel Scheffler's avatar
Daniel Scheffler committed
542
        self.CPUs = None
543
        dataframe = self.get_CoRegPoints_table()
544
545
        mp_out    = np.empty_like(dataframe.values)
        mp_out[:] = dataframe.values
Daniel Scheffler's avatar
Daniel Scheffler committed
546
        self.CPUs = 1
547
        dataframe = self.get_CoRegPoints_table()
548
549
550
551
552
553
        sp_out    = np.empty_like(dataframe.values)
        sp_out[:] = dataframe.values

        return np.array_equal(sp_out,mp_out)


554
555
    def _get_line_by_PID(self, PID):
        return self.CoRegPoints_table.loc[PID, :]
556
557


558
    def _get_lines_by_PIDs(self, PIDs):
559
        assert isinstance(PIDs,list)
560
        lines = np.zeros((len(PIDs),self.CoRegPoints_table.shape[1]))
561
        for i,PID in enumerate(PIDs):
562
            lines[i,:] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
563
564
565
        return lines


566
    def to_PointShapefile(self, path_out=None, skip_nodata=True, skip_nodata_col ='ABS_SHIFT'):
567
        # type: (str, bool, str) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
568
        """Writes the calculated tie points grid to a point shapefile containing
569
        Tie_Point_Grid.CoRegPoints_table as attribute table. This shapefile can easily be displayed using GIS software.
570
571
572

        :param path_out:        <str> the output path. If not given, it is automatically defined.
        :param skip_nodata:     <bool> whether to skip all points where no valid match could be found
573
        :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to
574
575
                                identify points where no valid match could be found
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
576
577
578
579
580
581
582
583
584
        GDF      = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

        # replace boolean values (cannot be written)
        for col in GDF2pass.columns:
            if GDF2pass[col].dtype == np.bool:
                GDF2pass[col] = GDF2pass[col].astype(int)
        GDF2pass = GDF2pass.replace(False, 0) # replace all remaining booleans where dtype is not np.bool but np.object
        GDF2pass = GDF2pass.replace(True,  1)
585

586
587
588
        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
589
590
        if not self.q:
            print('Writing %s ...' %path_out)
591
592
593
        GDF2pass.to_file(path_out)


594
    def _to_PointShapefile(self, skip_nodata=True, skip_nodata_col ='ABS_SHIFT'):
Daniel Scheffler's avatar
Daniel Scheffler committed
595
596
        warnings.warn(DeprecationWarning("'_tiepoints_grid_to_PointShapefile' is deprecated." # TODO delete if other method validated
                                         " 'tiepoints_grid_to_PointShapefile' is much faster."))
597
        GDF            = self.CoRegPoints_table
598
599
600
601
602
        GDF2pass       = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]
        shapely_points = GDF2pass['geometry'].values.tolist()
        attr_dicts     = [collections.OrderedDict(zip(GDF2pass.columns,GDF2pass.loc[i].values)) for i in GDF2pass.index]


603
        fName_out = "CoRegPoints_grid%s_ws%s.shp" %(self.grid_res, self.COREG_obj.win_size_XY)
604
605
606
607
        path_out = os.path.join(self.dir_out, fName_out)
        IO.write_shp(path_out, shapely_points, prj=self.COREG_obj.shift.prj, attrDict=attr_dicts)


608
    def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
609
        # type: (str) -> GeoArray
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
        """Saves the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield
        (e.g. using ArcGIS)

        :param path_out:    <str> the output path. If not given, it is automatically defined.
        :param fmt:         <str> output raster format string
        :param mode:        <str>   'uv': outputs X-/Y shifts
                                    'md': outputs magnitude and direction
        """

        assert mode in ['uv','md'], \
            "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' (outputs magnitude and direction)'. Got %s." %mode
        attr_b1 = 'X_SHIFT_M' if mode=='uv' else 'ABS_SHIFT'
        attr_b2 = 'Y_SHIFT_M' if mode=='uv' else 'ANGLE'

        xshift_arr, gt, prj = points_to_raster(points = self.CoRegPoints_table['geometry'],
                                               values = self.CoRegPoints_table[attr_b1],
                                               tgt_res = self.shift.xgsd * self.grid_res,
627
628
                                               prj = proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal = self.outFillVal)
629
630
631
632

        yshift_arr, gt, prj = points_to_raster(points = self.CoRegPoints_table['geometry'],
                                               values = self.CoRegPoints_table[attr_b2],
                                               tgt_res = self.shift.xgsd * self.grid_res,
633
634
                                               prj = proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal = self.outFillVal)
635
636
637
638
639
640
641
642
643
644
645
646

        out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)

        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))

        out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')

        return out_GA


647
648
    def to_Raster_using_KrigingOLD(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                                   path_out=None, tilepos=None):
649
650
651
        warnings.warn(DeprecationWarning("'to_Raster_using_KrigingOLD' is deprecated. Use to_Raster_using_Kriging "
                                         "instead.")) # TODO delete

652
        GDF             = self.CoRegPoints_table
653
654
655
        GDF2pass        = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

        # subset if tilepos is given
656
        rows,cols = tilepos if tilepos else self.shift.shape
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
                                       (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]


        X_coords,Y_coords,ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'],GDF2pass[attrName]

        xmin,ymin,xmax,ymax = GDF2pass.total_bounds

        grid_res            = outGridRes if outGridRes else int(min(xmax-xmin,ymax-ymin)/250)
        grid_x,grid_y       = np.arange(xmin, xmax+grid_res, grid_res), np.arange(ymax, ymin-grid_res, -grid_res)

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical',verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y)#,backend='C',)

673
674
675
        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="Kriging__%s__grid%s_ws(%s_%s).tif"  % (attrName, self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1]))
676
677
678
679
680
681
682
683
        print('Writing %s ...' %path_out)
        # add a half pixel grid points are centered on the output pixels
        xmin,ymin,xmax,ymax = xmin-grid_res/2,ymin-grid_res/2,xmax+grid_res/2,ymax+grid_res/2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res), prj=self.COREG_obj.shift.prj)

        return zvalues


684
685
    def Raster_using_Kriging(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                             fName_out=None, tilepos=None, tilesize=500, mp=None):
686

687
688
689
        mp = False if self.CPUs==1 else True
        self._Kriging_sp(attrName, skip_nodata=skip_nodata, skip_nodata_col=skip_nodata_col,
                         outGridRes=outGridRes, fName_out=fName_out, tilepos=tilepos)
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712

        # if mp:
        #     tilepositions = UTL.get_image_tileborders([tilesize,tilesize],self.tgt_shape)
        #     args_kwargs_dicts=[]
        #     for tp in tilepositions:
        #         kwargs_dict = {'skip_nodata':skip_nodata,'skip_nodata_col':skip_nodata_col,'outGridRes':outGridRes,
        #                        'fName_out':fName_out,'tilepos':tp}
        #         args_kwargs_dicts.append({'args':[attrName],'kwargs':kwargs_dict})
        #     # self.kriged=[]
        #     # for i in args_kwargs_dicts:
        #     #     res = self.Kriging_mp(i)
        #     #     self.kriged.append(res)
        #     #     print(res)
        #
        #     with multiprocessing.Pool() as pool:
        #        self.kriged = pool.map(self.Kriging_mp,args_kwargs_dicts)
        # else:
        #     self.Kriging_sp(attrName,skip_nodata=skip_nodata,skip_nodata_col=skip_nodata_col,
        #                     outGridRes=outGridRes,fName_out=fName_out,tilepos=tilepos)
        res = self.kriged if mp else None
        return res


713
714
715
    def _Kriging_sp(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                    fName_out=None, tilepos=None):
        GDF             = self.CoRegPoints_table
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
        GDF2pass        = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

#         # subset if tilepos is given
# #        overlap_factor =
#         rows,cols = tilepos if tilepos else self.tgt_shape
#         xvals, yvals = np.sort(GDF2pass['X_IM'].values.flat),np.sort(GDF2pass['Y_IM'].values.flat)
#         cS,cE = UTL.find_nearest(xvals,cols[0],'off',1), UTL.find_nearest(xvals,cols[1],'on',1)
#         rS,rE = UTL.find_nearest(yvals,rows[0],'off',1), UTL.find_nearest(yvals,rows[1],'on',1)
#         # GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
#         #                                (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]
#         GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cS)&(GDF2pass['X_IM']<=cE)&
#                                        (GDF2pass['Y_IM']>=rS)&(GDF2pass['Y_IM']<=rE)]

        X_coords,Y_coords,ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'],GDF2pass[attrName]

        xmin,ymin,xmax,ymax = GDF2pass.total_bounds

        grid_res            = outGridRes if outGridRes else int(min(xmax-xmin,ymax-ymin)/250)
        grid_x,grid_y       = np.arange(xmin, xmax+grid_res, grid_res), np.arange(ymax, ymin-grid_res, -grid_res)

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical',verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y,backend='C',n_closest_points=12)

Daniel Scheffler's avatar
Daniel Scheffler committed
741
        if self.CPUs is None or self.CPUs>1:
742
            fName_out = fName_out if fName_out else \
743
                "Kriging__%s__grid%s_ws%s_%s.tif" %(attrName,self.grid_res, self.COREG_obj.win_size_XY,tilepos)
744
745
        else:
            fName_out = fName_out if fName_out else \
746
747
                "Kriging__%s__grid%s_ws%s.tif" %(attrName,self.grid_res, self.COREG_obj.win_size_XY)
        path_out  = get_generic_outpath(dir_out=self.dir_out, fName_out=fName_out)
748
749
750
751
752
753
754
755
        print('Writing %s ...' %path_out)
        # add a half pixel grid points are centered on the output pixels
        xmin,ymin,xmax,ymax = xmin-grid_res/2,ymin-grid_res/2,xmax+grid_res/2,ymax+grid_res/2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res), prj=self.COREG_obj.shift.prj)

        return zvalues


756
    def _Kriging_mp(self, args_kwargs_dict):
757
        args   = args_kwargs_dict.get('args'  ,[])
758
759
        kwargs = args_kwargs_dict.get('kwargs',[])

760
        return self._Kriging_sp(*args, **kwargs)
761
762
763



764
class Tie_Point_Refiner(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
765
766
    def __init__(self, GDF, min_reliability=60, rs_max_outlier=10, rs_tolerance=2.5, rs_max_iter=15,
                      rs_exclude_previous_outliers=True, rs_timeout=20, q=False):
767
        """A class for performing outlier detection.
Daniel Scheffler's avatar
Daniel Scheffler committed
768

769
770
        :param GDF:                             GeoDataFrame like TiePointGrid.CoRegPoints_table containing all tie
                                                points to be filtered and the corresponding metadata
Daniel Scheffler's avatar
Daniel Scheffler committed
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
        :param min_reliability:                 <float, int> minimum threshold for previously computed tie X/Y shift
                                                reliability (default: 60%)
        :param rs_max_outlier:                  <float, int> RANSAC: maximum percentage of outliers to be detected
                                                (default: 10%)
        :param rs_tolerance:                    <float, int> RANSAC: percentage tolerance for max_outlier_percentage
                                                (default: 2.5%)
        :param rs_max_iter:                     <int> RANSAC: maximum iterations for finding the best RANSAC threshold
                                                (default: 15)
        :param rs_exclude_previous_outliers:    <bool> RANSAC: whether to exclude points that have been flagged as
                                                outlier by earlier filtering (default:True)
        :param rs_timeout:                      <float, int> RANSAC: timeout for iteration loop in seconds (default: 20)

        :param q:
        """

        self.GDF = GDF.copy()
        self.min_reliability = min_reliability
        self.rs_max_outlier_percentage = rs_max_outlier
        self.rs_tolerance = rs_tolerance
        self.rs_max_iter = rs_max_iter
        self.rs_exclude_previous_outliers = rs_exclude_previous_outliers
        self.rs_timeout = rs_timeout
        self.q = q
        self.new_cols = []
795
796
797
        self.ransac_model_robust = None


798
    def run_filtering(self, level=3):
799
800
        """Filter tie points used for shift correction.

801
        :param level:   tie point filter level (default: 3).
802
803
804
805
806
807
808
809
                        NOTE: lower levels are also included if a higher level is chosen
                            - Level 0: no tie point filtering
                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                reliability according to internal tests
                            - Level 2: SSIM filtering - filters all tie points out where shift
                                correction does not increase image similarity within matching window
                                (measured by mean structural similarity index)
                            - Level 3: RANSAC outlier detection
Daniel Scheffler's avatar
Daniel Scheffler committed
810
811
812
813

        :return:
        """

814
815
        # TODO catch empty GDF

816
        # RELIABILITY filtering
817
        if level>0:
818
            marked_recs = GeoSeries(self._reliability_thresholding())
819
820
            self.GDF['L1_OUTLIER'] = marked_recs
            self.new_cols.append('L1_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
821

822
823
824
            if not self.q:
                print('%s tie points flagged by level 1 filtering (reliability).' % (len(marked_recs[marked_recs==True])))

Daniel Scheffler's avatar
Daniel Scheffler committed
825

826
        # SSIM filtering
827
        if level>1:
828
            marked_recs = GeoSeries(self._SSIM_filtering())
829
830
            self.GDF['L2_OUTLIER'] = marked_recs
            self.new_cols.append('L2_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
831

832
833
834
            if not self.q:
                print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs==True])))

Daniel Scheffler's avatar
Daniel Scheffler committed
835

836
        # RANSAC filtering
837
        if level>2:
Daniel Scheffler's avatar
Daniel Scheffler committed
838
839
840
841
842
            # exclude previous outliers
            ransacInGDF = self.GDF[self.GDF[self.new_cols].any(axis=1) == False].copy()\
                            if self.rs_exclude_previous_outliers else self.GDF

            if len(ransacInGDF)>4:
843
                # running RANSAC with less than four tie points makes no sense
Daniel Scheffler's avatar
Daniel Scheffler committed
844
845

                marked_recs = GeoSeries(self._RANSAC_outlier_detection(ransacInGDF))
846
                self.GDF['L3_OUTLIER'] = marked_recs.tolist() # we need to join a list here because otherwise it's merged by the 'index' column
Daniel Scheffler's avatar
Daniel Scheffler committed
847

848
                if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
849
                    print('%s tie points flagged by level 3 filtering (RANSAC)' % (len(marked_recs[marked_recs == True])))
850
851
852
853
            else:
                print('RANSAC skipped because too less valid tie points have been found.')
                self.GDF['L3_OUTLIER'] = False

854
            self.new_cols.append('L3_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
855

856
857
858
859
860
861
862

        self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
        self.new_cols.append('OUTLIER')

        return self.GDF, self.new_cols


Daniel Scheffler's avatar
Daniel Scheffler committed
863
    def _reliability_thresholding(self):
864
        """Exclude all records where estimated reliability of the calculated shifts is below the given threshold."""
865

Daniel Scheffler's avatar
Daniel Scheffler committed
866
        return self.GDF.RELIABILITY < self.min_reliability
867
868
869


    def _SSIM_filtering(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
870
        """Exclude all records where SSIM decreased."""
871
872
873
874
875

        #ssim_diff  = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE'])

        #self.GDF.SSIM_IMPROVED = self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1)

876
877
878
        return self.GDF.SSIM_IMPROVED == False


Daniel Scheffler's avatar
Daniel Scheffler committed
879
880
    def _RANSAC_outlier_detection(self, inGDF):
        """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm."""
881

Daniel Scheffler's avatar
Daniel Scheffler committed
882
883
        src_coords = np.array(inGDF[['X_UTM', 'Y_UTM']])
        xyShift    = np.array(inGDF[['X_SHIFT_M', 'Y_SHIFT_M']])
884
885
886
887
888
        est_coords = src_coords + xyShift

        for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']):
            assert co.ndim==2 and co.shape[1]==2, "'%s' must have shape [Nx2]. Got shape %s."%(n, co.shape)

Daniel Scheffler's avatar
Daniel Scheffler committed
889
890
        if not 0 < self.rs_max_outlier_percentage < 100: raise ValueError
        min_inlier_percentage = 100-self.rs_max_outlier_percentage
891
892
893
894
895
896

        class PolyTF_1(PolynomialTransform):
            def estimate(*data):
                return PolynomialTransform.estimate(*data, order=1)

        # robustly estimate affine transform model with RANSAC
897
        # eliminates not more than the given maximum outlier percentage of the tie points
898
899
900

        model_robust, inliers = None, None
        count_inliers         = None
901
        th                    = 5  # start RANSAC threshold
902
903
904
905
        th_checked            = {} # dict of thresholds that already have been tried + calculated inlier percentage
        th_substract          = 2
        count_iter            = 0
        time_start            = time.time()
906
        ideal_count           = min_inlier_percentage * src_coords.shape[0] / 100
907

908
        # optimize RANSAC threshold so that it marks not much more or less than the given outlier percentage
909
910
        while True:
            if th_checked:
911
912
                th_too_strict = count_inliers < ideal_count # True if too less inliers remaining

913
914
915
916
917
918
                # calculate new theshold using old increment (but ensure th_new>0 by adjusting increment if needed)
                th_new = 0
                while th_new <= 0:
                    th_new = th+th_substract if th_too_strict else th-th_substract
                    if th_new <= 0:
                        th_substract /=2
919
920
921
922
923
924
925
926

                # check if calculated new threshold has been used before
                th_already_checked = th_new in th_checked.keys()

                # if yes, decrease increment and recalculate new threshold
                th_substract       = th_substract if not th_already_checked else th_substract / 2
                th                 = th_new if not th_already_checked else \
                                        (th+th_substract if th_too_strict else th-th_substract)
927

928
            # RANSAC call
929
            # model_robust, inliers = ransac((src, dst), PolynomialTransform, min_samples=3,
930
931
932
933
934
935
            if src_coords.size and est_coords.size:
                model_robust, inliers = \
                    ransac((src_coords, est_coords), AffineTransform,
                           min_samples        = 6,
                           residual_threshold = th,
                           max_trials         = 2000,
Daniel Scheffler's avatar
Daniel Scheffler committed
936
937
                           stop_sample_num    = int((min_inlier_percentage-self.rs_tolerance) /100*src_coords.shape[0]),
                           stop_residuals_sum = int((self.rs_max_outlier_percentage-self.rs_tolerance)/100*src_coords.shape[0])
938
939
940
941
942
                           )
            else:
                inliers = np.array([])
                break

943
            count_inliers  = np.count_nonzero(inliers)
944

945
            th_checked[th] = count_inliers / src_coords.shape[0] * 100
946
            #print(th,'\t', th_checked[th], )
Daniel Scheffler's avatar
Daniel Scheffler committed
947
            if min_inlier_percentage-self.rs_tolerance < th_checked[th] < min_inlier_percentage+self.rs_tolerance:
948
949
                #print('in tolerance')
                break
Daniel Scheffler's avatar
Daniel Scheffler committed
950
            if count_iter > self.rs_max_iter or time.time()-time_start > self.rs_timeout:
951
952
                break # keep last values and break while loop

953
954
            count_iter+=1

Daniel Scheffler's avatar
Daniel Scheffler committed
955
        outliers = inliers == False if inliers is not None and inliers.size else np.array([])
956

Daniel Scheffler's avatar
Daniel Scheffler committed
957
        if inGDF.empty or outliers is None or (isinstance(outliers, list) and not outliers) or \
958
                (isinstance(outliers, np.ndarray) and not outliers.size):
959
            gs              = GeoSeries([False]*len(self.GDF))
Daniel Scheffler's avatar
Daniel Scheffler committed
960
961
        elif len(inGDF) < len(self.GDF):
            inGDF['outliers'] = outliers
962
            fullGDF         = GeoDataFrame(self.GDF['POINT_ID'])
Daniel Scheffler's avatar
Daniel Scheffler committed
963
            fullGDF         = fullGDF.merge(inGDF[['POINT_ID', 'outliers']], on='POINT_ID', how="outer")
964
965
            #fullGDF.outliers.copy()[~fullGDF.POINT_ID.isin(GDF.POINT_ID)] = False
            fullGDF         = fullGDF.fillna(False) # NaNs are due to exclude_previous_outliers
966
967
968
969
970
971
            gs              = fullGDF['outliers']
        else:
            gs              = GeoSeries(outliers)

        assert len(gs)==len(self.GDF), 'RANSAC output validation failed.'

972
        self.ransac_model_robust = model_robust
973
974

        return gs
975