Commit 8ef57667 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

some bugfixes and further developments

geo.raster.reproject:
- warp_ndarray(): added output assertion

geo.projection:
- prj_equal(): added better docstring

io.raster.GeoArray:
- GeoArray:
    - revised mask_nodata property
    - added assertion to footprint_poly
    - revised calc_mask_nodata()
-_clip_array_at_mapPos(): bugfix for missing fillVal in case nodata value could not be derived
parent 1c9140cf
...@@ -51,8 +51,12 @@ def proj4_to_dict(proj4): ...@@ -51,8 +51,12 @@ def proj4_to_dict(proj4):
def prj_equal(prj1, prj2): def prj_equal(prj1, prj2):
#type: (str,str) -> bool #type: (any,any) -> bool
"""Checks if the given two projections are equal.""" """Checks if the given two projections are equal.
:param prj1: projection 1 (WKT or 'epsg:1234' or <EPSG_int>)
:param prj2: projection 2 (WKT or 'epsg:1234' or <EPSG_int>)
"""
return get_proj4info(proj=prj1)==get_proj4info(proj=prj2) return get_proj4info(proj=prj1)==get_proj4info(proj=prj2)
......
...@@ -18,7 +18,7 @@ from rasterio.warp import reproject as rio_reproject ...@@ -18,7 +18,7 @@ from rasterio.warp import reproject as rio_reproject
from rasterio.warp import calculate_default_transform as rio_calc_transform from rasterio.warp import calculate_default_transform as rio_calc_transform
from rasterio.warp import Resampling from rasterio.warp import Resampling
from ..projection import WKT2EPSG, isProjectedOrGeographic from ..projection import WKT2EPSG, isProjectedOrGeographic, prj_equal
from ..coord_trafo import pixelToLatLon from ..coord_trafo import pixelToLatLon
from ...io.raster.gdal import get_GDAL_ds_inmem from ...io.raster.gdal import get_GDAL_ds_inmem
from ...processing.progress_mon import printProgress from ...processing.progress_mon import printProgress
...@@ -383,7 +383,7 @@ def warp_ndarray(ndarray, in_gt, in_prj, out_prj=None, out_dtype=None, out_gsd=( ...@@ -383,7 +383,7 @@ def warp_ndarray(ndarray, in_gt, in_prj, out_prj=None, out_dtype=None, out_gsd=(
# GDAL Warp # GDAL Warp
gdal_Warp = get_gdal_func('Warp') gdal_Warp = get_gdal_func('Warp')
out_ds = gdal_Warp( res_ds = gdal_Warp(
'', in_ds, format='MEM', '', in_ds, format='MEM',
dstSRS = get_SRS(out_prj), dstSRS = get_SRS(out_prj),
outputType = get_GDT(out_dtype) if out_dtype else get_GDT(ndarray.dtype), outputType = get_GDT(out_dtype) if out_dtype else get_GDT(ndarray.dtype),
...@@ -412,18 +412,22 @@ def warp_ndarray(ndarray, in_gt, in_prj, out_prj=None, out_dtype=None, out_gsd=( ...@@ -412,18 +412,22 @@ def warp_ndarray(ndarray, in_gt, in_prj, out_prj=None, out_dtype=None, out_gsd=(
gdal.SetConfigOption('GDAL_NUM_THREADS', None) gdal.SetConfigOption('GDAL_NUM_THREADS', None)
if out_ds is None: if res_ds is None:
raise Exception('Warping Error: ' + gdal.GetLastErrorMsg()) raise Exception('Warping Error: ' + gdal.GetLastErrorMsg())
# get outputs # get outputs
out_arr = gdalnumeric.DatasetReadAsArray(out_ds) res_arr = gdalnumeric.DatasetReadAsArray(res_ds)
if len(out_arr.shape) == 3: if len(res_arr.shape) == 3:
out_arr = np.swapaxes(np.swapaxes(out_arr, 0, 2), 0, 1) res_arr = np.swapaxes(np.swapaxes(res_arr, 0, 2), 0, 1)
out_gt = out_ds.GetGeoTransform() res_gt = res_ds.GetGeoTransform()
out_prj = out_ds.GetProjection() res_prj = res_ds.GetProjection()
# cleanup # cleanup
in_ds = out_ds = None in_ds = res_ds = None
return out_arr, out_gt, out_prj # test output
\ No newline at end of file if prj_equal(out_prj,4626):
assert -180 < res_gt[0] < 180 and -90 < res_gt[3] < 90, 'Testing of gdal_warp output failed.'
return res_arr, res_gt, res_prj
\ No newline at end of file
...@@ -29,7 +29,7 @@ from ...geo.raster.conversion import raster2polygon ...@@ -29,7 +29,7 @@ from ...geo.raster.conversion import raster2polygon
from ...geo.vector.topology import get_overlap_polygon, get_footprint_polygon from ...geo.vector.topology import get_overlap_polygon, get_footprint_polygon
from ...geo.vector.geometry import boxObj from ...geo.vector.geometry import boxObj
from ...io.raster.gdal import get_GDAL_ds_inmem from ...io.raster.gdal import get_GDAL_ds_inmem
from ...numeric.array import find_noDataVal from ...numeric.array import find_noDataVal, get_outFillZeroSaturated
...@@ -257,12 +257,8 @@ class GeoArray(object): ...@@ -257,12 +257,8 @@ class GeoArray(object):
if self._mask_nodata is not None: if self._mask_nodata is not None:
return self._mask_nodata return self._mask_nodata
else: else:
if self.nodata is not None: self.calc_mask_nodata() # sets self._mask_nodata
self.calc_mask_nodata() # sets self._mask_nodata return self._mask_nodata
return self._mask_nodata
else:
warnings.warn('Calculation of nodata mask failed due to missing no data value.')
return None
@mask_nodata.setter @mask_nodata.setter
...@@ -275,6 +271,7 @@ class GeoArray(object): ...@@ -275,6 +271,7 @@ class GeoArray(object):
if self._footprint_poly is not None: if self._footprint_poly is not None:
return self._footprint_poly return self._footprint_poly
else: else:
assert self.mask_nodata is not None, 'A nodata mask is needed for calculating the footprint polygon. '
self._footprint_poly = raster2polygon(self, exact=False) self._footprint_poly = raster2polygon(self, exact=False)
return self._footprint_poly return self._footprint_poly
...@@ -352,10 +349,15 @@ class GeoArray(object): ...@@ -352,10 +349,15 @@ class GeoArray(object):
self.__dict__ = state self.__dict__ = state
def calc_mask_nodata(self, fromBand=None): def calc_mask_nodata(self, fromBand=None, overwrite=False):
arr = self[:,:,fromBand] if self._mask_nodata is None or overwrite:
self.mask_nodata = np.where(arr == self.nodata, 0, 1).astype(np.uint8) if self.ndim == 2 else \ arr = self[fromBand]
np.all(np.where(arr == self.nodata, 0, 1), axis=2).astype(np.uint8) if self.nodata is None:
self.mask_nodata = np.ones((self.rows, self.cols), np.uint8)
else:
assert self.ndim in [2,3], "Only 2D or 3D arrays are supported. Got a %sD array." %self.ndim
self.mask_nodata = np.where(arr == self.nodata, 0, 1).astype(np.uint8) if arr.ndim == 2 else \
np.all(np.where(arr == self.nodata, 0, 1), axis=2).astype(np.uint8)
def set_gdalDataset_meta(self): def set_gdalDataset_meta(self):
...@@ -820,7 +822,8 @@ def _clip_array_at_mapPos(arr, mapBounds, arr_gt, band2clip=None, fillVal=0): ...@@ -820,7 +822,8 @@ def _clip_array_at_mapPos(arr, mapBounds, arr_gt, band2clip=None, fillVal=0):
tgt_shape = (tgt_rows, tgt_cols, tgt_bands) if tgt_bands > 1 else (tgt_rows, tgt_cols) tgt_shape = (tgt_rows, tgt_cols, tgt_bands) if tgt_bands > 1 else (tgt_rows, tgt_cols)
try: try:
out_arr = np.full(tgt_shape, fillVal, arr_dtype) fillVal = fillVal if fillVal is not None else get_outFillZeroSaturated(arr)[0]
out_arr = np.full(tgt_shape, fillVal, arr_dtype)
except MemoryError: except MemoryError:
raise MemoryError('Calculated target dimensions are %s. Check your inputs!' %str(tgt_shape)) raise MemoryError('Calculated target dimensions are %s. Check your inputs!' %str(tgt_shape))
...@@ -842,7 +845,7 @@ def _clip_array_at_mapPos(arr, mapBounds, arr_gt, band2clip=None, fillVal=0): ...@@ -842,7 +845,7 @@ def _clip_array_at_mapPos(arr, mapBounds, arr_gt, band2clip=None, fillVal=0):
cS_out, rS_out = [int(i) for i in mapXY2imXY((xmin_in, ymax_in), out_gt)] cS_out, rS_out = [int(i) for i in mapXY2imXY((xmin_in, ymax_in), out_gt)]
cE_out, rE_out = [int(i)-1 for i in mapXY2imXY((xmax_in, ymin_in), out_gt)] # -1 because max values do not represent pixel origins cE_out, rE_out = [int(i)-1 for i in mapXY2imXY((xmax_in, ymin_in), out_gt)] # -1 because max values do not represent pixel origins
# fill newy created array with read data from input array # fill newly created array with read data from input array
if tgt_bands==1: if tgt_bands==1:
out_arr[rS_out:rE_out + 1, cS_out:cE_out + 1] = data out_arr[rS_out:rE_out + 1, cS_out:cE_out + 1] = data
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment