Tie_Point_Grid.py 45.9 KB
Newer Older
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-
__author__='Daniel Scheffler'

import collections
import multiprocessing
import os
import warnings
8
import time
9
10

# custom
11
12
13
14
try:
    import gdal
except ImportError:
    from osgeo import gdal
15
import numpy as np
16
from matplotlib import pyplot as plt
17
from geopandas         import GeoDataFrame, GeoSeries
18
19
from pykrige.ok        import OrdinaryKriging
from shapely.geometry  import Point
20
21
from skimage.measure   import points_in_poly, ransac
from skimage.transform import AffineTransform, PolynomialTransform
22
23

# internal modules
24
25
from .CoReg import COREG
from . import io as IO
26
27
28
29
from py_tools_ds.geo.projection          import isProjectedOrGeographic, get_UTMzone, dict_to_proj4, proj4_to_WKT
from py_tools_ds.io.pathgen              import get_generic_outpath
from py_tools_ds.processing.progress_mon import ProgressBar
from py_tools_ds.geo.vector.conversion   import points_to_raster
30
from geoarray import GeoArray
31
32
33
34
35
36
37



global_shared_imref    = None
global_shared_im2shift = None


38
39
class Tie_Point_Grid(object):
    """See help(Tie_Point_Grid) for documentation!"""
40

41
42
    def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
                 tieP_filter_level=2, dir_out=None, CPUs=None, progress=True, v=False, q=False):
43

44
45
46
        """Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
        are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
        correction performs a polynomial transformation using te calculated shifts of each point in the grid as GCPs.
47
        Thus 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target image.
48

49
        :param COREG_obj(object):       an instance of COREG class
50
        :param grid_res:                grid resolution in pixels of the target image (x-direction)
51
        :param max_points(int):         maximum number of points used to find coregistration tie points
52
53
54
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
Daniel Scheffler's avatar
Daniel Scheffler committed
55
        :param outFillVal(int):         if given the generated tie points grid is filled with this value in case
56
                                        no match could be found during co-registration (default: -9999)
57
58
59
60
61
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average, mode,
                                                       max, min, med, q1, q3)
                                        default: cubic (highly recommended)
62
63
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 2).
                                        NOTE: lower levels are also included if a higher level is chosen
64
                                            - Level 0: no tie point filtering
65
66
67
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
68
69
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
70
                                            - Level 3: RANSAC outlier detection
71
72
        :param dir_out(str):            output directory to be used for all outputs if nothing else is given
                                        to the individual methods
Daniel Scheffler's avatar
Daniel Scheffler committed
73
        :param CPUs(int):               number of CPUs to use during calculation of tie points grid
74
                                        (default: None, which means 'all CPUs available')
75
        :param progress(bool):          show progress bars (default: True)
76
77
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
78
        """
79
80
81

        if not isinstance(COREG_obj, COREG): raise ValueError("'COREG_obj' must be an instance of COREG class.")

82
83
        self.COREG_obj         = COREG_obj
        self.grid_res          = grid_res
84
        self.max_points        = max_points
85
86
87
88
89
90
91
92
        self.outFillVal        = outFillVal
        self.rspAlg_calc       = resamp_alg_calc
        self.tieP_filter_level = tieP_filter_level
        self.dir_out           = dir_out
        self.CPUs              = CPUs
        self.v                 = v
        self.q                 = q        if not v else False # overridden by v
        self.progress          = progress if not q else False # overridden by q
93

94
95
        self.ref               = self.COREG_obj.ref
        self.shift             = self.COREG_obj.shift
96
97

        self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
98
99
100
        self._CoRegPoints_table           = None # set by self.CoRegPoints_table
        self._GCPList                     = None # set by self.to_GCPList()
        self.kriged                       = None # set by Raster_using_Kriging()
101
102


103
104
    @property
    def CoRegPoints_table(self):
105
106
        """Returns a GeoDataFrame with the columns 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE',
        'Y_WIN_SIZE','X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE' containing all
Daniel Scheffler's avatar
Daniel Scheffler committed
107
        information containing all the results frm coregistration for all points in the tie points grid.
108
        """
109
110
111
112
113
114
        if self._CoRegPoints_table is not None:
            return self._CoRegPoints_table
        else:
            self._CoRegPoints_table = self.get_CoRegPoints_table()
            return self._CoRegPoints_table

115

116
117
118
119
120
121
122
    @CoRegPoints_table.setter
    def CoRegPoints_table(self, CoRegPoints_table):
        self._CoRegPoints_table = CoRegPoints_table


    @property
    def GCPList(self):
123
124
        """Returns a list of GDAL compatible GCP objects.
        """
125
126
127
128
        if self._GCPList:
            return self._GCPList
        else:
            self._GCPList = self.to_GCPList()
129
            return self._GCPList
130
131
132
133
134
135
136
137


    @GCPList.setter
    def GCPList(self, GCPList):
        self._GCPList = GCPList


    def _get_imXY__mapXY_points(self, grid_res):
138
139
140
141
142
143
        """Returns a numpy array containing possible positions for coregistration tie points according to the given
        grid resolution.

        :param grid_res:
        :return:
        """
144
        if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
145
            print('Initializing tie points grid...')
146

147
148
149
        Xarr,Yarr       = np.meshgrid(np.arange(0,self.shift.shape[1],grid_res),
                                      np.arange(0,self.shift.shape[0],grid_res))

Daniel Scheffler's avatar
Daniel Scheffler committed
150
151
        mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr*self.shift.gt[1]
        mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr*abs(self.shift.gt[5])
152
153
154
155
156
157
158
159
160

        XY_points      = np.empty((Xarr.size,2),Xarr.dtype)
        XY_points[:,0] = Xarr.flat
        XY_points[:,1] = Yarr.flat

        XY_mapPoints      = np.empty((mapXarr.size,2),mapXarr.dtype)
        XY_mapPoints[:,0] = mapXarr.flat
        XY_mapPoints[:,1] = mapYarr.flat

Daniel Scheffler's avatar
Daniel Scheffler committed
161
162
        assert XY_points.shape == XY_mapPoints.shape

163
164
165
        return XY_points,XY_mapPoints


166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    def _exclude_bad_XYpos(self, GDF):
        """Excludes all points outside of the image overlap area and all points where the bad data mask is True (if given).

        :param GDF:     <geopandas.GeoDataFrame> must include the columns 'X_UTM' and 'Y_UTM'
        :return:
        """

        # exclude all points outside of overlap area
        inliers = points_in_poly(self.XY_mapPoints,
                                 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
        GDF = GDF[inliers].copy()
        #GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower

        assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!' # FIXME track that


        # exclude all point where bad data mask is True (e.g. points on clouds etc.)
Daniel Scheffler's avatar
Daniel Scheffler committed
183
        orig_len_GDF       = len(GDF) # length of GDF after dropping all points outside the overlap polygon
184
185
186
187
188
189
190
        mapXY              = np.array(GDF.loc[:,['X_UTM','Y_UTM']])
        GDF['REF_BADDATA'] = self.COREG_obj.ref  .mask_baddata.read_pointData(mapXY) \
                                if self.COREG_obj.ref  .mask_baddata is not None else False
        GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY)\
                                if self.COREG_obj.shift.mask_baddata is not None else False
        GDF                = GDF[(GDF['REF_BADDATA']==False) & (GDF['TGT_BADDATA']==False)]
        if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
Daniel Scheffler's avatar
Daniel Scheffler committed
191
192
193
            if not self.q:
                print('According to the provided bad data mask(s) %s points of initially %s have been excluded.'
                      %(orig_len_GDF-len(GDF), orig_len_GDF))
194
195
196
197

        return GDF


198
199
    @staticmethod
    def _get_spatial_shifts(coreg_kwargs):
Daniel Scheffler's avatar
Daniel Scheffler committed
200
        # unpack
201
202
203
        pointID    = coreg_kwargs['pointID']
        fftw_works = coreg_kwargs['fftw_works']
        del coreg_kwargs['pointID'], coreg_kwargs['fftw_works']
204

Daniel Scheffler's avatar
Daniel Scheffler committed
205
        # assertions
206
207
        assert global_shared_imref    is not None
        assert global_shared_im2shift is not None
Daniel Scheffler's avatar
Daniel Scheffler committed
208
209

        # run CoReg
210
        CR = COREG(global_shared_imref, global_shared_im2shift, CPUs=1, **coreg_kwargs)
211
        CR.fftw_works = fftw_works
212
        CR.calculate_spatial_shifts()
Daniel Scheffler's avatar
Daniel Scheffler committed
213
214

        # fetch results
215
        last_err           = CR.tracked_errors[-1] if CR.tracked_errors else None
216
217
218
        win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
        CR_res   = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
                    CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
219
                    CR.shift_reliability, last_err]
220
221

        return [pointID]+CR_res
222
223


224
    def get_CoRegPoints_table(self):
225
226
        assert self.XY_points is not None and self.XY_mapPoints is not None

Daniel Scheffler's avatar
Daniel Scheffler committed
227
        # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM' (convert imCoords to mapCoords
228
        XYarr2PointGeom = np.vectorize(lambda X,Y: Point(X,Y), otypes=[Point])
229
230
231
232
233
234
235
        geomPoints      = np.array(XYarr2PointGeom(self.XY_mapPoints[:,0],self.XY_mapPoints[:,1]))

        if isProjectedOrGeographic(self.COREG_obj.shift.prj)=='geographic':
            crs = dict(ellps='WGS84', datum='WGS84', proj='longlat')
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj)=='projected':
            UTMzone = abs(get_UTMzone(prj=self.COREG_obj.shift.prj))
            south   = get_UTMzone(prj=self.COREG_obj.shift.prj)<0
236
            crs     = dict(ellps='WGS84', datum='WGS84', proj='utm', zone=UTMzone, south=south, units='m', no_defs=True)
237
238
239
240
            if not south: del crs['south']
        else:
            crs = None

241
242
        GDF                          = GeoDataFrame(index=range(len(geomPoints)),crs=crs,
                                                    columns=['geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'])
243
244
        GDF       ['geometry']       = geomPoints
        GDF       ['POINT_ID']       = range(len(geomPoints))
245
        GDF.loc[:,['X_IM' ,'Y_IM' ]] = self.XY_points
246
        GDF.loc[:,['X_UTM','Y_UTM']] = self.XY_mapPoints
247

248
249
        # exclude offsite points and points on bad data mask
        GDF = self._exclude_bad_XYpos(GDF)
250
251
252
        if GDF.empty:
            self.CoRegPoints_table = GDF
            return self.CoRegPoints_table
253

254
        # choose a random subset of points if a maximum number has been given
255
        if self.max_points and len(GDF) > self.max_points:
256
            GDF = GDF.sample(self.max_points).copy()
257

258
259
260
261
        # equalize pixel grids in order to save warping time
        if len(GDF)>100:
            self.COREG_obj.equalize_pixGrids() # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is connected to that

262
263
        # declare global variables needed for self._get_spatial_shifts()
        global global_shared_imref,global_shared_im2shift
264
265
        assert self.ref  .footprint_poly # this also checks for mask_nodata and nodata value
        assert self.shift.footprint_poly
266
267
268

        # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if
        # neighboured matching windows overlap during reading from disk!!
Daniel Scheffler's avatar
Daniel Scheffler committed
269
270
271
        self.ref.cache_array_subset([self.COREG_obj.ref.band4match]) # only sets geoArr._arr_cache; does not change number of bands
        self.shift.cache_array_subset([self.COREG_obj.shift.band4match])

272
273
        global_shared_imref    = self.ref
        global_shared_im2shift = self.shift
274
275

        # get all variations of kwargs for coregistration
276
277
        get_coreg_kwargs = lambda pID, wp: dict(
            pointID            = pID,
278
            fftw_works         = self.COREG_obj.fftw_works,
279
280
281
282
283
            wp                 = wp,
            ws                 = self.COREG_obj.win_size_XY,
            resamp_alg_calc    = self.rspAlg_calc,
            footprint_poly_ref = self.COREG_obj.ref.poly,
            footprint_poly_tgt = self.COREG_obj.shift.poly,
Daniel Scheffler's avatar
Daniel Scheffler committed
284
285
            r_b4match          = self.ref.band4match+1,   # band4match is internally saved as index, starting from 0
            s_b4match          = self.shift.band4match+1, # band4match is internally saved as index, starting from 0
286
287
288
            max_iter           = self.COREG_obj.max_iter,
            max_shift          = self.COREG_obj.max_shift,
            nodata             = (self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
Daniel Scheffler's avatar
Daniel Scheffler committed
289
            force_quadratic_win= self.COREG_obj.force_quadratic_win,
290
291
292
293
294
            binary_ws          = self.COREG_obj.bin_ws,
            v                  = False, # otherwise this would lead to massive console output
            q                  = True,  # otherwise this would lead to massive console output
            ignore_errors      = True
        )
295
296
297
        list_coreg_kwargs = (get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index) # generator

        # run co-registration for whole grid
298
        if self.CPUs is None or self.CPUs>1:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
299
            if not self.q:
300
                cpus = self.CPUs if self.CPUs is not None else multiprocessing.cpu_count()
Daniel Scheffler's avatar
Daniel Scheffler committed
301
                print("Calculating tie points grid (%s points) using %s CPU cores..." %(len(GDF), cpus))
302

303
            with multiprocessing.Pool(self.CPUs) as pool:
304
305
306
307
                if self.q or not self.progress:
                    results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
                else:
                    results = pool.map_async(self._get_spatial_shifts, list_coreg_kwargs, chunksize=1)
308
                    bar     = ProgressBar(prefix='\tprogress:')
309
310
311
                    while True:
                        time.sleep(.1)
                        numberDone = len(GDF)-results._number_left # this does not really represent the remaining tasks but the remaining chunks -> thus chunksize=1
312
313
                        if self.progress:
                            bar.print_progress(percent=numberDone/len(GDF)*100)
314
                        if results.ready():
315
                            results = results.get() # <= this is the line where multiprocessing can freeze if an exception appears within COREG ans is not raised
316
                            break
Daniel Scheffler's avatar
Daniel Scheffler committed
317

318
        else:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
319
            if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
320
                print("Calculating tie points grid (%s points) 1 CPU core..." %len(GDF))
321
            results = np.empty((len(geomPoints),14), np.object)
322
            bar     = ProgressBar(prefix='\tprogress:')
323
            for i,coreg_kwargs in enumerate(list_coreg_kwargs):
324
325
                if self.progress:
                    bar.print_progress((i+1)/len(GDF)*100)
326
327
                results[i,:] = self._get_spatial_shifts(coreg_kwargs)

328
         # merge results with GDF
329
330
331
        records = GeoDataFrame(np.array(results, np.object),
                               columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M',
                                        'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
332
333
                                        'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])

334
335
336
        GDF = GDF.merge(records, on='POINT_ID', how="inner")
        GDF = GDF.fillna(int(self.outFillVal))

337
338
339
        if not self.q:
            print("Found %s matches." % len(GDF[GDF.LAST_ERR == int(self.outFillVal)]))

340
341
        # filter tie points according to given filter level
        if self.tieP_filter_level>0:
342
343
            if not self.q:
                print('Performing validity checks...')
344
            TPR                   = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], q=self.q)
345
346
347
            GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
            GDF                   = GDF.merge(GDF_filt[ ['POINT_ID']+new_columns], on='POINT_ID', how="outer")
        GDF = GDF.fillna(int(self.outFillVal))
348

349
        self.CoRegPoints_table = GDF
350
351
352
353

        return self.CoRegPoints_table


354
355
356
357
    def calc_rmse(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates root mean square error of absolute shifts from the tie point grid.

Daniel Scheffler's avatar
Daniel Scheffler committed
358
        :param include_outliers:    whether to include tie points that have been marked as false-positives (if present)
359
360
361
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
362
        tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == False].copy() if 'OUTLIER' in tbl.columns else tbl
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404

        shifts = np.array(tbl['ABS_SHIFT'])
        shifts_sq = [i * i for i in shifts if i != self.outFillVal]

        return np.sqrt(sum(shifts_sq) / len(shifts_sq))


    def calc_overall_mssim(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates the median value of all MSSIM values contained in tie point grid.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        """

        tbl = self.CoRegPoints_table
        tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == False].copy()

        mssim_col = np.array(tbl['MSSIM'])
        mssim_col = [i * i for i in mssim_col if i != self.outFillVal]

        return float(np.median(mssim_col))


    def plot_shift_distribution(self, include_outliers=True, unit='m', interactive=False, figsize=None, xlim=None,
                                ylim=None, fontsize=12, title='shift distribution'):
        # type: (bool, str, bool, tuple, list, list, int) -> tuple
        """Creates a 2D scatterplot containing the distribution of calculated X/Y-shifts.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        :param unit:                'm' for meters or 'px' for pixels (default: 'm')
        :param interactive:         interactive mode uses plotly for visualization
        :param figsize:             (xdim, ydim)
        :param xlim:                [xmin, xmax]
        :param ylim:                [ymin, ymax]
        :param fontsize:            size of all used fonts
        :param title:               the title to be plotted above the figure
        """

        if not unit in ['m', 'px']:
            raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." %unit)

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
405
406
407
        tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal]
        tbl_il = tbl[tbl['OUTLIER'] == False].copy() if 'OUTLIER' in tbl.columns else tbl
        tbl_ol = tbl[tbl['OUTLIER'] == True].copy() if 'OUTLIER' in tbl.columns else None
408
409
410
411
412
413
414
415
        x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
        y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
        rmse   = self.calc_rmse(include_outliers=False) # always exclude outliers when calculating RMSE
        figsize = figsize if figsize else (10,10)

        if interactive:
            from plotly.offline import iplot, init_notebook_mode
            import plotly.graph_objs as go
Daniel Scheffler's avatar
Daniel Scheffler committed
416
            # FIXME outliers are not plotted
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437

            init_notebook_mode(connected=True)

            # Create a trace
            trace = go.Scatter(
                x=tbl_il[x_attr],
                y=tbl_il[y_attr],
                mode='markers'
            )

            data = [trace]

            # Plot and embed in ipython notebook!
            iplot(data, filename='basic-scatter')

            return None, None

        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111)

Daniel Scheffler's avatar
Daniel Scheffler committed
438
            if include_outliers and 'OUTLIER' in tbl.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
439
                ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives')
440
441
442
443
444
445
446
447
448
449
450
451
            ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')

            # set axis limits
            if not xlim:
                xmax = np.abs(tbl_il[x_attr]).max()
                xlim = [-xmax, xmax]
            if not ylim:
                ymax = np.abs(tbl_il[y_attr]).max()
                ylim = [-ymax, ymax]
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

Daniel Scheffler's avatar
Daniel Scheffler committed
452
            # add text box containing RMSE of plotted shifts
453
454
455
456
            xlim, ylim = ax.get_xlim(), ax.get_ylim()
            plt.text(xlim[1]-(xlim[1]/20),-ylim[1]+(ylim[1]/20),
                     'RMSE:  %s m / %s px' %(np.round(rmse, 2), np.round(rmse/self.shift.xgsd, 2)),
                     ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
457

Daniel Scheffler's avatar
Daniel Scheffler committed
458
            # add grid and increase linewidth of middle line
459
460
461
            plt.grid()
            xgl = ax.get_xgridlines()
            middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
462
            middle_xgl.set_linewidth(2)
463
464
465
            middle_xgl.set_linestyle('-')
            ygl = ax.get_ygridlines()
            middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
466
            middle_ygl.set_linewidth(2)
467
468
            middle_ygl.set_linestyle('-')

Daniel Scheffler's avatar
Daniel Scheffler committed
469
470
            # set title and adjust tick labels
            ax.set_title(title, fontsize=fontsize)
471
472
            [tick.label.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
            [tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
Daniel Scheffler's avatar
Daniel Scheffler committed
473
474
            plt.xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
            plt.ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
475

476
477
            # add legend with labels in the right order
            handles, labels = ax.get_legend_handles_labels()
Daniel Scheffler's avatar
Daniel Scheffler committed
478
479
            leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3)
            leg.get_frame().set_edgecolor('black')
480

481
482
483
484
485
            plt.show()

            return fig, ax


486
487
488
489
490
491
492
    def dump_CoRegPoints_table(self, path_out=None):
        path_out = path_out if path_out else get_generic_outpath(dir_out=self.dir_out,
            fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
        if not self.q:
            print('Writing %s ...' % path_out)
        self.CoRegPoints_table.to_pickle(path_out)
493
494


495
    def to_GCPList(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
496
        # get copy of tie points grid without no data
Daniel Scheffler's avatar
Daniel Scheffler committed
497
498
499
500
501
        try:
            GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
        except AttributeError:
            # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded
            return []
502

503
504
505
        if getattr(GDF,'empty'): # GDF.empty returns AttributeError
            return []
        else:
506
            # exclude all points flagged as outliers
507
508
            if 'OUTLIER' in GDF.columns:
                GDF = GDF[GDF.OUTLIER == False].copy()
509
510
            avail_TP = len(GDF)

511
512
513
514
            if not avail_TP:
                # no point passed all validity checks
                return []

515
516
517
518
519
            if avail_TP>7000:
                GDF = GDF.sample(7000)
                warnings.warn('By far not more than 7000 tie points can be used for warping within a limited '
                              'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen '
                              'out of the %s available tie points.' %avail_TP)
520

521
522
523
524
525
526
527
528
            # calculate GCPs
            GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
            GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
            GDF['GCP']       = GDF.apply(lambda GDF_row:    gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
                                                                     GDF_row.X_IM, GDF_row.Y_IM), axis=1)
            self.GCPList = GDF.GCP.tolist()

            if not self.q:
529
                print('Found %s valid tie points.' %len(self.GCPList))
530
531

            return self.GCPList
532
533


534
    def test_if_singleprocessing_equals_multiprocessing_result(self):
535
536
        self.tieP_filter_level=1 # RANSAC filtering always produces different results because it includes random sampling

Daniel Scheffler's avatar
Daniel Scheffler committed
537
        self.CPUs = None
538
        dataframe = self.get_CoRegPoints_table()
539
540
        mp_out    = np.empty_like(dataframe.values)
        mp_out[:] = dataframe.values
Daniel Scheffler's avatar
Daniel Scheffler committed
541
        self.CPUs = 1
542
        dataframe = self.get_CoRegPoints_table()
543
544
545
546
547
548
        sp_out    = np.empty_like(dataframe.values)
        sp_out[:] = dataframe.values

        return np.array_equal(sp_out,mp_out)


549
550
    def _get_line_by_PID(self, PID):
        return self.CoRegPoints_table.loc[PID, :]
551
552


553
    def _get_lines_by_PIDs(self, PIDs):
554
        assert isinstance(PIDs,list)
555
        lines = np.zeros((len(PIDs),self.CoRegPoints_table.shape[1]))
556
        for i,PID in enumerate(PIDs):
557
            lines[i,:] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
558
559
560
        return lines


561
    def to_PointShapefile(self, path_out=None, skip_nodata=True, skip_nodata_col ='ABS_SHIFT'):
562
        # type: (str, bool, str) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
563
        """Writes the calculated tie points grid to a point shapefile containing
564
        Tie_Point_Grid.CoRegPoints_table as attribute table. This shapefile can easily be displayed using GIS software.
565
566
567

        :param path_out:        <str> the output path. If not given, it is automatically defined.
        :param skip_nodata:     <bool> whether to skip all points where no valid match could be found
568
        :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to
569
570
                                identify points where no valid match could be found
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
571
572
573
574
575
576
577
578
579
        GDF      = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

        # replace boolean values (cannot be written)
        for col in GDF2pass.columns:
            if GDF2pass[col].dtype == np.bool:
                GDF2pass[col] = GDF2pass[col].astype(int)
        GDF2pass = GDF2pass.replace(False, 0) # replace all remaining booleans where dtype is not np.bool but np.object
        GDF2pass = GDF2pass.replace(True,  1)
580

581
582
583
        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
584
585
        if not self.q:
            print('Writing %s ...' %path_out)
586
587
588
        GDF2pass.to_file(path_out)


589
    def _to_PointShapefile(self, skip_nodata=True, skip_nodata_col ='ABS_SHIFT'):
Daniel Scheffler's avatar
Daniel Scheffler committed
590
591
        warnings.warn(DeprecationWarning("'_tiepoints_grid_to_PointShapefile' is deprecated." # TODO delete if other method validated
                                         " 'tiepoints_grid_to_PointShapefile' is much faster."))
592
        GDF            = self.CoRegPoints_table
593
594
595
596
597
        GDF2pass       = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]
        shapely_points = GDF2pass['geometry'].values.tolist()
        attr_dicts     = [collections.OrderedDict(zip(GDF2pass.columns,GDF2pass.loc[i].values)) for i in GDF2pass.index]


598
        fName_out = "CoRegPoints_grid%s_ws%s.shp" %(self.grid_res, self.COREG_obj.win_size_XY)
599
600
601
602
        path_out = os.path.join(self.dir_out, fName_out)
        IO.write_shp(path_out, shapely_points, prj=self.COREG_obj.shift.prj, attrDict=attr_dicts)


603
    def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
604
        # type: (str) -> GeoArray
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        """Saves the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield
        (e.g. using ArcGIS)

        :param path_out:    <str> the output path. If not given, it is automatically defined.
        :param fmt:         <str> output raster format string
        :param mode:        <str>   'uv': outputs X-/Y shifts
                                    'md': outputs magnitude and direction
        """

        assert mode in ['uv','md'], \
            "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' (outputs magnitude and direction)'. Got %s." %mode
        attr_b1 = 'X_SHIFT_M' if mode=='uv' else 'ABS_SHIFT'
        attr_b2 = 'Y_SHIFT_M' if mode=='uv' else 'ANGLE'

        xshift_arr, gt, prj = points_to_raster(points = self.CoRegPoints_table['geometry'],
                                               values = self.CoRegPoints_table[attr_b1],
                                               tgt_res = self.shift.xgsd * self.grid_res,
622
623
                                               prj = proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal = self.outFillVal)
624
625
626
627

        yshift_arr, gt, prj = points_to_raster(points = self.CoRegPoints_table['geometry'],
                                               values = self.CoRegPoints_table[attr_b2],
                                               tgt_res = self.shift.xgsd * self.grid_res,
628
629
                                               prj = proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal = self.outFillVal)
630
631
632
633
634
635
636
637
638
639
640
641

        out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)

        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))

        out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')

        return out_GA


642
643
644
    def to_Raster_using_KrigingOLD(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                                   path_out=None, tilepos=None):
        GDF             = self.CoRegPoints_table
645
646
647
        GDF2pass        = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

        # subset if tilepos is given
648
        rows,cols = tilepos if tilepos else self.shift.shape
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
                                       (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]


        X_coords,Y_coords,ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'],GDF2pass[attrName]

        xmin,ymin,xmax,ymax = GDF2pass.total_bounds

        grid_res            = outGridRes if outGridRes else int(min(xmax-xmin,ymax-ymin)/250)
        grid_x,grid_y       = np.arange(xmin, xmax+grid_res, grid_res), np.arange(ymax, ymin-grid_res, -grid_res)

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical',verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y)#,backend='C',)

665
666
667
        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="Kriging__%s__grid%s_ws(%s_%s).tif"  % (attrName, self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1]))
668
669
670
671
672
673
674
675
        print('Writing %s ...' %path_out)
        # add a half pixel grid points are centered on the output pixels
        xmin,ymin,xmax,ymax = xmin-grid_res/2,ymin-grid_res/2,xmax+grid_res/2,ymax+grid_res/2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res), prj=self.COREG_obj.shift.prj)

        return zvalues


676
677
    def Raster_using_Kriging(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                             fName_out=None, tilepos=None, tilesize=500, mp=None):
678

679
680
681
        mp = False if self.CPUs==1 else True
        self._Kriging_sp(attrName, skip_nodata=skip_nodata, skip_nodata_col=skip_nodata_col,
                         outGridRes=outGridRes, fName_out=fName_out, tilepos=tilepos)
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704

        # if mp:
        #     tilepositions = UTL.get_image_tileborders([tilesize,tilesize],self.tgt_shape)
        #     args_kwargs_dicts=[]
        #     for tp in tilepositions:
        #         kwargs_dict = {'skip_nodata':skip_nodata,'skip_nodata_col':skip_nodata_col,'outGridRes':outGridRes,
        #                        'fName_out':fName_out,'tilepos':tp}
        #         args_kwargs_dicts.append({'args':[attrName],'kwargs':kwargs_dict})
        #     # self.kriged=[]
        #     # for i in args_kwargs_dicts:
        #     #     res = self.Kriging_mp(i)
        #     #     self.kriged.append(res)
        #     #     print(res)
        #
        #     with multiprocessing.Pool() as pool:
        #        self.kriged = pool.map(self.Kriging_mp,args_kwargs_dicts)
        # else:
        #     self.Kriging_sp(attrName,skip_nodata=skip_nodata,skip_nodata_col=skip_nodata_col,
        #                     outGridRes=outGridRes,fName_out=fName_out,tilepos=tilepos)
        res = self.kriged if mp else None
        return res


705
706
707
    def _Kriging_sp(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                    fName_out=None, tilepos=None):
        GDF             = self.CoRegPoints_table
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        GDF2pass        = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

#         # subset if tilepos is given
# #        overlap_factor =
#         rows,cols = tilepos if tilepos else self.tgt_shape
#         xvals, yvals = np.sort(GDF2pass['X_IM'].values.flat),np.sort(GDF2pass['Y_IM'].values.flat)
#         cS,cE = UTL.find_nearest(xvals,cols[0],'off',1), UTL.find_nearest(xvals,cols[1],'on',1)
#         rS,rE = UTL.find_nearest(yvals,rows[0],'off',1), UTL.find_nearest(yvals,rows[1],'on',1)
#         # GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
#         #                                (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]
#         GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cS)&(GDF2pass['X_IM']<=cE)&
#                                        (GDF2pass['Y_IM']>=rS)&(GDF2pass['Y_IM']<=rE)]

        X_coords,Y_coords,ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'],GDF2pass[attrName]

        xmin,ymin,xmax,ymax = GDF2pass.total_bounds

        grid_res            = outGridRes if outGridRes else int(min(xmax-xmin,ymax-ymin)/250)
        grid_x,grid_y       = np.arange(xmin, xmax+grid_res, grid_res), np.arange(ymax, ymin-grid_res, -grid_res)

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical',verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y,backend='C',n_closest_points=12)

Daniel Scheffler's avatar
Daniel Scheffler committed
733
        if self.CPUs is None or self.CPUs>1:
734
            fName_out = fName_out if fName_out else \
735
                "Kriging__%s__grid%s_ws%s_%s.tif" %(attrName,self.grid_res, self.COREG_obj.win_size_XY,tilepos)
736
737
        else:
            fName_out = fName_out if fName_out else \
738
739
                "Kriging__%s__grid%s_ws%s.tif" %(attrName,self.grid_res, self.COREG_obj.win_size_XY)
        path_out  = get_generic_outpath(dir_out=self.dir_out, fName_out=fName_out)
740
741
742
743
744
745
746
747
        print('Writing %s ...' %path_out)
        # add a half pixel grid points are centered on the output pixels
        xmin,ymin,xmax,ymax = xmin-grid_res/2,ymin-grid_res/2,xmax+grid_res/2,ymax+grid_res/2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res), prj=self.COREG_obj.shift.prj)

        return zvalues


748
    def _Kriging_mp(self, args_kwargs_dict):
749
        args   = args_kwargs_dict.get('args'  ,[])
750
751
        kwargs = args_kwargs_dict.get('kwargs',[])

752
        return self._Kriging_sp(*args, **kwargs)
753
754
755



756
class Tie_Point_Refiner(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
    def __init__(self, GDF, min_reliability=60, rs_max_outlier=10, rs_tolerance=2.5, rs_max_iter=15,
                      rs_exclude_previous_outliers=True, rs_timeout=20, q=False):
        """

        :param GDF:                             GeoDataFrame like TiePointGrid.CoRegPoints_table
        :param min_reliability:                 <float, int> minimum threshold for previously computed tie X/Y shift
                                                reliability (default: 60%)
        :param rs_max_outlier:                  <float, int> RANSAC: maximum percentage of outliers to be detected
                                                (default: 10%)
        :param rs_tolerance:                    <float, int> RANSAC: percentage tolerance for max_outlier_percentage
                                                (default: 2.5%)
        :param rs_max_iter:                     <int> RANSAC: maximum iterations for finding the best RANSAC threshold
                                                (default: 15)
        :param rs_exclude_previous_outliers:    <bool> RANSAC: whether to exclude points that have been flagged as
                                                outlier by earlier filtering (default:True)
        :param rs_timeout:                      <float, int> RANSAC: timeout for iteration loop in seconds (default: 20)

        :param q:
        """

        self.GDF = GDF.copy()
        self.min_reliability = min_reliability
        self.rs_max_outlier_percentage = rs_max_outlier
        self.rs_tolerance = rs_tolerance
        self.rs_max_iter = rs_max_iter
        self.rs_exclude_previous_outliers = rs_exclude_previous_outliers
        self.rs_timeout = rs_timeout
        self.q = q
        self.new_cols = []
786
787
788
        self.ransac_model_robust = None


Daniel Scheffler's avatar
Daniel Scheffler committed
789
790
791
792
793
794
795
    def run_filtering(self, level=2):
        """
        :param level:

        :return:
        """

796
797
        # TODO catch empty GDF

798
        # RELIABILITY filtering
799
        if level>0:
800
            marked_recs = GeoSeries(self._reliability_thresholding())
801
802
            self.GDF['L1_OUTLIER'] = marked_recs
            self.new_cols.append('L1_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
803

804
805
806
            if not self.q:
                print('%s tie points flagged by level 1 filtering (reliability).' % (len(marked_recs[marked_recs==True])))

Daniel Scheffler's avatar
Daniel Scheffler committed
807

808
        # SSIM filtering
809
        if level>1:
810
            marked_recs = GeoSeries(self._SSIM_filtering())
811
812
            self.GDF['L2_OUTLIER'] = marked_recs
            self.new_cols.append('L2_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
813

814
815
816
            if not self.q:
                print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs==True])))

Daniel Scheffler's avatar
Daniel Scheffler committed
817

818
        # RANSAC filtering
819
        if level>2:
Daniel Scheffler's avatar
Daniel Scheffler committed
820
821
822
823
824
            # exclude previous outliers
            ransacInGDF = self.GDF[self.GDF[self.new_cols].any(axis=1) == False].copy()\
                            if self.rs_exclude_previous_outliers else self.GDF

            if len(ransacInGDF)>4:
825
                # running RANSAC with less than four tie points makes no sense
Daniel Scheffler's avatar
Daniel Scheffler committed
826
827

                marked_recs = GeoSeries(self._RANSAC_outlier_detection(ransacInGDF))
828
                self.GDF['L3_OUTLIER'] = marked_recs.tolist() # we need to join a list here because otherwise it's merged by the 'index' column
Daniel Scheffler's avatar
Daniel Scheffler committed
829

830
                if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
831
                    print('%s tie points flagged by level 3 filtering (RANSAC)' % (len(marked_recs[marked_recs == True])))
832
833
834
835
            else:
                print('RANSAC skipped because too less valid tie points have been found.')
                self.GDF['L3_OUTLIER'] = False

836
            self.new_cols.append('L3_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
837

838
839
840
841
842
843
844

        self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
        self.new_cols.append('OUTLIER')

        return self.GDF, self.new_cols


Daniel Scheffler's avatar
Daniel Scheffler committed
845
    def _reliability_thresholding(self):
846
        """Exclude all records where estimated reliability of the calculated shifts is below the given threshold."""
847

Daniel Scheffler's avatar
Daniel Scheffler committed
848
        return self.GDF.RELIABILITY < self.min_reliability
849
850
851


    def _SSIM_filtering(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
852
        """Exclude all records where SSIM decreased."""
853
854
855
856
857

        #ssim_diff  = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE'])

        #self.GDF.SSIM_IMPROVED = self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1)

858
859
860
        return self.GDF.SSIM_IMPROVED == False


Daniel Scheffler's avatar
Daniel Scheffler committed
861
862
    def _RANSAC_outlier_detection(self, inGDF):
        """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm."""
863

Daniel Scheffler's avatar
Daniel Scheffler committed
864
865
        src_coords = np.array(inGDF[['X_UTM', 'Y_UTM']])
        xyShift    = np.array(inGDF[['X_SHIFT_M', 'Y_SHIFT_M']])
866
867
868
869
870
        est_coords = src_coords + xyShift

        for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']):
            assert co.ndim==2 and co.shape[1]==2, "'%s' must have shape [Nx2]. Got shape %s."%(n, co.shape)

Daniel Scheffler's avatar
Daniel Scheffler committed
871
872
        if not 0 < self.rs_max_outlier_percentage < 100: raise ValueError
        min_inlier_percentage = 100-self.rs_max_outlier_percentage
873
874
875
876
877
878

        class PolyTF_1(PolynomialTransform):
            def estimate(*data):
                return PolynomialTransform.estimate(*data, order=1)

        # robustly estimate affine transform model with RANSAC
879
        # eliminates not more than the given maximum outlier percentage of the tie points
880
881
882

        model_robust, inliers = None, None
        count_inliers         = None
883
        th                    = 5  # start RANSAC threshold
884
885
886
887
        th_checked            = {} # dict of thresholds that already have been tried + calculated inlier percentage
        th_substract          = 2
        count_iter            = 0
        time_start            = time.time()
888
        ideal_count           = min_inlier_percentage * src_coords.shape[0] / 100
889

890
        # optimize RANSAC threshold so that it marks not much more or less than the given outlier percentage
891
892
        while True:
            if th_checked:
893
894
                th_too_strict = count_inliers < ideal_count # True if too less inliers remaining

895
896
897
898
899
900
                # calculate new theshold using old increment (but ensure th_new>0 by adjusting increment if needed)
                th_new = 0
                while th_new <= 0:
                    th_new = th+th_substract if th_too_strict else th-th_substract
                    if th_new <= 0:
                        th_substract /=2
901
902
903
904
905
906
907
908

                # check if calculated new threshold has been used before
                th_already_checked = th_new in th_checked.keys()

                # if yes, decrease increment and recalculate new threshold
                th_substract       = th_substract if not th_already_checked else th_substract / 2
                th                 = th_new if not th_already_checked else \
                                        (th+th_substract if th_too_strict else th-th_substract)
909

910
            # RANSAC call
911
            # model_robust, inliers = ransac((src, dst), PolynomialTransform, min_samples=3,
912
913
914
915
916
917
            if src_coords.size and est_coords.size:
                model_robust, inliers = \
                    ransac((src_coords, est_coords), AffineTransform,
                           min_samples        = 6,
                           residual_threshold = th,
                           max_trials         = 2000,
Daniel Scheffler's avatar
Daniel Scheffler committed
918
919
                           stop_sample_num    = int((min_inlier_percentage-self.rs_tolerance) /100*src_coords.shape[0]),
                           stop_residuals_sum = int((self.rs_max_outlier_percentage-self.rs_tolerance)/100*src_coords.shape[0])
920
921
922
923
924
                           )
            else:
                inliers = np.array([])
                break

925
            count_inliers  = np.count_nonzero(inliers)
926

927
            th_checked[th] = count_inliers / src_coords.shape[0] * 100
928
            #print(th,'\t', th_checked[th], )
Daniel Scheffler's avatar
Daniel Scheffler committed
929
            if min_inlier_percentage-self.rs_tolerance < th_checked[th] < min_inlier_percentage+self.rs_tolerance:
930
931
                #print('in tolerance')
                break
Daniel Scheffler's avatar
Daniel Scheffler committed
932
            if count_iter > self.rs_max_iter or time.time()-time_start > self.rs_timeout:
933
934
                break # keep last values and break while loop

935
936
            count_iter+=1

Daniel Scheffler's avatar
Daniel Scheffler committed
937
        outliers = inliers == False if inliers is not None and inliers.size else np.array([])
938

Daniel Scheffler's avatar
Daniel Scheffler committed
939
        if inGDF.empty or outliers is None or (isinstance(outliers, list) and not outliers) or \
940
                (isinstance(outliers, np.ndarray) and not outliers.size):
941
            gs              = GeoSeries([False]*len(self.GDF))
Daniel Scheffler's avatar
Daniel Scheffler committed
942
943
        elif len(inGDF) < len(self.GDF):
            inGDF['outliers'] = outliers
944
            fullGDF         = GeoDataFrame(self.GDF['POINT_ID'])
Daniel Scheffler's avatar
Daniel Scheffler committed
945
            fullGDF         = fullGDF.merge(inGDF[['POINT_ID', 'outliers']], on='POINT_ID', how="outer")
946
947
            #fullGDF.outliers.copy()[~fullGDF.POINT_ID.isin(GDF.POINT_ID)] = False
            fullGDF         = fullGDF.fillna(False) # NaNs are due to exclude_previous_outliers
948
949
950
951
952
953
            gs              = fullGDF['outliers']
        else:
            gs              = GeoSeries(outliers)

        assert len(gs)==len(self.GDF), 'RANSAC output validation failed.'

954
        self.ransac_model_robust = model_robust
955
956

        return gs
957