Commit 199c1e3b authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

some new functions and improvements

components.CoReg_local.COREG_LOCAL:
- view_CoRegPoints():
    - added keyword 'return_map'
    - revised size of major ticks
    - added separate color coding for attribute2plot='ANGLE'
    - added legend

components.io:
- write_shp(): now also accepts EPGS codes

components.Tie_Point_Grid.Tie_Point_Grid:
- added calc_rmse()
- added calc_overall_mssim()
- added plot_shift_distribution()
- to_vectorfield(): fix for not properly setting output fill value

 components.Tie_Point_Grid.Tie_Point_Refiner:
 - run_filtering() fix for always appendinf 'L3_OUTLIER'

updated __version__
parent cfd1b29d
......@@ -9,7 +9,7 @@ from .components import utilities
from .components import geometry
__author__ = 'Daniel Scheffler'
__version__= '2017-03-16_01'
__version__= '2017-03-28_01'
__all__=['COREG',
'COREG_LOCAL',
......
......@@ -837,7 +837,11 @@ class COREG(object):
if wsYX:
time0 = time.time()
if self.v: print('final window size: %s/%s (X/Y)' % (wsYX[1], wsYX[0]))
if self.v:
print('final window size: %s/%s (X/Y)' % (wsYX[1], wsYX[0]))
# FIXME size of self.matchWin is not updated
# FIXME CoRegPoints_grid.WIN_SZ is taken from self.matchBox.imDimsYX but this is not updated
center_YX = np.array(im0.shape)/2
xmin,xmax,ymin,ymax = int(center_YX[1]-wsYX[1]/2), int(center_YX[1]+wsYX[1]/2),\
int(center_YX[0]-wsYX[0]/2), int(center_YX[0]+wsYX[0]/2)
......@@ -1027,6 +1031,7 @@ class COREG(object):
# check if integer shifts are now gone (0/0)
scps = self._calc_shifted_cross_power_spectrum(gdsh_im0, crsp_im1)
if scps is not None:
peakpos = self._get_peakpos(scps)
x_shift, y_shift = self._get_shifts_from_peakpos(peakpos, scps.shape)
......@@ -1252,7 +1257,7 @@ class COREG(object):
geotransform=self.shift.gt, projection=self.shift.prj)[0]
self.x_shift_map, self.y_shift_map = new_originX - self.shift.gt[0], new_originY - self.shift.gt[3]
# get length of shift vecor in map units
# get length of shift vector in map units
self.vec_length_map = float(np.sqrt(self.x_shift_map ** 2 + self.y_shift_map ** 2))
# get angle of shift vector
......@@ -1274,7 +1279,7 @@ class COREG(object):
self._get_updated_map_info()
# set self.ssim_before and ssim_after
self._validate_ssim_improvement()
self._validate_ssim_improvement() # FIXME uses the not updated matchWin size
self.shift_reliability = self._calc_shift_reliability(scps)
warnings.simplefilter('default')
......
......@@ -248,11 +248,13 @@ class COREG_LOCAL(object):
return os.path.abspath(self._projectDir)
else:
# return a project name that not already has a corresponding folder on disk
root_dir = os.path.dirname(self.im2shift.filePath) if self.im2shift.filePath else os.path.curdir
projectDir = os.path.join(root_dir, 'UntitledProject_1')
while os.path.isdir(projectDir):
projectDir = '%s_%s' % (projectDir.split('_')[0], int(projectDir.split('_')[-1]) + 1)
self._projectDir = projectDir
root_dir = os.path.dirname(self.im2shift.filePath) if self.im2shift.filePath else os.path.curdir
fold_name = 'UntitledProject_1'
while os.path.isdir(os.path.join(root_dir, fold_name)):
fold_name = '%s_%s' % (fold_name.split('_')[0], int(fold_name.split('_')[-1]) + 1)
self._projectDir = os.path.join(root_dir, fold_name)
return self._projectDir
......@@ -301,7 +303,8 @@ class COREG_LOCAL(object):
def view_CoRegPoints(self, attribute2plot='ABS_SHIFT', cmap=None, exclude_fillVals=True, backgroundIm='tgt',
hide_filtered=True, figsize=None, savefigPath='', savefigDPI=96, showFig=True, zoomable=False):
hide_filtered=True, figsize=None, savefigPath='', savefigDPI=96, showFig=True,
return_map=False, zoomable=False):
"""Shows a map of the calculated quality grid with the target image as background.
:param attribute2plot: <str> the attribute of the quality grid to be shown (default: 'ABS_SHIFT')
......@@ -315,6 +318,7 @@ class COREG_LOCAL(object):
:param savefigPath:
:param savefigDPI:
:param showFig: <bool> whether to show or to hide the figure
:param return_map <bool>
:param zoomable: <bool> enable or disable zooming via mpld3
:return:
"""
......@@ -324,12 +328,16 @@ class COREG_LOCAL(object):
backgroundIm = self.im2shift if backgroundIm=='tgt' else self.imref
fig, ax, map2show = backgroundIm.show_map(figsize=figsize, nodataVal=self.nodata[1], return_map=True,
band=self.COREG_obj.shift.band4match, zoomable=zoomable)
plt.tick_params(axis='both', which='major', labelsize=40)
#ax.tick_params(axis='both', which='minor', labelsize=8)
# fig, ax, map2show = backgroundIm.show_map_utm(figsize=(20,20), nodataVal=self.nodata[1], return_map=True)
plt.title(attribute2plot)
# transform all points of quality grid to LonLat
outlierCols = [c for c in self.CoRegPoints_table.columns if 'OUTLIER' in c]
attr2include = ['geometry', attribute2plot] + outlierCols
attr2include = ['geometry', attribute2plot] + outlierCols + ['X_SHIFT_M', 'Y_SHIFT_M']
GDF = self.CoRegPoints_table.loc\
[self.CoRegPoints_table.X_SHIFT_M != self.outFillVal, attr2include].copy() \
if exclude_fillVals else self.CoRegPoints_table.loc[:, attr2include]
......@@ -342,7 +350,19 @@ class COREG_LOCAL(object):
#vmin = min(GDF[GDF[attribute2plot] != self.outFillVal][attribute2plot])
#vmax = max(GDF[GDF[attribute2plot] != self.outFillVal][attribute2plot])
#norm = mpl_normalize(vmin=vmin, vmax=vmax)
palette = cmap if cmap else plt.cm.RdYlGn_r
palette = cmap if cmap is not None else plt.cm.RdYlGn_r
if cmap is None and attribute2plot == 'ANGLE':
#import matplotlib.colors as mcolors
#colors1 = plt.cm.RdYlGn_r(np.linspace(0., 1, 128))
#colors2 = plt.cm.RdYlGn(np.linspace(0., 1, 128))
## combine them and build a new colormap
#colors = np.vstack((colors1, colors2))
#palette = mcolors.LinearSegmentedColormap.from_list('my_colormap', colors)
#palette = plt.cm.hsv
import cmocean
palette = cmocean.cm.delta
#GDF['color'] = [*GDF[attribute2plot].map(lambda val: palette(norm(val)))]
# add quality grid to map
......@@ -361,24 +381,33 @@ class COREG_LOCAL(object):
if self.tieP_filter_level > 0:
# flag level 1 outliers
GDF_filt = GDF[GDF.L1_OUTLIER == True].copy()
plt.scatter(GDF_filt['plt_X'], GDF_filt['plt_Y'], c='b', marker=marker, s=250, alpha=1.0)
plt.scatter(GDF_filt['plt_X'], GDF_filt['plt_Y'], c='b', marker=marker, s=250, alpha=1.0, label='reliability')
if self.tieP_filter_level > 1:
# flag level 2 outliers
GDF_filt = GDF[GDF.L2_OUTLIER == True].copy()
plt.scatter(GDF_filt['plt_X'], GDF_filt['plt_Y'], c='r', marker=marker, s=150, alpha=1.0)
plt.scatter(GDF_filt['plt_X'], GDF_filt['plt_Y'], c='r', marker=marker, s=150, alpha=1.0, label='MSSIM')
if self.tieP_filter_level > 2:
# flag level 3 outliers
GDF_filt = GDF[GDF.L3_OUTLIER == True].copy()
plt.scatter(GDF_filt['plt_X'], GDF_filt['plt_Y'], c='y', marker=marker, s=250, alpha=1.0)
plt.scatter(GDF_filt['plt_X'], GDF_filt['plt_Y'], c='y', marker=marker, s=250, alpha=1.0, label='RANSAC')
if self.tieP_filter_level > 0:
plt.legend(loc=0, scatterpoints = 1)
# plot all points on top
if not GDF.empty:
vmin, vmax = np.percentile(GDF[attribute2plot], 0), np.percentile(GDF[attribute2plot], 95)
points = plt.scatter(GDF['plt_X'],GDF['plt_Y'], c=GDF[attribute2plot],
vmin, vmax = (np.percentile(GDF[attribute2plot], 0), np.percentile(GDF[attribute2plot], 95)) \
if attribute2plot!='ANGLE' else (0, 360)
#vmin=None # TODO make this adjustable
#vmax=None
points = plt.scatter(GDF['plt_X'],GDF['plt_Y'], c=GDF[attribute2plot], lw = 0,
cmap=palette, marker='o' if len(GDF)<10000 else '.', s=50, alpha=1.0,
vmin=vmin, vmax=vmax)
# plot shift vectors
#map2show.quiver(GDF['plt_X'], GDF['plt_Y'], GDF['X_SHIFT_M'], GDF['Y_SHIFT_M'])#, scale=700)
# add colorbar
divider = make_axes_locatable(plt.gca())
cax = divider.append_axes("right", size="2%",
......@@ -392,6 +421,9 @@ class COREG_LOCAL(object):
if savefigPath:
fig.savefig(savefigPath, dpi=savefigDPI)
if return_map:
return fig, ax, map2show
if showFig and not self.q:
plt.show(block=True)
else:
......@@ -416,7 +448,7 @@ class COREG_LOCAL(object):
center_lon, center_lat = (lon_min+lon_max)/2, (lat_min+lat_max)/2
# get image to plot
image2plot = self.im2shift[0] # FIXME hardcoded band
image2plot = self.im2shift[:,:,0] # FIXME hardcoded band
from py_tools_ds.ptds.geo.raster.reproject import warp_ndarray
image2plot, gt, prj = warp_ndarray(image2plot, self.im2shift.geotransform, self.im2shift.projection,
......
......@@ -13,6 +13,7 @@ try:
except ImportError:
from osgeo import gdal
import numpy as np
from matplotlib import pyplot as plt
from geopandas import GeoDataFrame, GeoSeries
from pykrige.ok import OrdinaryKriging
from shapely.geometry import Point
......@@ -22,7 +23,7 @@ from skimage.transform import AffineTransform, PolynomialTransform
# internal modules
from .CoReg import COREG
from . import io as IO
from py_tools_ds.ptds.geo.projection import isProjectedOrGeographic, get_UTMzone, dict_to_proj4
from py_tools_ds.ptds.geo.projection import isProjectedOrGeographic, get_UTMzone, dict_to_proj4, proj4_to_WKT
from py_tools_ds.ptds.io.pathgen import get_generic_outpath
from py_tools_ds.ptds.processing.progress_mon import ProgressBar
from py_tools_ds.ptds.geo.vector.conversion import points_to_raster
......@@ -349,6 +350,128 @@ class Tie_Point_Grid(object):
return self.CoRegPoints_table
def calc_rmse(self, include_outliers=False):
# type: (bool) -> float
"""Calculates root mean square error of absolute shifts from the tie point grid.
:param include_outliers: whether to include tie points that have been marked as false-positives
"""
tbl = self.CoRegPoints_table
tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == False].copy()
shifts = np.array(tbl['ABS_SHIFT'])
shifts_sq = [i * i for i in shifts if i != self.outFillVal]
return np.sqrt(sum(shifts_sq) / len(shifts_sq))
def calc_overall_mssim(self, include_outliers=False):
# type: (bool) -> float
"""Calculates the median value of all MSSIM values contained in tie point grid.
:param include_outliers: whether to include tie points that have been marked as false-positives
"""
tbl = self.CoRegPoints_table
tbl = tbl if include_outliers else tbl[tbl['OUTLIER'] == False].copy()
mssim_col = np.array(tbl['MSSIM'])
mssim_col = [i * i for i in mssim_col if i != self.outFillVal]
return float(np.median(mssim_col))
def plot_shift_distribution(self, include_outliers=True, unit='m', interactive=False, figsize=None, xlim=None,
ylim=None, fontsize=12, title='shift distribution'):
# type: (bool, str, bool, tuple, list, list, int) -> tuple
"""Creates a 2D scatterplot containing the distribution of calculated X/Y-shifts.
:param include_outliers: whether to include tie points that have been marked as false-positives
:param unit: 'm' for meters or 'px' for pixels (default: 'm')
:param interactive: interactive mode uses plotly for visualization
:param figsize: (xdim, ydim)
:param xlim: [xmin, xmax]
:param ylim: [ymin, ymax]
:param fontsize: size of all used fonts
:param title: the title to be plotted above the figure
"""
if not unit in ['m', 'px']:
raise ValueError("Parameter 'unit' must have the value 'm' (meters) or 'px' (pixels)! Got %s." %unit)
tbl = self.CoRegPoints_table
tbl_il = tbl[tbl['OUTLIER'] == False].copy()
tbl_ol = tbl[tbl['OUTLIER'] == True].copy()
x_attr = 'X_SHIFT_M' if unit == 'm' else 'X_SHIFT_PX'
y_attr = 'Y_SHIFT_M' if unit == 'm' else 'Y_SHIFT_PX'
rmse = self.calc_rmse(include_outliers=False) # always exclude outliers when calculating RMSE
figsize = figsize if figsize else (10,10)
if interactive:
from plotly.offline import iplot, init_notebook_mode
import plotly.graph_objs as go
init_notebook_mode(connected=True)
# Create a trace
trace = go.Scatter(
x=tbl_il[x_attr],
y=tbl_il[y_attr],
mode='markers'
)
data = [trace]
# Plot and embed in ipython notebook!
iplot(data, filename='basic-scatter')
return None, None
else:
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
if include_outliers:
ax.scatter(tbl_ol[x_attr], tbl_ol[y_attr], marker='+', c='r', label='false positives')
ax.scatter(tbl_il[x_attr], tbl_il[y_attr], marker='+', c='g', label='valid tie points')
# set axis limits
if not xlim:
xmax = np.abs(tbl_il[x_attr]).max()
xlim = [-xmax, xmax]
if not ylim:
ymax = np.abs(tbl_il[y_attr]).max()
ylim = [-ymax, ymax]
ax.set_xlim(xlim)
ax.set_ylim(ylim)
plt.text(xlim[1]-(xlim[1]/20),-ylim[1]+(ylim[1]/20), 'RMSE: %s m' % np.round(rmse, 2), ha='right',
va='bottom', fontsize=fontsize, bbox=dict(facecolor='w', pad=None, alpha=0.8))
# add grid
plt.grid()
xgl = ax.get_xgridlines()
middle_xgl = xgl[int(np.median(np.array(range(len(xgl)))))]
middle_xgl.set_linewidth(1)
middle_xgl.set_linestyle('-')
ygl = ax.get_ygridlines()
middle_ygl = ygl[int(np.median(np.array(range(len(ygl)))))]
middle_ygl.set_linewidth(1)
middle_ygl.set_linestyle('-')
[tick.label.set_fontsize(fontsize) for tick in ax.xaxis.get_major_ticks()]
[tick.label.set_fontsize(fontsize) for tick in ax.yaxis.get_major_ticks()]
plt.legend(fontsize=fontsize)
ax.set_title(title, fontsize=fontsize)
plt.xlabel('x-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
plt.ylabel('y-shift [%s]' % 'meters' if unit == 'm' else 'pixels', fontsize=fontsize)
plt.show()
return fig, ax
def dump_CoRegPoints_table(self, path_out=None):
path_out = path_out if path_out else get_generic_outpath(dir_out=self.dir_out,
fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl" % (self.grid_res, self.COREG_obj.win_size_XY[0],
......@@ -467,7 +590,7 @@ class Tie_Point_Grid(object):
def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
# type: (str) -> None
# type: (str) -> GeoArray
"""Saves the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield
(e.g. using ArcGIS)
......@@ -485,12 +608,14 @@ class Tie_Point_Grid(object):
xshift_arr, gt, prj = points_to_raster(points = self.CoRegPoints_table['geometry'],
values = self.CoRegPoints_table[attr_b1],
tgt_res = self.shift.xgsd * self.grid_res,
prj = dict_to_proj4(self.CoRegPoints_table.crs))
prj = proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
fillVal = self.outFillVal)
yshift_arr, gt, prj = points_to_raster(points = self.CoRegPoints_table['geometry'],
values = self.CoRegPoints_table[attr_b2],
tgt_res = self.shift.xgsd * self.grid_res,
prj = dict_to_proj4(self.CoRegPoints_table.crs))
prj = proj4_to_WKT(dict_to_proj4(self.CoRegPoints_table.crs)),
fillVal = self.outFillVal)
out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)
......@@ -657,7 +782,7 @@ class Tie_Point_Refiner(object):
print('RANSAC skipped because too less valid tie points have been found.')
self.GDF['L3_OUTLIER'] = False
self.new_cols.append('L3_OUTLIER')
self.new_cols.append('L3_OUTLIER')
self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
......@@ -680,6 +805,11 @@ class Tie_Point_Refiner(object):
:return:
"""
#ssim_diff = np.median(self.GDF['SSIM_AFTER']) - np.median(self.GDF['SSIM_BEFORE'])
#self.GDF.SSIM_IMPROVED = self.GDF.apply(lambda GDF_row: GDF_row['SSIM_AFTER']>GDF_row['SSIM_BEFORE'] + ssim_diff, axis=1)
return self.GDF.SSIM_IMPROVED == False
......
......@@ -15,6 +15,7 @@ from spectral.io import envi
# internal modules
from .utilities import get_image_tileborders, convertGdalNumpyDataType
from py_tools_ds.ptds.geo.map_info import geotransform2mapinfo
from py_tools_ds.ptds.geo.projection import EPSG2WKT
from py_tools_ds.ptds.dtypes.conversion import get_dtypeStr
......@@ -107,6 +108,7 @@ def write_shp(path_out, shapely_geom, prj=None,attrDict=None):
ds = ogr.GetDriverByName("Esri Shapefile").CreateDataSource(path_out)
if prj is not None:
prj = prj if not isinstance(prj, int) else EPSG2WKT(prj)
srs = osr.SpatialReference()
srs.ImportFromWkt(prj)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment