Commit f72075cd authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Fixed multiprocessing issue related to OpenMP multiprocessing within pykdtree...

Fixed multiprocessing issue related to OpenMP multiprocessing within pykdtree as called by pyresample. Fixed type hints.
parent 5bd37c1e
Pipeline #3708 passed with stages
in 1 minute and 30 seconds
......@@ -49,7 +49,7 @@ class Geocoding(object):
self.from_geotransform_projection(gt, prj)
def from_geotransform_projection(self, gt, prj):
# type: (Union[list, tuple], str) -> self
# type: (Union[list, tuple], str) -> 'Geocoding'
"""Create Geocoding object from GDAL GeoTransform + WKT projection string.
HOW COMPUTATION OF RADIANTS WORKS:
......@@ -121,13 +121,12 @@ class Geocoding(object):
return self
def from_mapinfo(self, mapinfo):
# type: (Union[list, tuple]) -> self
# type: (Union[list, tuple]) -> 'Geocoding'
"""Create Geocoding object from ENVI map info.
:param mapinfo: ENVI map info, e.g., ['UTM', 1, 1, 192585.0, 5379315.0, 30.0, 30.0, 41, 'North', 'WGS-84']
:return: instance of Geocoding
"""
# type: (Union[list, tuple]) -> self
if mapinfo:
# validate input map info
if not isinstance(mapinfo, (list, tuple)):
......
......@@ -21,7 +21,6 @@ from rasterio.warp import calculate_default_transform as rio_calc_transform
from rasterio.warp import Resampling
from pyresample.geometry import AreaDefinition, SwathDefinition
from pyresample.bilinear import resample_bilinear
from pyresample.kd_tree import resample_nearest, resample_gauss, resample_custom
from pyresample.utils import get_area_def
from ...dtypes.conversion import dTypeDic_NumPy2GDAL
......@@ -545,6 +544,13 @@ class SensorMapGeometryTransformer(object):
del self.opts['radius_of_influence']
self.opts['radius'] = radius_of_influence
# NOTE: If pykdtree is built with OpenMP support (default) the number of threads is controlled with the
# standard OpenMP environment variable OMP_NUM_THREADS. The nprocs argument has no effect on pykdtree.
if 'nprocs' in self.opts:
if self.opts['nprocs'] > 1:
os.environ['OMP_NUM_THREADS'] = '%d' % opts['nprocs']
del self.opts['nprocs']
self.lats = lats
self.lons = lons
self.swath_definition = SwathDefinition(lons=lons, lats=lats)
......@@ -682,6 +688,8 @@ class SensorMapGeometryTransformer(object):
:param target_geo_def: target geo definition
:return:
"""
from pyresample.kd_tree import resample_nearest, resample_gauss, resample_custom
if self.resamp_alg == 'nearest':
opts = {k: v for k, v in self.opts.items() if k not in ['sigmas']}
result = resample_nearest(source_geo_def, data, target_geo_def, **opts)
......@@ -832,18 +840,10 @@ class SensorMapGeometryTransformer3D(object):
# define number of CPUs to use (but avoid sub-multiprocessing)
# -> parallelize either over bands or over image tiles
ncpus_avail = multiprocessing.cpu_count()
if 'nprocs' in opts and opts['nprocs']:
if self.lons.shape[2] > opts['nprocs']:
self.CPUs = opts['nprocs'] # parallelize over bands
opts['nprocs'] = 1
else:
self.CPUs = 1 # parallelize over imgage tiles
elif self.lons.shape[2] > ncpus_avail:
self.CPUs = ncpus_avail # parallelize over bands
else:
self.CPUs = 1
opts['nprocs'] = ncpus_avail # parallelize over imgage tiles
# bands: multiprocessing uses multiprocessing.Pool, implemented in to_map_geometry / to_sensor_geometry
# tiles: multiprocessing uses OpenMP implemented in pykdtree which is used by pyresample
self.opts['nprocs'] = opts.get('nprocs', multiprocessing.cpu_count())
self.mp_alg = 'bands' if self.lons.shape[2] >= opts['nprocs'] else 'tiles'
@staticmethod
def _to_map_geometry_2D(kwargs_dict):
......@@ -894,12 +894,16 @@ class SensorMapGeometryTransformer3D(object):
tgt_epsg = WKT2EPSG(proj4_to_WKT(get_proj4info(proj=tgt_prj)))
tgt_extent = tgt_extent or self._get_common_target_extent(tgt_epsg)
init_opts = self.opts.copy()
if self.mp_alg == 'bands':
del init_opts['nprocs'] # avoid sub-multiprocessing
args = [dict(
lons_2D=self.lons[:, :, band],
lats_2D=self.lats[:, :, band],
resamp_alg=self.resamp_alg,
radius_of_influence=self.radius_of_influence,
init_opts=self.opts,
init_opts=init_opts,
data_sensor_geo_2D=data[:, :, band],
tgt_prj=tgt_prj,
tgt_extent=tgt_extent,
......@@ -907,8 +911,8 @@ class SensorMapGeometryTransformer3D(object):
band_idx=band
) for band in range(data.shape[2])]
if self.CPUs > 1:
with multiprocessing.Pool(self.CPUs) as pool:
if self.mp_alg == 'bands':
with multiprocessing.Pool(self.opts['nprocs']) as pool:
result = pool.map(self._to_map_geometry_2D, args)
else:
result = [self._to_map_geometry_2D(argsdict) for argsdict in args]
......@@ -945,20 +949,24 @@ class SensorMapGeometryTransformer3D(object):
if data.ndim != 3:
raise ValueError(data.ndim, "'data' must have 3 dimensions.")
init_opts = self.opts.copy()
if self.mp_alg == 'bands':
del init_opts['nprocs'] # avoid sub-multiprocessing
args = [dict(
lons_2D=self.lons[:, :, band],
lats_2D=self.lats[:, :, band],
resamp_alg=self.resamp_alg,
radius_of_influence=self.radius_of_influence,
init_opts=self.opts,
init_opts=init_opts,
data_map_geo_2D=data[:, :, band],
src_prj=src_prj,
src_extent=src_extent,
band_idx=band
) for band in range(data.shape[2])]
if self.CPUs > 1:
with multiprocessing.Pool(self.CPUs) as pool:
if self.mp_alg == 'bands':
with multiprocessing.Pool(self.opts['nprocs']) as pool:
result = pool.map(self._to_sensor_geometry_2D, args)
else:
result = [self._to_sensor_geometry_2D(argsdict) for argsdict in args]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment