Commit 1f9be4ce authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Added resampling algorithms 'bilinear' and 'custom'. Added docstrings and type hints.

parent 7d171c08
Pipeline #3353 failed with stages
in 52 seconds
...@@ -22,9 +22,10 @@ gdal_env.try2set_GDAL_DATA() ...@@ -22,9 +22,10 @@ gdal_env.try2set_GDAL_DATA()
def get_proj4info(ds=None, proj=None): def get_proj4info(ds=None, proj=None):
# type: (gdal.Dataset,str) -> str # type: (gdal.Dataset, Union[str, int]) -> str
"""Returns PROJ4 formatted projection info for the given gdal.Dataset or projection respectivly, """Returns PROJ4 formatted projection info for the given gdal.Dataset or projection respectivly,
e.g. '+proj=utm +zone=43 +datum=WGS84 +units=m +no_defs ' e.g. '+proj=utm +zone=43 +datum=WGS84 +units=m +no_defs '
:param ds: <gdal.Dataset> the gdal dataset to get PROJ4 info for :param ds: <gdal.Dataset> the gdal dataset to get PROJ4 info for
:param proj: <str,int> the projection to get PROJ4 formatted info for (WKT or 'epsg:1234' or <EPSG_int>) :param proj: <str,int> the projection to get PROJ4 formatted info for (WKT or 'epsg:1234' or <EPSG_int>)
""" """
......
...@@ -5,6 +5,7 @@ import warnings ...@@ -5,6 +5,7 @@ import warnings
import multiprocessing import multiprocessing
import os import os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Union, Tuple, List # noqa: F401
# custom # custom
try: try:
...@@ -19,7 +20,8 @@ from rasterio.warp import reproject as rio_reproject ...@@ -19,7 +20,8 @@ from rasterio.warp import reproject as rio_reproject
from rasterio.warp import calculate_default_transform as rio_calc_transform from rasterio.warp import calculate_default_transform as rio_calc_transform
from rasterio.warp import Resampling from rasterio.warp import Resampling
from pyresample.geometry import AreaDefinition, SwathDefinition from pyresample.geometry import AreaDefinition, SwathDefinition
from pyresample.kd_tree import resample_nearest, resample_gauss from pyresample.bilinear import resample_bilinear
from pyresample.kd_tree import resample_nearest, resample_gauss, resample_custom
from pyresample.utils import get_area_def from pyresample.utils import get_area_def
from ...dtypes.conversion import dTypeDic_NumPy2GDAL from ...dtypes.conversion import dTypeDic_NumPy2GDAL
...@@ -496,6 +498,16 @@ def warp_ndarray(ndarray, in_gt, in_prj=None, out_prj=None, out_dtype=None, ...@@ -496,6 +498,16 @@ def warp_ndarray(ndarray, in_gt, in_prj=None, out_prj=None, out_dtype=None,
class SensorMapGeometryTransformer(object): class SensorMapGeometryTransformer(object):
def __init__(self, data, lons, lats, resamp_alg='nearest', **opts): def __init__(self, data, lons, lats, resamp_alg='nearest', **opts):
# type: (np.ndarray, np.ndarray, np.ndarray, str, dict) -> None # type: (np.ndarray, np.ndarray, np.ndarray, str, dict) -> None
"""Get an instance of SensorMapGeometryTransformer.
:param data: numpy array to be warped to sensor or map geometry
:param lons: longitude array
:param lats: latitude array
:param resamp_alg: resampling algorithm ('nearest', 'bilinear', 'gauss', 'custom')
:param opts: options to be passed as keyword arguments to the pyresample resampling function,
for documentation see here: https://pyresample.readthedocs.io/en/latest/swath.html
"""
self.data = data self.data = data
self.resamp_alg = resamp_alg self.resamp_alg = resamp_alg
self.opts = dict(radius_of_influence=30, self.opts = dict(radius_of_influence=30,
...@@ -507,6 +519,11 @@ class SensorMapGeometryTransformer(object): ...@@ -507,6 +519,11 @@ class SensorMapGeometryTransformer(object):
self.area_extent = [np.min(lons), np.min(lats), np.max(lons), np.max(lats)] self.area_extent = [np.min(lons), np.min(lats), np.max(lons), np.max(lats)]
def compute_output_shape(self): def compute_output_shape(self):
# type: () -> Tuple[int, int]
"""Estimates the map geometry output shape of a sensor geometry array resampled to map geometry.
:return:
"""
with TemporaryDirectory() as td: with TemporaryDirectory() as td:
path_lons_lats = os.path.join(td, 'lons_lats.bsq') path_lons_lats = os.path.join(td, 'lons_lats.bsq')
path_lons_lats_vrt = os.path.join(td, 'lons_lats.vrt') path_lons_lats_vrt = os.path.join(td, 'lons_lats.vrt')
...@@ -551,10 +568,16 @@ class SensorMapGeometryTransformer(object): ...@@ -551,10 +568,16 @@ class SensorMapGeometryTransformer(object):
y_size = ds_out.RasterYSize y_size = ds_out.RasterYSize
del ds_out del ds_out
return x_size, y_size return x_size, y_size
def get_area_definition(self, proj4_args, cols_out, rows_out): def get_area_definition(self, proj4_args, cols_out, rows_out):
"""Get output area definition.""" # type: (Union[str, list], int, int) -> AreaDefinition
"""Get output area definition.
:param proj4_args: Proj4 arguments as list of arguments or string
:param cols_out: number of output columns
:param rows_out: number of output rows
"""
area_def_out = get_area_def(area_id='', area_def_out = get_area_def(area_id='',
area_name='', area_name='',
proj_id='', proj_id='',
...@@ -562,20 +585,46 @@ class SensorMapGeometryTransformer(object): ...@@ -562,20 +585,46 @@ class SensorMapGeometryTransformer(object):
x_size=cols_out, x_size=cols_out,
y_size=rows_out, y_size=rows_out,
area_extent=self.area_extent # xmin, ymin, xmax, ymax area_extent=self.area_extent # xmin, ymin, xmax, ymax
) ) # type: AreaDefinition
return area_def_out return area_def_out
def _resample(self, source_geo_def: object, target_geo_def: object): def _resample(self, source_geo_def, target_geo_def):
# type: (Union[AreaDefinition, SwathDefinition], Union[AreaDefinition, SwathDefinition]) -> np.ndarray
"""Run the resampling algorithm.
:param source_geo_def: source geo definition
:param target_geo_def: target geo definition
:return:
"""
if self.resamp_alg == 'nearest': if self.resamp_alg == 'nearest':
opts = {k: v for k, v in self.opts.items() if k not in ['sigmas']} opts = {k: v for k, v in self.opts.items() if k not in ['sigmas']}
result = resample_nearest(source_geo_def, self.data, target_geo_def, **opts) result = resample_nearest(source_geo_def, self.data, target_geo_def, **opts)
else:
elif self.resamp_alg == 'bilinear':
opts = {k: v for k, v in self.opts.items() if k not in ['sigmas']}
result = resample_bilinear(self.data, source_geo_def, target_geo_def, **opts)
elif self.resamp_alg == 'gauss':
opts = {k: v for k, v in self.opts.items()} opts = {k: v for k, v in self.opts.items()}
result = resample_gauss(source_geo_def, self.data, target_geo_def, **opts) result = resample_gauss(source_geo_def, self.data, target_geo_def, **opts)
elif self.resamp_alg == 'custom':
opts = {k: v for k, v in self.opts.items()}
if 'weight_funcs' not in opts:
raise ValueError(opts, "Options must contain a 'weight_funcs' item.")
result = resample_custom(source_geo_def, self.data, target_geo_def, **opts)
else:
raise ValueError(self.resamp_alg)
return result return result
def to_map_geometry(self, tgt_prj): def to_map_geometry(self, tgt_prj):
# type: (Union[str, int]) -> np.ndarray
"""Transform the input sensor geometry array into map geometry.
:param tgt_prj: target projection (WKT or 'epsg:1234' or <EPSG_int>)
"""
cols_out, rows_out = self.compute_output_shape() cols_out, rows_out = self.compute_output_shape()
self.area_definition = self.get_area_definition(proj4_args=get_proj4info(proj=tgt_prj), self.area_definition = self.get_area_definition(proj4_args=get_proj4info(proj=tgt_prj),
...@@ -584,6 +633,12 @@ class SensorMapGeometryTransformer(object): ...@@ -584,6 +633,12 @@ class SensorMapGeometryTransformer(object):
return self._resample(self.swath_definition, self.area_definition) return self._resample(self.swath_definition, self.area_definition)
def to_sensor_geometry(self, src_prj, src_extent): def to_sensor_geometry(self, src_prj, src_extent):
# type: (Union[str, int], List[float, float, float, float]) -> np.ndarray
"""Transform the input map geometry array into sensor geometry
:param src_prj: projection of the input map geometry array (WKT or 'epsg:1234' or <EPSG_int>)
:param src_extent: extent coordinates of input map geometry array (LL_x, LL_y, UR_x, UR_y) in the src_prj
"""
proj4_args = proj4_to_dict(get_proj4info(proj=src_prj)) proj4_args = proj4_to_dict(get_proj4info(proj=src_prj))
self.area_definition = AreaDefinition('', '', '', proj4_args, self.data.shape[1], self.data.shape[0], self.area_definition = AreaDefinition('', '', '', proj4_args, self.data.shape[1], self.data.shape[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment