Commit bc5fc58d authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

implemented shifts vector length and angle directly into CoReg; calculation of...

implemented shifts vector length and angle directly into CoReg; calculation of geometric quality grid now 100% faster

COREG:
- __init__(): added new attributes 'vec_length_map' and 'vec_angle_deg'
- _get_image_windows_to_match(): bugfix for running warp_ndarray in multiprocessing if multiprocessing is disabled
- calculate_spatial_shifts(): moved calculation of X/Y shifts in map units here; implemented calculation of shifts vector length and angle

DESHIFTER:
- __init__(): added attribute 'GeoArray_shifted'
- deshift_results(): added keys 'updated geotransform' and 'GeoArray_shifted' to returned dict

Geom_Quality_Grid:
- _get_spatial_shifts(): now returns values for all columns of quality grid
- revised get_quality_grid(): now much simpler and 100% faster
- view_results() added functionality for saving the output figure to disk
parent 2745a55f
......@@ -180,6 +180,8 @@ class COREG(object):
self.y_shift_px = None # always in shift image units (image coords) # set by calculate_spatial_shifts()
self.x_shift_map = None # set by self.get_updated_map_info()
self.y_shift_map = None # set by self.get_updated_map_info()
self.vec_length_map = None
self.vec_angle_deg = None
self.updated_map_info = None # set by self.get_updated_map_info()
self.tracked_errors = [] # expanded each time an error occurs
......@@ -436,6 +438,7 @@ class COREG(object):
out_bounds = ([tgt_xmin, tgt_ymin, tgt_xmax, tgt_ymax]),
rspAlg = 'cubic',
in_nodata = self.otherWin.imParams.nodata,
CPUs = None if self.mp else 1,
progress = False) [0]
if self.matchWin.data.shape != self.otherWin.data.shape:
......@@ -683,89 +686,102 @@ class COREG(object):
# calculate cross power spectrum without any de-shifting applied
scps = self._calc_shifted_cross_power_spectrum()
## calculate X/Y shifts for target image ##
x_shift_px,y_shift_px = None,None # defaults
if scps is None:
self.success = False
warnings.simplefilter('default')
return 'fail'
# calculate spatial shifts
count_iter = 1
x_intshift, y_intshift = self._calc_integer_shifts(scps)
if (x_intshift, y_intshift) == (0, 0):
self.success = True
else:
# 1st attempt
count_iter = 1
x_intshift, y_intshift = self._calc_integer_shifts(scps)
valid_invalid, x_val_shift, y_val_shift, scps = \
self._validate_integer_shifts(im0, im1, x_intshift, y_intshift)
if (x_intshift, y_intshift) == (0, 0):
self.success = True
else:
valid_invalid, x_val_shift, y_val_shift, scps = \
self._validate_integer_shifts(im0, im1, x_intshift, y_intshift)
while valid_invalid!='valid':
count_iter += 1
if count_iter > self.max_iter:
self.success = False
self.tracked_errors.append(RuntimeError('No match found in the given window.'))
if not self.ignErr:
raise self.tracked_errors[-1]
else:
warnings.warn('No match found in the given window.'); break
if valid_invalid=='invalid' and (x_val_shift, y_val_shift)==(None, None):
# this happens if matching window became too small
self.success = False
break
if not self.q: print('No clear match found yet. Jumping to iteration %s...' % count_iter)
if not self.q: print('input shifts: ', x_val_shift, y_val_shift)
valid_invalid, x_val_shift, y_val_shift, scps = \
self._validate_integer_shifts(im0, im1, x_val_shift, y_val_shift)
# overwrite previous integer shifts if a valid match has been found
if valid_invalid=='valid':
self.success = True
x_intshift, y_intshift = x_val_shift, y_val_shift
if not self.success==False:
x_subshift, y_subshift = self.calc_subpixel_shifts(scps)
x_totalshift, y_totalshift = self._get_total_shifts(x_intshift, y_intshift, x_subshift, y_subshift)
x_shift_px, y_shift_px = x_totalshift*gsd_factor, y_totalshift*gsd_factor
if not self.q:
print('Detected integer shifts (X/Y): %s/%s' %(x_intshift,y_intshift))
print('Detected subpixel shifts (X/Y): %s/%s' %(x_subshift,y_subshift))
print('Calculated total shifts in fft pixel units (X/Y): %s/%s' %(x_totalshift,y_totalshift))
print('Calculated total shifts in reference pixel units (X/Y): %s/%s' %(x_totalshift,y_totalshift))
print('Calculated total shifts in target pixel units (X/Y): %s/%s' %(x_shift_px,y_shift_px))
while valid_invalid!='valid':
count_iter += 1
if max([abs(x_totalshift),abs(y_totalshift)]) > self.max_shift:
if count_iter > self.max_iter:
self.success = False
self.tracked_errors.append(
RuntimeError("The calculated shift (X: %s px / Y: %s px) is recognized as too large to "
"be valid. If you know that it is valid, just set the '-max_shift' "
"parameter to an appropriate value. Otherwise try to use a different window "
"size for matching via the '-ws' parameter or define the spectral bands "
"to be used for matching manually ('-br' and '-bs')."
% (x_totalshift, y_totalshift)))
self.tracked_errors.append(RuntimeError('No match found in the given window.'))
if not self.ignErr:
raise self.tracked_errors[-1]
else:
else:
warnings.warn('No match found in the given window.'); break
if valid_invalid=='invalid' and (x_val_shift, y_val_shift)==(None, None):
# this happens if matching window became too small
self.success = False
break
if not self.q: print('No clear match found yet. Jumping to iteration %s...' % count_iter)
if not self.q: print('input shifts: ', x_val_shift, y_val_shift)
valid_invalid, x_val_shift, y_val_shift, scps = \
self._validate_integer_shifts(im0, im1, x_val_shift, y_val_shift)
# overwrite previous integer shifts if a valid match has been found
if valid_invalid=='valid':
self.success = True
x_intshift, y_intshift = x_val_shift, y_val_shift
if self.success or self.success is None:
# get total pixel shifts
x_subshift, y_subshift = self.calc_subpixel_shifts(scps)
x_totalshift, y_totalshift = self._get_total_shifts(x_intshift, y_intshift, x_subshift, y_subshift)
if max([abs(x_totalshift),abs(y_totalshift)]) > self.max_shift:
self.success = False
self.tracked_errors.append(
RuntimeError("The calculated shift (X: %s px / Y: %s px) is recognized as too large to "
"be valid. If you know that it is valid, just set the '-max_shift' "
"parameter to an appropriate value. Otherwise try to use a different window "
"size for matching via the '-ws' parameter or define the spectral bands "
"to be used for matching manually ('-br' and '-bs')."
% (x_totalshift, y_totalshift)))
if not self.ignErr:
raise self.tracked_errors[-1]
else:
self.success = True
self.x_shift_px, self.y_shift_px = x_totalshift*gsd_factor, y_totalshift*gsd_factor
# get map shifts
new_originY, new_originX = pixelToMapYX([self.x_shift_px, self.y_shift_px],
geotransform=self.shift.gt, projection=self.shift.prj)[0]
self.x_shift_map, self.y_shift_map = new_originX - self.shift.gt[0], new_originY - self.shift.gt[3]
# get length of shift vecor in map units
self.vec_length_map = float(np.sqrt(self.x_shift_map ** 2 + self.y_shift_map ** 2))
# get angle of shift vector
self.vec_angle_deg = GEO.angle_to_north((self.x_shift_px,self.y_shift_px)).tolist()[0]
# print results
if not self.q:
print('Detected integer shifts (X/Y): %s/%s' %(x_intshift,y_intshift))
print('Detected subpixel shifts (X/Y): %s/%s' %(x_subshift,y_subshift))
print('Calculated total shifts in fft pixel units (X/Y): %s/%s' %(x_totalshift,y_totalshift))
print('Calculated total shifts in reference pixel units (X/Y): %s/%s' %(x_totalshift,y_totalshift))
print('Calculated total shifts in target pixel units (X/Y): %s/%s' %(self.x_shift_px,self.y_shift_px))
print('Calculated map shifts (X,Y):\t\t\t\t %s/%s' %(self.x_shift_map, self.y_shift_map))
print('Calculated absolute shift vector length in map units: %s' %self.vec_length_map)
print('Calculated angle of shift vector in degrees from North: %s' %self.vec_angle_deg)
self.x_shift_px, self.y_shift_px = (x_shift_px,y_shift_px) if self.success else (None,None)
if self.x_shift_px or self.y_shift_px:
self._get_updated_map_info()
warnings.simplefilter('default')
return 'success'
def _get_updated_map_info(self):
original_map_info = geotransform2mapinfo(self.shift.gt, self.shift.prj)
new_originY, new_originX = pixelToMapYX([self.x_shift_px,self.y_shift_px],
geotransform=self.shift.gt, projection=self.shift.prj)[0]
self.x_shift_map = new_originX - self.shift.gt[0]
self.y_shift_map = new_originY - self.shift.gt[3]
if not self.q: print('Calculated map shifts (X,Y):\t\t\t\t ', self.x_shift_map,self.y_shift_map)
def _get_updated_map_info(self):
original_map_info = geotransform2mapinfo(self.shift.gt, self.shift.prj)
self.updated_map_info = original_map_info.copy()
self.updated_map_info[3] = str(float(original_map_info[3]) + self.x_shift_map)
self.updated_map_info[4] = str(float(original_map_info[4]) + self.y_shift_map)
......
......@@ -106,6 +106,7 @@ class DESHIFTER(object):
self.is_resampled = False # this is not included in COREG.coreg_info
self.tracked_errors = []
self.arr_shifted = None # set by self.correct_shifts
self.GeoArray_shifted = None # set by self.correct_shifts
def _get_out_grid(self, init_kwargs):
......@@ -253,6 +254,7 @@ class DESHIFTER(object):
else:
self.arr_shifted = rasterio.open(path_tmp).read(self.band2process)
self.GeoArray_shifted = GeoArray(self.arr_shifted,tuple(self.shift_gt), self.shift_prj)
self.is_shifted = True
self.is_resampled = True
......@@ -281,6 +283,8 @@ class DESHIFTER(object):
out_gsd = self.out_gsd,
out_bounds = self._get_out_extent(),
gcpList = self.GCPList,
polynomialOrder= None,
options = None,#'-refine_gcps 500',
CPUs = self.CPUs,
q = self.q)
......@@ -288,6 +292,7 @@ class DESHIFTER(object):
self.arr_shifted = out_arr
self.updated_map_info = geotransform2mapinfo(out_gt,out_prj)
self.shift_gt = mapinfo2geotransform(self.updated_map_info)
self.GeoArray_shifted = GeoArray(self.arr_shifted, tuple(self.shift_gt), self.updated_projection)
self.is_shifted = True
self.is_resampled = True
......@@ -301,10 +306,12 @@ class DESHIFTER(object):
@property
def deshift_results(self):
deshift_results = collections.OrderedDict()
deshift_results.update({'band' :self.band2process})
deshift_results.update({'is shifted' :self.is_shifted})
deshift_results.update({'is resampled' :self.is_resampled})
deshift_results.update({'updated map info' :self.updated_map_info})
deshift_results.update({'updated projection':self.updated_projection})
deshift_results.update({'arr_shifted' :self.arr_shifted})
deshift_results.update({'band' : self.band2process})
deshift_results.update({'is shifted' : self.is_shifted})
deshift_results.update({'is resampled' : self.is_resampled})
deshift_results.update({'updated map info' : self.updated_map_info})
deshift_results.update({'updated geotransform': self.shift_gt})
deshift_results.update({'updated projection' : self.updated_projection})
deshift_results.update({'arr_shifted' : self.arr_shifted})
deshift_results.update({'GeoArray_shifted' : self.GeoArray_shifted})
return deshift_results
\ No newline at end of file
......@@ -20,7 +20,7 @@ from .CoReg import COREG, DESHIFTER
from . import geometry as GEO
from . import io as IO
from py_tools_ds.ptds import GeoArray
from py_tools_ds.ptds.geo.projection import isProjectedOrGeographic, get_UTMzone, EPSG2WKT
from py_tools_ds.ptds.geo.projection import isProjectedOrGeographic, get_UTMzone
from py_tools_ds.ptds.geo.coord_trafo import transform_any_prj, reproject_shapelyGeometry
......@@ -138,10 +138,14 @@ class Geom_Quality_Grid(object):
pointID = coreg_kwargs['pointID']
del coreg_kwargs['pointID']
CR = COREG(global_shared_imref, global_shared_im2shift, **coreg_kwargs)
CR = COREG(global_shared_imref, global_shared_im2shift, **coreg_kwargs, multiproc=False)
CR.calculate_spatial_shifts()
res = pointID, CR.ref.win.size_YX[0], CR.x_shift_px, CR.y_shift_px
return res
CR_res = [int(CR.matchWin.imDimsYX[0]), int(CR.matchWin.imDimsYX[1]),
CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
CR.vec_length_map, CR.vec_angle_deg]
return [pointID]+CR_res
def get_quality_grid(self,exclude_outliers=1,dump_values=1):
......@@ -173,14 +177,12 @@ class Geom_Quality_Grid(object):
crs = None
GDF = GeoDataFrame(index=range(len(geomPoints)),crs=crs,
columns=['geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','WIN_SIZE',
'X_SHIFT_PX','Y_SHIFT_PX','X_SHIFT_M','Y_SHIFT_M','ABS_SHIFT','ANGLE'])
columns=['geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'])
GDF ['geometry'] = geomPoints
GDF ['POINT_ID'] = range(len(geomPoints))
GDF.loc[:,['X_IM','Y_IM']] = self.XY_points
GDF.loc[:,['X_UTM','Y_UTM']] = self.XY_mapPoints
GDF = GDF if not exclude_outliers else GDF[GDF['geometry'].within(self.overlap_poly)]
GDF.loc[:,['WIN_SIZE','X_SHIFT_PX','Y_SHIFT_PX','X_SHIFT_M','Y_SHIFT_M','ABS_SHIFT','ANGLE']] = self.outFillVal # Fehlwert
# declare global variables needed for self._get_spatial_shifts()
global global_shared_imref,global_shared_im2shift
......@@ -203,7 +205,7 @@ class Geom_Quality_Grid(object):
'nodata' : self.nodata,
'binary_ws' : self.bin_ws,
'v' : self.v, # FIXME this could lead to massive console output
'q' : self.q, # FIXME this could lead to massive console output
'q' : True, # otherwise this would lead to massive console output
'ignore_errors' : True
}
list_coreg_kwargs = (get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index) # generator
......@@ -216,30 +218,23 @@ class Geom_Quality_Grid(object):
results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
else:
if not self.q:
print("Calculating geometric quality grid in mode 'multiprocessing'...")
results = []
print("Calculating geometric quality grid in mode 'singleprocessing'...")
results = np.empty((len(geomPoints),9))
for i,coreg_kwargs in enumerate(list_coreg_kwargs):
#print(argset[1])
#if not 0<i<10: continue
#if i>300 or i<100: continue
#if i!=127: continue
if i%100==0: print('Point #%s, ID %s' %(i,coreg_kwargs['pointID']))
res = self._get_spatial_shifts(coreg_kwargs)
results.append(res)
for res in results:
pointID = res[0]
GDF.loc[pointID,'WIN_SIZE'] = res[1] if res[1] is not None else self.outFillVal
GDF.loc[pointID,'X_SHIFT_PX'] = res[2] if res[2] is not None else self.outFillVal
GDF.loc[pointID,'Y_SHIFT_PX'] = res[3] if res[3] is not None else self.outFillVal
oFV = self.outFillVal
GDF['X_SHIFT_M'] = [*GDF['X_SHIFT_PX'].map(lambda px: oFV if px==oFV else px*self.im2shift.xgsd)]
GDF['Y_SHIFT_M'] = [*GDF['Y_SHIFT_PX'].map(lambda px: oFV if px==oFV else px*self.im2shift.ygsd)]
get_absShift = lambda row: float(np.sqrt(row['X_SHIFT_M']**2 + row['Y_SHIFT_M']**2))
GDF['ABS_SHIFT'] = GDF.apply(lambda row: oFV if row['X_SHIFT_M'] == oFV else get_absShift(row), axis=1)
get_angle = lambda row: GEO.angle_to_north((row['X_SHIFT_PX'],row['Y_SHIFT_PX'])).tolist()[0]
GDF['ANGLE'] = GDF.apply(lambda row: oFV if row['X_SHIFT_PX'] == oFV else get_angle(row) , axis=1)
results[i,:] = self._get_spatial_shifts(coreg_kwargs)
# merge results with GDF
records = GeoDataFrame(np.array(results, np.object), columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE',
'X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M',
'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE'])
GDF = GDF.merge(records, on='POINT_ID', how="inner")
GDF = GDF.fillna(int(self.outFillVal))
self.quality_grid = GDF
......@@ -442,24 +437,8 @@ class Geom_Quality_Grid(object):
return self.Kriging_sp(*args,**kwargs)
def correct_shifts(self, max_GCP_count=None):
coreg_info = self.COREG_obj.coreg_info
coreg_info['GCPList'] = self.GCPList if self.GCPList else self.to_GCPList()
if max_GCP_count:
coreg_info['GCPList'] = coreg_info['GCPList'][:max_GCP_count]
DS = DESHIFTER(self.im2shift, coreg_info,
path_out=None,
out_gsd=(self.im2shift.xgsd,self.im2shift.ygsd),
align_grids=True,
v=self.v,
q=self.q)
deshift_results = DS.correct_shifts()
return deshift_results
def view_results(self, attribute2plot='ABS_SHIFT', cmap=None, exclude_fillVals=True, backgroundIm='tgt'):
def view_results(self, attribute2plot='ABS_SHIFT', cmap=None, exclude_fillVals=True, backgroundIm='tgt',
savefigPath='', savefigDPI=96):
"""Shows a map of the calculated quality grid with the target image as background.
:param attribute2plot: <str> the attribute of the quality grid to be shown (default: 'ABS_SHIFT')
......@@ -501,7 +480,8 @@ class Geom_Quality_Grid(object):
GDF['plt_X'] = [*GDF['plt_XY'].map(lambda XY: XY[0])]
GDF['plt_Y'] = [*GDF['plt_XY'].map(lambda XY: XY[1])]
points = plt.scatter(GDF['plt_X'],GDF['plt_Y'], c=GDF[attribute2plot],
cmap=palette, marker='o', s=50, alpha=1.0)
#cmap=palette, marker='o', s=50, alpha=1.0)
cmap=palette, marker='.', s=50, alpha=1.0)
# add colorbar
divider = make_axes_locatable(plt.gca())
......@@ -510,6 +490,9 @@ class Geom_Quality_Grid(object):
plt.show()
if savefigPath:
fig.savefig(savefigPath, dpi=savefigDPI)
def view_results_folium(self, attribute2plot='ABS_SHIFT', cmap=None, exclude_fillVals=True):
warnings.warn(UserWarning('This function is still under construction and may not work as expected!'))
......@@ -550,3 +533,30 @@ class Geom_Quality_Grid(object):
return map_osm
def correct_shifts(self, max_GCP_count=None):
"""Performs a local shift correction using all points from the previously calculated geometric quality grid
that contain valid matches as GCP points.
:param max_GCP_count: <int> maximum number of GCPs to use
:return:
"""
coreg_info = self.COREG_obj.coreg_info
coreg_info['GCPList'] = self.GCPList if self.GCPList else self.to_GCPList()
if max_GCP_count:
coreg_info['GCPList'] = coreg_info['GCPList'][:max_GCP_count] # TODO should be a random sample
DS = DESHIFTER(self.im2shift, coreg_info,
path_out = None,
out_gsd = (self.im2shift.xgsd,self.im2shift.ygsd),
align_grids = True,
#cliptoextent = True, # why?
#clipextent = self.im2shift.box.boxMapYX,
#options = '-wm 10000 -order 1',
v = self.v,
q = self.q)
deshift_results = DS.correct_shifts()
return deshift_results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment