Tie_Point_Grid.py 49 KB
Newer Older
1
2
3
4
5
6
# -*- coding: utf-8 -*-

import collections
import multiprocessing
import os
import warnings
7
import time
8
9

# custom
10
11
12
13
try:
    import gdal
except ImportError:
    from osgeo import gdal
14
import numpy as np
15
from matplotlib import pyplot as plt
16
17
18
from geopandas import GeoDataFrame, GeoSeries
from shapely.geometry import Point
from skimage.measure import points_in_poly, ransac
19
from skimage.transform import AffineTransform, PolynomialTransform
20
21

# internal modules
22
23
from .CoReg import COREG
from . import io as IO
24
25
from py_tools_ds.geo.projection import isProjectedOrGeographic, get_UTMzone, dict_to_proj4, proj4_to_WKT
from py_tools_ds.io.pathgen import get_generic_outpath
26
from py_tools_ds.processing.progress_mon import ProgressBar
27
from py_tools_ds.geo.vector.conversion import points_to_raster
28
from geoarray import GeoArray
29

30
__author__ = 'Daniel Scheffler'
31

32
global_shared_imref = None
33
34
35
global_shared_im2shift = None


36
37
38
39
40
41
42
43
44
45
46
def mp_initializer(imref, imtgt):
    """Declare global variables needed for self._get_spatial_shifts()

    :param imref:   reference image
    :param imtgt:   target image
    """
    global global_shared_imref, global_shared_im2shift
    global_shared_imref = imref
    global_shared_im2shift = imtgt


47
48
class Tie_Point_Grid(object):
    """See help(Tie_Point_Grid) for documentation!"""
49

50
    def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
51
52
                 tieP_filter_level=3, outlDetect_settings=None, dir_out=None, CPUs=None, progress=True, v=False,
                 q=False):
53

54
55
56
        """Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
        are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
        correction performs a polynomial transformation using te calculated shifts of each point in the grid as GCPs.
57
        Thus 'Tie_Point_Grid' can be used to correct for locally varying geometric distortions of the target image.
58

59
        :param COREG_obj(object):       an instance of COREG class
60
        :param grid_res:                grid resolution in pixels of the target image (x-direction)
61
        :param max_points(int):         maximum number of points used to find coregistration tie points
62
63
64
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
Daniel Scheffler's avatar
Daniel Scheffler committed
65
        :param outFillVal(int):         if given the generated tie points grid is filled with this value in case
66
                                        no match could be found during co-registration (default: -9999)
67
68
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
69
70
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average,
                                                           mode, max, min, med, q1, q3)
71
                                        default: cubic (highly recommended)
72
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 3).
73
                                        NOTE: lower levels are also included if a higher level is chosen
74
                                            - Level 0: no tie point filtering
75
76
77
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
78
79
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
80
                                            - Level 3: RANSAC outlier detection
81
82
83
84
        :param outlDetect_settings      a dictionary with the settings to be passed to
                                        arosics.TiePointGrid.Tie_Point_Refiner. Available keys: min_reliability,
                                        rs_max_outlier, rs_tolerance, rs_max_iter, rs_exclude_previous_outliers,
                                        rs_timeout, q. See documentation there.
85
86
        :param dir_out(str):            output directory to be used for all outputs if nothing else is given
                                        to the individual methods
Daniel Scheffler's avatar
Daniel Scheffler committed
87
        :param CPUs(int):               number of CPUs to use during calculation of tie points grid
88
                                        (default: None, which means 'all CPUs available')
89
        :param progress(bool):          show progress bars (default: True)
90
91
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
92
        """
93

94
95
        if not isinstance(COREG_obj, COREG):
            raise ValueError("'COREG_obj' must be an instance of COREG class.")
96

97
98
99
100
101
        self.COREG_obj = COREG_obj
        self.grid_res = grid_res
        self.max_points = max_points
        self.outFillVal = outFillVal
        self.rspAlg_calc = resamp_alg_calc
102
        self.tieP_filter_level = tieP_filter_level
103
        self.outlDetect_settings = outlDetect_settings if outlDetect_settings else dict(q=q)
104
105
106
107
108
        self.dir_out = dir_out
        self.CPUs = CPUs
        self.v = v
        self.q = q if not v else False  # overridden by v
        self.progress = progress if not q else False  # overridden by q
109

110
111
        self.ref = self.COREG_obj.ref
        self.shift = self.COREG_obj.shift
112

113
        self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        self._CoRegPoints_table = None  # set by self.CoRegPoints_table
        self._GCPList = None  # set by self.to_GCPList()
        self.kriged = None  # set by Raster_using_Kriging()

    mean_x_shift_px = property(lambda self:
                               self.CoRegPoints_table['X_SHIFT_PX'][
                                   self.CoRegPoints_table['X_SHIFT_PX'] != self.outFillVal].mean())
    mean_y_shift_px = property(lambda self:
                               self.CoRegPoints_table['Y_SHIFT_PX'][
                                   self.CoRegPoints_table['Y_SHIFT_PX'] != self.outFillVal].mean())
    mean_x_shift_map = property(lambda self:
                                self.CoRegPoints_table['X_SHIFT_M'][
                                    self.CoRegPoints_table['X_SHIFT_M'] != self.outFillVal].mean())
    mean_y_shift_map = property(lambda self:
                                self.CoRegPoints_table['Y_SHIFT_M'][
                                    self.CoRegPoints_table['Y_SHIFT_M'] != self.outFillVal].mean())
130

131
132
    @property
    def CoRegPoints_table(self):
133
134
        """Returns a GeoDataFrame with the columns 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE',
        'Y_WIN_SIZE','X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE' containing all
135
        information containing all the results from coregistration for all points in the tie points grid.
136
        """
137
138
139
140
141
142
143
144
145
146
147
148
        if self._CoRegPoints_table is not None:
            return self._CoRegPoints_table
        else:
            self._CoRegPoints_table = self.get_CoRegPoints_table()
            return self._CoRegPoints_table

    @CoRegPoints_table.setter
    def CoRegPoints_table(self, CoRegPoints_table):
        self._CoRegPoints_table = CoRegPoints_table

    @property
    def GCPList(self):
149
150
        """Returns a list of GDAL compatible GCP objects.
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
151

152
153
154
155
        if self._GCPList:
            return self._GCPList
        else:
            self._GCPList = self.to_GCPList()
156
            return self._GCPList
157
158
159
160
161
162

    @GCPList.setter
    def GCPList(self, GCPList):
        self._GCPList = GCPList

    def _get_imXY__mapXY_points(self, grid_res):
163
164
165
166
167
168
        """Returns a numpy array containing possible positions for coregistration tie points according to the given
        grid resolution.

        :param grid_res:
        :return:
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
169

170
        if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
171
            print('Initializing tie points grid...')
172

173
174
        Xarr, Yarr = np.meshgrid(np.arange(0, self.shift.shape[1], grid_res),
                                 np.arange(0, self.shift.shape[0], grid_res))
175

176
177
        mapXarr = np.full_like(Xarr, self.shift.gt[0], dtype=np.float64) + Xarr * self.shift.gt[1]
        mapYarr = np.full_like(Yarr, self.shift.gt[3], dtype=np.float64) - Yarr * abs(self.shift.gt[5])
178

179
180
181
        XY_points = np.empty((Xarr.size, 2), Xarr.dtype)
        XY_points[:, 0] = Xarr.flat
        XY_points[:, 1] = Yarr.flat
182

183
184
185
        XY_mapPoints = np.empty((mapXarr.size, 2), mapXarr.dtype)
        XY_mapPoints[:, 0] = mapXarr.flat
        XY_mapPoints[:, 1] = mapYarr.flat
186

Daniel Scheffler's avatar
Daniel Scheffler committed
187
188
        assert XY_points.shape == XY_mapPoints.shape

189
        return XY_points, XY_mapPoints
190

191
192
193
194
195
196
197
198
199
200
201
    def _exclude_bad_XYpos(self, GDF):
        """Excludes all points outside of the image overlap area and all points where the bad data mask is True (if given).

        :param GDF:     <geopandas.GeoDataFrame> must include the columns 'X_UTM' and 'Y_UTM'
        :return:
        """

        # exclude all points outside of overlap area
        inliers = points_in_poly(self.XY_mapPoints,
                                 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
        GDF = GDF[inliers].copy()
202
        # GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower
203

204
205
        # FIXME track that
        assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!'
206
207

        # exclude all point where bad data mask is True (e.g. points on clouds etc.)
208
209
210
211
212
213
        orig_len_GDF = len(GDF)  # length of GDF after dropping all points outside the overlap polygon
        mapXY = np.array(GDF.loc[:, ['X_UTM', 'Y_UTM']])
        GDF['REF_BADDATA'] = self.COREG_obj.ref.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.ref.mask_baddata is not None else False
        GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY) \
            if self.COREG_obj.shift.mask_baddata is not None else False
Daniel Scheffler's avatar
Daniel Scheffler committed
214
        GDF = GDF[(~GDF['REF_BADDATA']) & (~GDF['TGT_BADDATA'])]
215
        if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
Daniel Scheffler's avatar
Daniel Scheffler committed
216
217
            if not self.q:
                print('According to the provided bad data mask(s) %s points of initially %s have been excluded.'
218
                      % (orig_len_GDF - len(GDF), orig_len_GDF))
219
220
221

        return GDF

222
223
    @staticmethod
    def _get_spatial_shifts(coreg_kwargs):
Daniel Scheffler's avatar
Daniel Scheffler committed
224
        # unpack
225
        pointID = coreg_kwargs['pointID']
226
227
        fftw_works = coreg_kwargs['fftw_works']
        del coreg_kwargs['pointID'], coreg_kwargs['fftw_works']
228

Daniel Scheffler's avatar
Daniel Scheffler committed
229
        # assertions
230
        assert global_shared_imref is not None
231
        assert global_shared_im2shift is not None
Daniel Scheffler's avatar
Daniel Scheffler committed
232
233

        # run CoReg
234
        CR = COREG(global_shared_imref, global_shared_im2shift, CPUs=1, **coreg_kwargs)
235
        CR.fftw_works = fftw_works
236
        CR.calculate_spatial_shifts()
Daniel Scheffler's avatar
Daniel Scheffler committed
237
238

        # fetch results
239
        last_err = CR.tracked_errors[-1] if CR.tracked_errors else None
240
        win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
241
242
243
        CR_res = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
                  CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
                  CR.shift_reliability, last_err]
244

245
        return [pointID] + CR_res
246

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    def _get_coreg_kwargs(self, pID, wp):
        return dict(
            pointID=pID,
            fftw_works=self.COREG_obj.fftw_works,
            wp=wp,
            ws=self.COREG_obj.win_size_XY,
            resamp_alg_calc=self.rspAlg_calc,
            footprint_poly_ref=self.COREG_obj.ref.poly,
            footprint_poly_tgt=self.COREG_obj.shift.poly,
            r_b4match=self.ref.band4match + 1,  # band4match is internally saved as index, starting from 0
            s_b4match=self.shift.band4match + 1,  # band4match is internally saved as index, starting from 0
            max_iter=self.COREG_obj.max_iter,
            max_shift=self.COREG_obj.max_shift,
            nodata=(self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
            force_quadratic_win=self.COREG_obj.force_quadratic_win,
            binary_ws=self.COREG_obj.bin_ws,
            v=False,  # otherwise this would lead to massive console output
            q=True,  # otherwise this would lead to massive console output
            ignore_errors=True
        )

268
    def get_CoRegPoints_table(self):
269
270
        assert self.XY_points is not None and self.XY_mapPoints is not None

271
272
273
274
        # create a dataframe containing 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'
        # (convert imCoords to mapCoords
        XYarr2PointGeom = np.vectorize(lambda X, Y: Point(X, Y), otypes=[Point])
        geomPoints = np.array(XYarr2PointGeom(self.XY_mapPoints[:, 0], self.XY_mapPoints[:, 1]))
275

276
        if isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'geographic':
277
            crs = dict(ellps='WGS84', datum='WGS84', proj='longlat')
278
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj) == 'projected':
279
            UTMzone = abs(get_UTMzone(prj=self.COREG_obj.shift.prj))
280
281
282
283
            south = get_UTMzone(prj=self.COREG_obj.shift.prj) < 0
            crs = dict(ellps='WGS84', datum='WGS84', proj='utm', zone=UTMzone, south=south, units='m', no_defs=True)
            if not south:
                del crs['south']
284
285
286
        else:
            crs = None

287
288
289
290
291
292
        GDF = GeoDataFrame(index=range(len(geomPoints)), crs=crs,
                           columns=['geometry', 'POINT_ID', 'X_IM', 'Y_IM', 'X_UTM', 'Y_UTM'])
        GDF['geometry'] = geomPoints
        GDF['POINT_ID'] = range(len(geomPoints))
        GDF.loc[:, ['X_IM', 'Y_IM']] = self.XY_points
        GDF.loc[:, ['X_UTM', 'Y_UTM']] = self.XY_mapPoints
293

294
295
        # exclude offsite points and points on bad data mask
        GDF = self._exclude_bad_XYpos(GDF)
296
297
298
        if GDF.empty:
            self.CoRegPoints_table = GDF
            return self.CoRegPoints_table
299

300
        # choose a random subset of points if a maximum number has been given
301
        if self.max_points and len(GDF) > self.max_points:
302
            GDF = GDF.sample(self.max_points).copy()
303

304
        # equalize pixel grids in order to save warping time
305
306
307
308
        if len(GDF) > 100:
            # NOTE: actually grid res should be also changed here because self.shift.xgsd changes and grid res is
            # connected to that
            self.COREG_obj.equalize_pixGrids()
309

310
        # validate reference and target image inputs
311
        assert self.ref.footprint_poly  # this also checks for mask_nodata and nodata value
312
        assert self.shift.footprint_poly
313
314
315

        # ensure the input arrays for CoReg are in memory -> otherwise the code will get stuck in multiprocessing if
        # neighboured matching windows overlap during reading from disk!!
316
317
        self.ref.cache_array_subset(
            [self.COREG_obj.ref.band4match])  # only sets geoArr._arr_cache; does not change number of bands
Daniel Scheffler's avatar
Daniel Scheffler committed
318
319
        self.shift.cache_array_subset([self.COREG_obj.shift.band4match])

320
        # get all variations of kwargs for coregistration
321
        list_coreg_kwargs = (self._get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index)  # generator
322
323

        # run co-registration for whole grid
324
        if self.CPUs is None or self.CPUs > 1:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
325
            if not self.q:
326
                cpus = self.CPUs if self.CPUs is not None else multiprocessing.cpu_count()
327
                print("Calculating tie point grid (%s points) using %s CPU cores..." % (len(GDF), cpus))
328

329
            with multiprocessing.Pool(self.CPUs, initializer=mp_initializer, initargs=(self.ref, self.shift)) as pool:
330
331
332
333
                if self.q or not self.progress:
                    results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
                else:
                    results = pool.map_async(self._get_spatial_shifts, list_coreg_kwargs, chunksize=1)
334
                    bar = ProgressBar(prefix='\tprogress:')
335
336
                    while True:
                        time.sleep(.1)
337
338
339
                        # this does not really represent the remaining tasks but the remaining chunks
                        # -> thus chunksize=1
                        numberDone = len(GDF) - results._number_left
340
                        if self.progress:
341
                            bar.print_progress(percent=numberDone / len(GDF) * 100)
342
                        if results.ready():
343
344
345
                            # <= this is the line where multiprocessing can freeze if an exception appears within
                            # COREG ans is not raised
                            results = results.get()
346
                            break
Daniel Scheffler's avatar
Daniel Scheffler committed
347

348
        else:
349
350
351
352
353
            # declare global variables needed for self._get_spatial_shifts()
            global global_shared_imref, global_shared_im2shift
            global_shared_imref = self.ref
            global_shared_im2shift = self.shift

Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
354
            if not self.q:
355
                print("Calculating tie point grid (%s points) 1 CPU core..." % len(GDF))
356
357
358
            results = np.empty((len(geomPoints), 14), np.object)
            bar = ProgressBar(prefix='\tprogress:')
            for i, coreg_kwargs in enumerate(list_coreg_kwargs):
359
                if self.progress:
360
361
                    bar.print_progress((i + 1) / len(GDF) * 100)
                results[i, :] = self._get_spatial_shifts(coreg_kwargs)
362

363
                # merge results with GDF
364
        records = GeoDataFrame(np.array(results, np.object),
365
                               columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX', 'Y_SHIFT_PX', 'X_SHIFT_M',
366
                                        'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
367
368
                                        'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])

369
370
371
        GDF = GDF.merge(records, on='POINT_ID', how="inner")
        GDF = GDF.fillna(int(self.outFillVal))

372
373
374
        if not self.q:
            print("Found %s matches." % len(GDF[GDF.LAST_ERR == int(self.outFillVal)]))

375
        # filter tie points according to given filter level
376
        if self.tieP_filter_level > 0:
377
378
            if not self.q:
                print('Performing validity checks...')
379
            TPR = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal], **self.outlDetect_settings)
380
            GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
381
            GDF = GDF.merge(GDF_filt[['POINT_ID'] + new_columns], on='POINT_ID', how="outer")
382
        GDF = GDF.fillna(int(self.outFillVal))
383

384
        self.CoRegPoints_table = GDF
385
386
387

        return self.CoRegPoints_table

388
389
390
391
    def calc_rmse(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates root mean square error of absolute shifts from the tie point grid.

Daniel Scheffler's avatar
Daniel Scheffler committed
392
        :param include_outliers:    whether to include tie points that have been marked as false-positives (if present)
393
394
395
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
396
        tbl = tbl if include_outliers else tbl[~tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else tbl
397
398
399
400
401
402
403
404
405
406
407
408
409
410

        shifts = np.array(tbl['ABS_SHIFT'])
        shifts_sq = [i * i for i in shifts if i != self.outFillVal]

        return np.sqrt(sum(shifts_sq) / len(shifts_sq))

    def calc_overall_mssim(self, include_outliers=False):
        # type: (bool) -> float
        """Calculates the median value of all MSSIM values contained in tie point grid.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        """

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
411
        tbl = tbl if include_outliers else tbl[~tbl['OUTLIER']].copy()
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432

        mssim_col = np.array(tbl['MSSIM'])
        mssim_col = [i * i for i in mssim_col if i != self.outFillVal]

        return float(np.median(mssim_col))

    def plot_shift_distribution(self, include_outliers=True, unit='m', interactive=False, figsize=None, xlim=None,
                                ylim=None, fontsize=12, title='shift distribution'):
        # type: (bool, str, bool, tuple, list, list, int) -> tuple
        """Creates a 2D scatterplot containing the distribution of calculated X/Y-shifts.

        :param include_outliers:    whether to include tie points that have been marked as false-positives
        :param unit:                'm' for meters or 'px' for pixels (default: 'm')
        :param interactive:         interactive mode uses plotly for visualization
        :param figsize:             (xdim, ydim)
        :param xlim:                [xmin, xmax]
        :param ylim:                [ymin, ymax]
        :param fontsize:            size of all used fonts
        :param title:               the title to be plotted above the figure
        """

433
434
        if unit not in ['m', 'px']:
            raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." % unit)
435
436

        tbl = self.CoRegPoints_table
Daniel Scheffler's avatar
Daniel Scheffler committed
437
        tbl = tbl[tbl['ABS_SHIFT'] != self.outFillVal]
Daniel Scheffler's avatar
Daniel Scheffler committed
438
439
        tbl_il = tbl[~tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else tbl
        tbl_ol = tbl[tbl['OUTLIER']].copy() if 'OUTLIER' in tbl.columns else None
440
441
        x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
        y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
442
443
        rmse = self.calc_rmse(include_outliers=False)  # always exclude outliers when calculating RMSE
        figsize = figsize if figsize else (10, 10)
444
445
446
447

        if interactive:
            from plotly.offline import iplot, init_notebook_mode
            import plotly.graph_objs as go
Daniel Scheffler's avatar
Daniel Scheffler committed
448
            # FIXME outliers are not plotted
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469

            init_notebook_mode(connected=True)

            # Create a trace
            trace = go.Scatter(
                x=tbl_il[x_attr],
                y=tbl_il[y_attr],
                mode='markers'
            )

            data = [trace]

            # Plot and embed in ipython notebook!
            iplot(data, filename='basic-scatter')

            return None, None

        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111)

Daniel Scheffler's avatar
Daniel Scheffler committed
470
            if include_outliers and 'OUTLIER' in tbl.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
471
                ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false-positives')
472
473
474
475
476
477
478
479
480
481
482
483
            ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')

            # set axis limits
            if not xlim:
                xmax = np.abs(tbl_il[x_attr]).max()
                xlim = [-xmax, xmax]
            if not ylim:
                ymax = np.abs(tbl_il[y_attr]).max()
                ylim = [-ymax, ymax]
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

Daniel Scheffler's avatar
Daniel Scheffler committed
484
            # add text box containing RMSE of plotted shifts
485
            xlim, ylim = ax.get_xlim(), ax.get_ylim()
486
487
            plt.text(xlim[1] - (xlim[1] / 20), -ylim[1] + (ylim[1] / 20),
                     'RMSE:  %s m / %s px' % (np.round(rmse, 2), np.round(rmse / self.shift.xgsd, 2)),
488
                     ha='right', va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
489

Daniel Scheffler's avatar
Daniel Scheffler committed
490
            # add grid and increase linewidth of middle line
491
492
493
            plt.grid()
            xgl = ax.get_xgridlines()
            middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
494
            middle_xgl.set_linewidth(2)
495
496
497
            middle_xgl.set_linestyle('-')
            ygl = ax.get_ygridlines()
            middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
Daniel Scheffler's avatar
Daniel Scheffler committed
498
            middle_ygl.set_linewidth(2)
499
500
            middle_ygl.set_linestyle('-')

Daniel Scheffler's avatar
Daniel Scheffler committed
501
502
            # set title and adjust tick labels
            ax.set_title(title, fontsize=fontsize)
503
504
            [tick.label.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
            [tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
Daniel Scheffler's avatar
Daniel Scheffler committed
505
506
            plt.xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
            plt.ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
507

508
509
            # add legend with labels in the right order
            handles, labels = ax.get_legend_handles_labels()
Daniel Scheffler's avatar
Daniel Scheffler committed
510
511
            leg = plt.legend(reversed(handles), reversed(labels), fontsize=fontsize, loc='upper right', scatterpoints=3)
            leg.get_frame().set_edgecolor('black')
512

513
514
515
516
            plt.show()

            return fig, ax

517
    def dump_CoRegPoints_table(self, path_out=None):
518
519
520
521
522
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=self.dir_out, fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl"
                                                                % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                                                   self.COREG_obj.win_size_XY[1], self.shift.basename,
                                                                   self.ref.basename))
523
524
525
        if not self.q:
            print('Writing %s ...' % path_out)
        self.CoRegPoints_table.to_pickle(path_out)
526

527
    def to_GCPList(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
528
        # get copy of tie points grid without no data
Daniel Scheffler's avatar
Daniel Scheffler committed
529
530
531
532
533
        try:
            GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
        except AttributeError:
            # self.CoRegPoints_table has no attribute 'ABS_SHIFT' because all points have been excluded
            return []
534

535
        if getattr(GDF, 'empty'):  # GDF.empty returns AttributeError
536
537
            return []
        else:
538
            # exclude all points flagged as outliers
539
            if 'OUTLIER' in GDF.columns:
Daniel Scheffler's avatar
Daniel Scheffler committed
540
                GDF = GDF[GDF.OUTLIER == False].copy()
541
542
            avail_TP = len(GDF)

543
544
545
546
            if not avail_TP:
                # no point passed all validity checks
                return []

547
            if avail_TP > 7000:
548
549
550
                GDF = GDF.sample(7000)
                warnings.warn('By far not more than 7000 tie points can be used for warping within a limited '
                              'computation time (due to a GDAL bottleneck). Thus these 7000 points are randomly chosen '
551
                              'out of the %s available tie points.' % avail_TP)
552

553
554
555
            # calculate GCPs
            GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
            GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
556
557
            GDF['GCP'] = GDF.apply(lambda GDF_row: gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
                                                            GDF_row.X_IM, GDF_row.Y_IM), axis=1)
558
559
560
            self.GCPList = GDF.GCP.tolist()

            if not self.q:
561
                print('Found %s valid tie points.' % len(self.GCPList))
562
563

            return self.GCPList
564

565
    def test_if_singleprocessing_equals_multiprocessing_result(self):
566
567
        # RANSAC filtering always produces different results because it includes random sampling
        self.tieP_filter_level = 1
568

Daniel Scheffler's avatar
Daniel Scheffler committed
569
        self.CPUs = None
570
        dataframe = self.get_CoRegPoints_table()
571
        mp_out = np.empty_like(dataframe.values)
572
        mp_out[:] = dataframe.values
Daniel Scheffler's avatar
Daniel Scheffler committed
573
        self.CPUs = 1
574
        dataframe = self.get_CoRegPoints_table()
575
        sp_out = np.empty_like(dataframe.values)
576
577
        sp_out[:] = dataframe.values

578
        return np.array_equal(sp_out, mp_out)
579

580
581
    def _get_line_by_PID(self, PID):
        return self.CoRegPoints_table.loc[PID, :]
582

583
    def _get_lines_by_PIDs(self, PIDs):
584
585
586
587
        assert isinstance(PIDs, list)
        lines = np.zeros((len(PIDs), self.CoRegPoints_table.shape[1]))
        for i, PID in enumerate(PIDs):
            lines[i, :] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
588
589
        return lines

590
    def to_PointShapefile(self, path_out=None, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):
591
        # type: (str, bool, str) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
592
        """Writes the calculated tie points grid to a point shapefile containing
593
        Tie_Point_Grid.CoRegPoints_table as attribute table. This shapefile can easily be displayed using GIS software.
594
595
596

        :param path_out:        <str> the output path. If not given, it is automatically defined.
        :param skip_nodata:     <bool> whether to skip all points where no valid match could be found
597
        :param skip_nodata_col: <str> determines which column of Tie_Point_Grid.CoRegPoints_table is used to
598
599
                                identify points where no valid match could be found
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
600

601
602
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
Daniel Scheffler's avatar
Daniel Scheffler committed
603
604
605
606
607

        # replace boolean values (cannot be written)
        for col in GDF2pass.columns:
            if GDF2pass[col].dtype == np.bool:
                GDF2pass[col] = GDF2pass[col].astype(int)
608
609
610
611
612
613
614
615
        GDF2pass = GDF2pass.replace(False, 0)  # replace all remaining booleans where dtype is not np.bool but np.object
        GDF2pass = GDF2pass.replace(True, 1)

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
616
        if not self.q:
617
            print('Writing %s ...' % path_out)
618
619
        GDF2pass.to_file(path_out)

620
621
622
623
624
625
    def _to_PointShapefile(self, skip_nodata=True, skip_nodata_col='ABS_SHIFT'):
        warnings.warn(DeprecationWarning(
            "'_tiepoints_grid_to_PointShapefile' is deprecated."  # TODO delete if other method validated
            " 'tiepoints_grid_to_PointShapefile' is much faster."))
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
626
        shapely_points = GDF2pass['geometry'].values.tolist()
627
        attr_dicts = [collections.OrderedDict(zip(GDF2pass.columns, GDF2pass.loc[i].values)) for i in GDF2pass.index]
628

629
        fName_out = "CoRegPoints_grid%s_ws%s.shp" % (self.grid_res, self.COREG_obj.win_size_XY)
630
631
632
        path_out = os.path.join(self.dir_out, fName_out)
        IO.write_shp(path_out, shapely_points, prj=self.COREG_obj.shift.prj, attrDict=attr_dicts)

633
    def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
634
        # type: (str) -> GeoArray
635
636
637
638
639
        """Saves the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield
        (e.g. using ArcGIS)

        :param path_out:    <str> the output path. If not given, it is automatically defined.
        :param fmt:         <str> output raster format string
640
641
        :param mode:        <str> The mode how the output is written ('uv' or 'md'; default: 'md')
                                    'uv': outputs X-/Y shifts
642
643
644
                                    'md': outputs magnitude and direction
        """

645
646
        assert mode in ['uv', 'md'], "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' " \
                                     "(outputs magnitude and direction)'. Got %s." % mode
647
648
        attr_b1 = 'X_SHIFT_M' if mode == 'uv' else 'ABS_SHIFT'
        attr_b2 = 'Y_SHIFT_M' if mode == 'uv' else 'ANGLE'
649

650
651
652
653
654
        xshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b1],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal=self.outFillVal)
655

656
657
658
659
660
        yshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b2],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
                                               fillVal=self.outFillVal)
661
662
663

        out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)

664
665
666
667
668
        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
669
670
671
672
673

        out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')

        return out_GA

674
    def _to_Raster_using_KrigingOLD(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
675
                                    path_out=None, tilepos=None):
676
        warnings.warn(DeprecationWarning("'to_Raster_using_KrigingOLD' is deprecated. Use to_Raster_using_Kriging "
677
                                         "instead."))  # TODO delete
678

679
680
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
681
682

        # subset if tilepos is given
683
684
685
        rows, cols = tilepos if tilepos else self.shift.shape
        GDF2pass = GDF2pass.loc[(GDF2pass['X_IM'] >= cols[0]) & (GDF2pass['X_IM'] <= cols[1]) &
                                (GDF2pass['Y_IM'] >= rows[0]) & (GDF2pass['Y_IM'] <= rows[1])]
686

687
        X_coords, Y_coords, ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'], GDF2pass[attrName]
688

689
        xmin, ymin, xmax, ymax = GDF2pass.total_bounds
690

691
692
        grid_res = outGridRes if outGridRes else int(min(xmax - xmin, ymax - ymin) / 250)
        grid_x, grid_y = np.arange(xmin, xmax + grid_res, grid_res), np.arange(ymax, ymin - grid_res, -grid_res)
693
694
695

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
696
        from pykrige.ok import OrdinaryKriging
697
698
699
700
701
702
703
704
705
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical', verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y)  # ,backend='C',)

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="Kriging__%s__grid%s_ws(%s_%s).tif"
                                          % (attrName, self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1]))
        print('Writing %s ...' % path_out)
706
        # add a half pixel grid points are centered on the output pixels
707
708
709
        xmin, ymin, xmax, ymax = xmin - grid_res / 2, ymin - grid_res / 2, xmax + grid_res / 2, ymax + grid_res / 2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res),
                                prj=self.COREG_obj.shift.prj)
710
711
712

        return zvalues

713
    def to_Raster_using_Kriging(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
714
                                fName_out=None, tilepos=None, tilesize=500, mp=None):
715

716
        mp = False if self.CPUs == 1 else True
717
718
        self._Kriging_sp(attrName, skip_nodata=skip_nodata, skip_nodata_col=skip_nodata_col,
                         outGridRes=outGridRes, fName_out=fName_out, tilepos=tilepos)
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740

        # if mp:
        #     tilepositions = UTL.get_image_tileborders([tilesize,tilesize],self.tgt_shape)
        #     args_kwargs_dicts=[]
        #     for tp in tilepositions:
        #         kwargs_dict = {'skip_nodata':skip_nodata,'skip_nodata_col':skip_nodata_col,'outGridRes':outGridRes,
        #                        'fName_out':fName_out,'tilepos':tp}
        #         args_kwargs_dicts.append({'args':[attrName],'kwargs':kwargs_dict})
        #     # self.kriged=[]
        #     # for i in args_kwargs_dicts:
        #     #     res = self.Kriging_mp(i)
        #     #     self.kriged.append(res)
        #     #     print(res)
        #
        #     with multiprocessing.Pool() as pool:
        #        self.kriged = pool.map(self.Kriging_mp,args_kwargs_dicts)
        # else:
        #     self.Kriging_sp(attrName,skip_nodata=skip_nodata,skip_nodata_col=skip_nodata_col,
        #                     outGridRes=outGridRes,fName_out=fName_out,tilepos=tilepos)
        res = self.kriged if mp else None
        return res

741
742
    def _Kriging_sp(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                    fName_out=None, tilepos=None):
743
744
        GDF = self.CoRegPoints_table
        GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col] != self.outFillVal]
745

746
747
748
749
750
751
752
753
754
755
        #         # subset if tilepos is given
        # #        overlap_factor =
        #         rows,cols = tilepos if tilepos else self.tgt_shape
        #         xvals, yvals = np.sort(GDF2pass['X_IM'].values.flat),np.sort(GDF2pass['Y_IM'].values.flat)
        #         cS,cE = UTL.find_nearest(xvals,cols[0],'off',1), UTL.find_nearest(xvals,cols[1],'on',1)
        #         rS,rE = UTL.find_nearest(yvals,rows[0],'off',1), UTL.find_nearest(yvals,rows[1],'on',1)
        #         # GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
        #         #                                (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]
        #         GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cS)&(GDF2pass['X_IM']<=cE)&
        #                                        (GDF2pass['Y_IM']>=rS)&(GDF2pass['Y_IM']<=rE)]
756

757
        X_coords, Y_coords, ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'], GDF2pass[attrName]
758

759
        xmin, ymin, xmax, ymax = GDF2pass.total_bounds
760

761
762
        grid_res = outGridRes if outGridRes else int(min(xmax - xmin, ymax - ymin) / 250)
        grid_x, grid_y = np.arange(xmin, xmax + grid_res, grid_res), np.arange(ymax, ymin - grid_res, -grid_res)
763
764
765

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
766
        from pykrige.ok import OrdinaryKriging
767
768
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical', verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y, backend='C', n_closest_points=12)
769

770
        if self.CPUs is None or self.CPUs > 1:
771
            fName_out = fName_out if fName_out else \
772
                "Kriging__%s__grid%s_ws%s_%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY, tilepos)
773
774
        else:
            fName_out = fName_out if fName_out else \
775
776
777
                "Kriging__%s__grid%s_ws%s.tif" % (attrName, self.grid_res, self.COREG_obj.win_size_XY)
        path_out = get_generic_outpath(dir_out=self.dir_out, fName_out=fName_out)
        print('Writing %s ...' % path_out)
778
        # add a half pixel grid points are centered on the output pixels
779
780
781
        xmin, ymin, xmax, ymax = xmin - grid_res / 2, ymin - grid_res / 2, xmax + grid_res / 2, ymax + grid_res / 2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res),
                                prj=self.COREG_obj.shift.prj)
782
783
784

        return zvalues

785
    def _Kriging_mp(self, args_kwargs_dict):
786
787
        args = args_kwargs_dict.get('args', [])
        kwargs = args_kwargs_dict.get('kwargs', [])
788

789
        return self._Kriging_sp(*args, **kwargs)
790
791


792
class Tie_Point_Refiner(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
793
    def __init__(self, GDF, min_reliability=60, rs_max_outlier=10, rs_tolerance=2.5, rs_max_iter=15,
794
                 rs_exclude_previous_outliers=True, rs_timeout=20, q=False):
795
        """A class for performing outlier detection.
Daniel Scheffler's avatar
Daniel Scheffler committed
796

797
798
        :param GDF:                             GeoDataFrame like TiePointGrid.CoRegPoints_table containing all tie
                                                points to be filtered and the corresponding metadata
Daniel Scheffler's avatar
Daniel Scheffler committed
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
        :param min_reliability:                 <float, int> minimum threshold for previously computed tie X/Y shift
                                                reliability (default: 60%)
        :param rs_max_outlier:                  <float, int> RANSAC: maximum percentage of outliers to be detected
                                                (default: 10%)
        :param rs_tolerance:                    <float, int> RANSAC: percentage tolerance for max_outlier_percentage
                                                (default: 2.5%)
        :param rs_max_iter:                     <int> RANSAC: maximum iterations for finding the best RANSAC threshold
                                                (default: 15)
        :param rs_exclude_previous_outliers:    <bool> RANSAC: whether to exclude points that have been flagged as
                                                outlier by earlier filtering (default:True)
        :param rs_timeout:                      <float, int> RANSAC: timeout for iteration loop in seconds (default: 20)

        :param q:
        """

        self.GDF = GDF.copy()
        self.min_reliability = min_reliability
        self.rs_max_outlier_percentage = rs_max_outlier
        self.rs_tolerance = rs_tolerance
        self.rs_max_iter = rs_max_iter
        self.rs_exclude_previous_outliers = rs_exclude_previous_outliers
        self.rs_timeout = rs_timeout
        self.q = q
        self.new_cols = []
823
824
        self.ransac_model_robust = None

825
    def run_filtering(self, level=3):
826
827
        """Filter tie points used for shift correction.

828
        :param level:   tie point filter level (default: 3).
829
830
831
832
833
834
835
836
                        NOTE: lower levels are also included if a higher level is chosen
                            - Level 0: no tie point filtering
                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                reliability according to internal tests
                            - Level 2: SSIM filtering - filters all tie points out where shift
                                correction does not increase image similarity within matching window
                                (measured by mean structural similarity index)
                            - Level 3: RANSAC outlier detection
Daniel Scheffler's avatar
Daniel Scheffler committed
837
838
839
840

        :return:
        """

841
842
        # TODO catch empty GDF

843
        # RELIABILITY filtering
844
        if level > 0:
845
            marked_recs = GeoSeries(self._reliability_thresholding())
846
847
            self.GDF['L1_OUTLIER'] = marked_recs
            self.new_cols.append('L1_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
848

849
            if not self.q:
850
                print('%s tie points flagged by level 1 filtering (reliability).'
Daniel Scheffler's avatar
Daniel Scheffler committed
851
                      % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
852

853
        # SSIM filtering
854
        if level > 1:
855
            marked_recs = GeoSeries(self._SSIM_filtering())
856
857
            self.GDF['L2_OUTLIER'] = marked_recs
            self.new_cols.append('L2_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
858

859
            if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
860
                print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs])))
Daniel Scheffler's avatar
Daniel Scheffler committed
861

862
        # RANSAC filtering
863
        if level > 2:
Daniel Scheffler's avatar
Daniel Scheffler committed
864
            # exclude previous outliers
Daniel Scheffler's avatar
Daniel Scheffler committed
865
            ransacInGDF = self.GDF[~self.GDF[self.new_cols].any(axis=1)].copy() \
866
                    if self.rs_exclude_previous_outliers else self.GDF
Daniel Scheffler's avatar
Daniel Scheffler committed
867

868
            if len(ransacInGDF) > 4:
869
                # running RANSAC with less than four tie points makes no sense
Daniel Scheffler's avatar
Daniel Scheffler committed
870
871

                marked_recs = GeoSeries(self._RANSAC_outlier_detection(ransacInGDF))
872
873
                # we need to join a list here because otherwise it's merged by the 'index' column
                self.GDF['L3_OUTLIER'] = marked_recs.tolist()
Daniel Scheffler's avatar
Daniel Scheffler committed
874

875
                if not self.q:
Daniel Scheffler's avatar
Daniel Scheffler committed
876
877
                    print('%s tie points flagged by level 3 filtering (RANSAC)'
                          % (len(marked_recs[marked_recs])))
878
879
880
881
            else:
                print('RANSAC skipped because too less valid tie points have been found.')
                self.GDF['L3_OUTLIER'] = False

882
            self.new_cols.append('L3_OUTLIER')
Daniel Scheffler's avatar
Daniel Scheffler committed
883

884
885
886
887
888
        self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
        self.new_cols.append('OUTLIER')

        return self.GDF, self.new_cols

Daniel Scheffler's avatar
Daniel Scheffler committed
889
    def _reliability_thresholding(self):
890
        """Exclude all records where estimated reliability of the calculated shifts is below the given threshold."""
891

Daniel Scheffler's avatar
Daniel Scheffler committed
892
        return self.GDF.RELIABILITY < self.min_reliability
893
894

    def _SSIM_filtering(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
895
        """Exclude all records where SSIM decreased."""
896

897
        # ssim_diff  = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE'])
898

899
900
        # self.GDF.SSIM_IMPROVED = \
        #     self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1)
901

902
        return self.GDF.SSIM_IMPROVED is False
903

Daniel Scheffler's avatar
Daniel Scheffler committed
904
905
    def _RANSAC_outlier_detection(self, inGDF):
        """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm."""
906

Daniel Scheffler's avatar
Daniel Scheffler committed
907
        src_coords = np.array(inGDF[['X_UTM', 'Y_UTM']])
908
        xyShift = np.array(inGDF[['X_SHIFT_M', 'Y_SHIFT_M']])
909
910
911
        est_coords = src_coords + xyShift

        for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']):
912
            assert co.ndim == 2 and co.shape[1] == 2, "'%s' must have shape [Nx2]. Got shape %s." % (n, co.shape)
913

914
915
        if not 0 < self.rs_max_outlier_percentage < 100:
            raise ValueError
916
        min_inlier_percentage = 100 - self.rs_max_outlier_percentage