Commit d398b9fc authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

added local shift-correction based on geometric quality grid; added README

COREG:
- calculate_spatial_shifts(): bugfix for printing None instead of calculated shifts within RuntimeError in case of too large shifts

DESHIFTER:
- updated documentation of keyword warp_alg
- implemented multiprocessing (keyword CPUs)
- implemented first prototype of local shift-correction based on a list of GCPs

Geom_Quality_Grid:
- replaced keyword 'multiproc' by 'CPUs'
- added method to_GCPList()
- added method correct_shifts()

README:
- added README for the whole package
parent 010c04d9
This diff is collapsed.
......@@ -404,10 +404,15 @@ class COREG(object):
# rsp_algor = 5 if is_avail_rsp_average else 2 # average if possible else cubic # OLD
# TODO replace cubic resampling by PSF resampling - average resampling leads to sinus like distortions in the fft image that make a precise coregistration impossible. Thats why there is currently no way around cubic resampling.
tgt_xmin,tgt_xmax,tgt_ymin,tgt_ymax = self.matchWin.boundsMap
self.otherWin.data = warp_ndarray(self.otherWin.data, otherWin_subgt, self.otherWin.imParams.prj,
self.matchWin.imParams.prj, out_gsd=(self.imfft_gsd, self.imfft_gsd),
out_bounds=([tgt_xmin, tgt_ymin, tgt_xmax, tgt_ymax]),
rspAlg='cubic', in_nodata=self.otherWin.imParams.nodata)[0]
self.otherWin.data = warp_ndarray(self.otherWin.data,
otherWin_subgt,
self.otherWin.imParams.prj,
self.matchWin.imParams.prj,
out_gsd = (self.imfft_gsd, self.imfft_gsd),
out_bounds = ([tgt_xmin, tgt_ymin, tgt_xmax, tgt_ymax]),
rspAlg = 'cubic',
in_nodata = self.otherWin.imParams.nodata,
progress = False) [0]
if self.matchWin.data.shape != self.otherWin.data.shape:
self.tracked_errors.append(
......@@ -716,7 +721,7 @@ class COREG(object):
"parameter to an appropriate value. Otherwise try to use a different window "
"size for matching via the '-ws' parameter or define the spectral bands "
"to be used for matching manually ('-br' and '-bs')."
% (self.x_shift_px, self.y_shift_px)))
% (x_totalshift, y_totalshift)))
if not self.ignErr:
raise self.tracked_errors[-1]
else:
......
......@@ -53,22 +53,23 @@ class DESHIFTER(object):
- resamp_alg(str) the resampling algorithm to be used if neccessary
(valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average, mode,
max, min, med, q1, q3)
- warp_alg(str): 'GDAL' or 'rasterio' (default = 'rasterio')
- warp_alg(str): 'GDAL_cmd' or 'GDAL_lib' (default = 'GDAL_lib')
- cliptoextent (bool): True: clip the input image to its actual bounds while deleting possible no data
areas outside of the actual bounds, default = True
- clipextent (list): xmin, ymin, xmax, ymax - if given the calculation of the actual bounds is skipped.
The given coordinates are automatically snapped to the output grid.
- tempDir(str): directory to be used for tempfiles (default: /dev/shm/)
- CPUs(int): number of CPUs to use (default: None, which means 'all CPUs available')
- v(bool): verbose mode (default: False)
- q(bool): quiet mode (default: False)
"""
# FIXME add mp?
# unpack args
self.im2shift = im2shift if isinstance(im2shift, GeoArray) else GeoArray(im2shift)
self.shift_prj = im2shift.projection
self.shift_gt = list(im2shift.geotransform)
self.shift_prj = self.im2shift.projection
self.shift_gt = list(self.im2shift.geotransform)
self.nodata = get_outFillZeroSaturated(self.im2shift.dtype)[0]
self.GCPList = coreg_results['GCPList'] if 'GCPList' in coreg_results else None
mapI = coreg_results['updated map info']
self.updated_map_info = mapI if mapI else geotransform2mapinfo(self.shift_gt, self.shift_prj)
self.original_map_info = coreg_results['original map info']
......@@ -88,7 +89,8 @@ class DESHIFTER(object):
self.warpAlg = kwargs.get('warp_alg' , 'GDAL_lib')
self.cliptoextent = kwargs.get('cliptoextent', True)
self.clipextent = kwargs.get('clipextent' , None)
self.tempDir = kwargs.get('tempDir' ,'/dev/shm/')
self.tempDir = kwargs.get('tempDir' , '/dev/shm/')
self.CPUs = kwargs.get('CPUs' , None)
self.v = kwargs.get('v' , False)
self.q = kwargs.get('q' , False) if not self.v else False
self.out_grid = self._get_out_grid(kwargs) # needs self.ref_grid, self.im2shift
......@@ -113,8 +115,8 @@ class DESHIFTER(object):
out_grid = init_kwargs.get('target_xyGrid', None)
# assertions
assert out_grid is None or (isinstance(out_grid,(list, tuple)) and len(out_grid)==2)
assert out_gsd is None or (isinstance(out_gsd, (int, list)) and len(out_gsd) ==2)
assert out_grid is None or (isinstance(out_grid,(list, tuple)) and len(out_grid)==2)
assert out_gsd is None or (isinstance(out_gsd, (int, tuple, list)) and len(out_gsd) ==2)
ref_xgsd, ref_ygsd = (self.ref_grid[0][1]-self.ref_grid[0][0],self.ref_grid[1][1]-self.ref_grid[1][0])
get_grid = lambda gt, xgsd, ygsd: [[gt[0], gt[0] + xgsd], [gt[3], gt[3] - ygsd]]
......@@ -267,7 +269,8 @@ class DESHIFTER(object):
elif self.warpAlg=='GDAL_lib':
# apply XY-shifts to shift_gt
in_arr = self.im2shift[self.band2process] if self.band2process else self.im2shift[:]
self.shift_gt[0], self.shift_gt[3] = self.updated_gt[0], self.updated_gt[3]
if not self.GCPList:
self.shift_gt[0], self.shift_gt[3] = self.updated_gt[0], self.updated_gt[3]
# get resampled array
out_arr, out_gt, out_prj = \
......@@ -276,7 +279,9 @@ class DESHIFTER(object):
in_nodata = self.nodata,
out_nodata = self.nodata,
out_gsd = self.out_gsd,
out_bounds = self._get_out_extent())
out_bounds = self._get_out_extent(),
gcpList = self.GCPList,
CPUs = self.CPUs)
self.updated_projection = out_prj
self.arr_shifted = out_arr
......
......@@ -14,7 +14,7 @@ from pykrige.ok import OrdinaryKriging
from shapely.geometry import Point
# internal modules
from .CoReg import COREG
from .CoReg import COREG, DESHIFTER
from . import geometry as GEO
from . import io as IO
from py_tools_ds.ptds import GeoArray
......@@ -27,10 +27,10 @@ global_shared_im2shift = None
class Geom_Quality_Grid(object):
def __init__(self, im_ref, im_tgt, grid_res, window_size=(256,256), dir_out=None, projectName=None, multiproc=True,
def __init__(self, im_ref, im_tgt, grid_res, window_size=(256,256), dir_out=None, projectName=None,
r_b4match=1, s_b4match=1, max_iter=5, max_shift=5, data_corners_im0=None,
data_corners_im1=None, outFillVal=-9999, nodata=(None,None), calc_corners=True, binary_ws=True,
v=False, q=False):
CPUs=None, v=False, q=False):
"""
......@@ -40,7 +40,6 @@ class Geom_Quality_Grid(object):
:param window_size(tuple): custom matching window size [pixels] (default: (512,512))
:param dir_out:
:param projectName:
:param multiproc: enables multiprocessing during calculation of geometric quality grid (default: True)
:param r_b4match(int): band of reference image to be used for matching (starts with 1; default: 1)
:param s_b4match(int): band of shift image to be used for matching (starts with 1; default: 1)
:param max_iter(int): maximum number of iterations for matching (default: 5)
......@@ -54,6 +53,8 @@ class Geom_Quality_Grid(object):
matching window position within the actual image overlap
(default: 1; deactivated if '-cor0' and '-cor1' are given
:param binary_ws(bool): use binary X/Y dimensions for the matching window (default: 1)
:param CPUs(int): number of CPUs to use during calculation of geometric quality grid
(default: None, which means 'all CPUs available')
:param v(bool): verbose mode (default: 0)
:param q(bool): quiet mode (default: 0)
"""
......@@ -62,7 +63,6 @@ class Geom_Quality_Grid(object):
self.dir_out = dir_out
self.grid_res = grid_res
self.window_size = window_size
self.mp = multiproc
self.max_shift = max_shift
self.max_iter = max_iter
self.r_b4match = r_b4match
......@@ -71,6 +71,7 @@ class Geom_Quality_Grid(object):
self.nodata = nodata
self.outFillVal = outFillVal
self.bin_ws = binary_ws
self.CPUs = CPUs
self.v = v
self.q = q
......@@ -91,7 +92,7 @@ class Geom_Quality_Grid(object):
max_iter = max_iter,
max_shift = max_shift,
nodata = nodata,
multiproc = multiproc,
multiproc = self.CPUs is None or self.CPUs>1,
binary_ws = self.bin_ws,
v = v,
q = q,
......@@ -106,6 +107,7 @@ class Geom_Quality_Grid(object):
self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
self.quality_grid = None # set by self.get_quality_grid()
self.GCPList = None # set by self.to_GCPList()
def _get_imXY__mapXY_points(self,grid_res):
......@@ -204,10 +206,10 @@ class Geom_Quality_Grid(object):
list_coreg_kwargs = (get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index) # generator
# run co-registration for whole grid
if self.mp:
if self.CPUs is None or self.CPUs>1:
if not self.q:
print("Calculating geometric quality grid in mode 'multiprocessing'...")
with multiprocessing.Pool() as pool:
with multiprocessing.Pool(self.CPUs) as pool:
results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
else:
if not self.q:
......@@ -254,6 +256,21 @@ class Geom_Quality_Grid(object):
self.quality_grid.to_pickle(path_out)
def to_GCPList(self):
assert self.quality_grid is not None, 'Calculate quality grid first!'
# get copy of quality grid without no data
GDF = self.quality_grid.loc[self.quality_grid.X_SHIFT_M!=self.outFillVal, :].copy()
# calculate GCPs
GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
GDF['GCP'] = GDF.apply(lambda GDF_row: gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
GDF_row.X_IM, GDF_row.Y_IM), axis=1)
self.GCPList = GDF.GCP.tolist()
return self.GCPList
def test_if_singleprocessing_equals_multiprocessing_result(self):
self.mp = 1
dataframe = self.get_quality_grid()
......@@ -268,12 +285,12 @@ class Geom_Quality_Grid(object):
def get_line_by_PID(self,PID):
assert self.quality_grid, 'Calculate quality grid first!'
assert self.quality_grid is not None, 'Calculate quality grid first!'
return self.quality_grid.loc[PID,:]
def get_lines_by_PIDs(self,PIDs):
assert self.quality_grid, 'Calculate quality grid first!'
assert self.quality_grid is not None, 'Calculate quality grid first!'
assert isinstance(PIDs,list)
lines = np.zeros((len(PIDs),self.quality_grid.shape[1]))
for i,PID in enumerate(PIDs):
......@@ -296,7 +313,7 @@ class Geom_Quality_Grid(object):
def _quality_grid_to_PointShapefile(self,skip_nodata=1,skip_nodata_col = 'ABS_SHIFT'):
warnings.warn(DeprecationWarning("'_quality_grid_to_PointShapefile' deprecated."
warnings.warn(DeprecationWarning("'_quality_grid_to_PointShapefile' is deprecated."
" 'quality_grid_to_PointShapefile' is much faster."))
GDF = self.quality_grid
GDF2pass = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]
......@@ -419,4 +436,21 @@ class Geom_Quality_Grid(object):
args = args_kwargs_dict.get('args',[])
kwargs = args_kwargs_dict.get('kwargs',[])
return self.Kriging_sp(*args,**kwargs)
\ No newline at end of file
return self.Kriging_sp(*args,**kwargs)
def correct_shifts(self, max_GCP_count=None):
coreg_info = self.COREG_obj.coreg_info
coreg_info['GCPList'] = self.GCPList if self.GCPList else self.to_GCPList()
if max_GCP_count:
coreg_info['GCPList'] = coreg_info['GCPList'][:max_GCP_count]
DS = DESHIFTER(self.im2shift, coreg_info,
path_out=None,
out_gsd=(self.im2shift.xgsd,self.im2shift.ygsd),
align_grids=True,
v=self.v,
q=self.q)
deshift_results = DS.correct_shifts()
return deshift_results
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment