GeoArray.py 66.5 KB
Newer Older
Daniel Scheffler's avatar
Daniel Scheffler committed
1
2
3
4
5
6
7
8
# -*- coding: utf-8 -*-
__author__='Daniel Scheffler'


import numpy as np
import os
import warnings
from matplotlib import pyplot as plt
Daniel Scheffler's avatar
Daniel Scheffler committed
9
from six import PY3
Daniel Scheffler's avatar
Daniel Scheffler committed
10
11
12

# custom
from shapely.geometry import Polygon, box
Daniel Scheffler's avatar
Daniel Scheffler committed
13
from shapely.wkt      import loads as shply_loads
Daniel Scheffler's avatar
Daniel Scheffler committed
14
from osgeo import gdal_array
15
16
# mpl_toolkits.basemap -> imported when GeoArray.show_map() is used
# dill -> imported when dumping GeoArray
17

Daniel Scheffler's avatar
Daniel Scheffler committed
18
19
20
21
22
23
try:
    from osgeo import gdal
    from osgeo import gdalnumeric
except ImportError:
    import gdal
    import gdalnumeric
24
from geopandas import GeoDataFrame, GeoSeries
Daniel Scheffler's avatar
Daniel Scheffler committed
25
26


27
28
29
30
31
from ...geo.coord_calc                  import get_corner_coordinates, calc_FullDataset_corner_positions
from ...geo.coord_grid                  import snap_bounds_to_pixGrid
from ...geo.coord_trafo                 import mapXY2imXY, imXY2mapXY, transform_any_prj, reproject_shapelyGeometry
from ...geo.projection                  import prj_equal, WKT2EPSG, EPSG2WKT
from ...geo.raster.conversion           import raster2polygon
32
33
from ...geo.vector.topology             import get_overlap_polygon, get_footprint_polygon, polyVertices_outside_poly, \
                                               fill_holes_within_poly
34
35
36
from ...geo.vector.geometry             import boxObj
from ...io.raster.gdal                  import get_GDAL_ds_inmem
from ...numeric.array                   import find_noDataVal, get_outFillZeroSaturated
Daniel Scheffler's avatar
Daniel Scheffler committed
37
38
from ...compatibility.python.exceptions import TimeoutError as TimeoutError_comp, \
                                               FileNotFoundError as FileNotFoundError_comp
39
from ...compatibility.gdal              import get_gdal_func
40
41
42
43


def _alias_property(key):
    return property(
44
        lambda self:      getattr(self, key),
45
        lambda self, val: setattr(self, key, val),
46
        lambda self:      delattr(self, key))
Daniel Scheffler's avatar
Daniel Scheffler committed
47
48
49


class GeoArray(object):
50
51
    def __init__(self, path_or_array, geotransform=None, projection=None, bandnames=None, nodata=None, progress=True,
                 q=False):
52
53
54
55
56
57
        # type: (Any, tuple, str, list, float, bool, bool) -> GeoArray
        """This class creates a fast Python inteface to geodata - either on disk or in memory. It can be instanced with
        a file path or with a numpy array and the corresponding geoinformation. Instances can always be indexed like
        normal numpy arrays, no matter if GeoArray has been instanced from file or from an in-memory array. GeoArray
        provides a wide range of geo-related attributes belonging to the dataset as well as some functions for quickly
        visualizing the data as a map, a simple image or an interactive image.
Daniel Scheffler's avatar
Daniel Scheffler committed
58
59
60
61
62

        :param path_or_array:   a numpy.ndarray or a valid file path
        :param geotransform:    GDAL geotransform of the given array or file on disk
        :param projection:      projection of the given array or file on disk as WKT string
                                (only needed if GeoArray is instanced with an array)
63
64
        :param bandnames:       names of the bands within the input array, e.g. ['mask_1bit', 'mask_clouds'],
                                (default: ['B1', 'B2', 'B3', ...])
65
        :param nodata:          nodata value
66
        :param progress:        show progress bars (default: True)
67
        :param q:               quiet mode (default: False)
Daniel Scheffler's avatar
Daniel Scheffler committed
68
        """
69

70

71
        # FIXME implement compatibility to GDAL VRTs
72
73
        if not (isinstance(path_or_array, (str, np.ndarray, GeoArray)) or
                issubclass(getattr(path_or_array,'__class__'), GeoArray)):
74
            raise ValueError("%s parameter 'arg' takes only string, np.ndarray or GeoArray(and subclass) instances. "
75
                             "Got %s." %(self.__class__.__name__,type(path_or_array)))
76
77
78

        if path_or_array is None:
            raise ValueError("The %s parameter 'path_or_array' must not be None!" %self.__class__.__name__)
Daniel Scheffler's avatar
Daniel Scheffler committed
79

80
81
82
83
84
85
86
        if isinstance(path_or_array, str):
            assert ' ' not in path_or_array, "The given path contains whitespaces. This is not supported by GDAL."

            if not os.path.exists(path_or_array):
                raise FileNotFoundError(path_or_array) if PY3 else FileNotFoundError_comp(path_or_array)


87
        if isinstance(path_or_array, GeoArray) or issubclass(getattr(path_or_array,'__class__'), GeoArray):
88
            self.__dict__= path_or_array.__dict__.copy()
89
90
91
            self._initParams = dict([x for x in locals().items() if x[0] != "self"])
            self.geotransform = geotransform if geotransform       else self.geotransform
            self.projection   = projection   if projection         else self.projection
92
            self.bandnames    = bandnames    if bandnames          else list(self.bandnames.values())
93
94
95
96
            self._nodata      = nodata       if nodata is not None else self._nodata
            self.progress     = progress     if progress           else self.progress
            self.q            = q            if q is not None      else self.q

Daniel Scheffler's avatar
Daniel Scheffler committed
97
        else:
98
99
            self._initParams     = dict([x for x in locals().items() if x[0] != "self"])
            self.arg             = path_or_array
100
101
            self._arr            = path_or_array if isinstance(path_or_array, np.ndarray)            else None
            self.filePath        = path_or_array if isinstance(path_or_array, str) and path_or_array else None
102
103
104
105
106
107
108
109
110
111
            self.basename        = os.path.splitext(os.path.basename(self.filePath))[0] if not self.is_inmem else 'IN_MEM'
            self.progress        = progress
            self.q               = q
            self._arr_cache      = None
            self._geotransform   = None
            self._projection     = None
            self._shape          = None
            self._dtype          = None
            self._nodata         = nodata
            self._mask_nodata    = None
112
            self._mask_baddata   = None
113
114
            self._footprint_poly = None
            self._gdalDataset_meta_already_set = False
115
            self._metadata       = None
116
            self._bandnames      = None
Daniel Scheffler's avatar
Daniel Scheffler committed
117

118
            if bandnames:
119
                self.bandnames    = bandnames    # use property in order to validate given value
120
121
122
            if geotransform:
                self.geotransform = geotransform # use property in order to validate given value
            if projection:
123
                self.projection   = projection   # use property in order to validate given value
Daniel Scheffler's avatar
Daniel Scheffler committed
124

125
126
127
            if self.filePath:
                self.set_gdalDataset_meta()

128
129
130
131
132
133
134
    @property
    def arr(self):
        return self._arr


    @arr.setter
    def arr(self, ndarray):
Daniel Scheffler's avatar
Daniel Scheffler committed
135
        assert isinstance(ndarray, np.ndarray), "'arr' can only be set to a numpy array! Got %s." %type(ndarray)
136
137
138
139
140
        self._arr = ndarray


    @property
    def bandnames(self):
141
        if self._bandnames and len(self._bandnames)==self.bands:
142
143
144
145
146
147
148
149
            return self._bandnames
        else:
            self._bandnames = {'B%s' % band: i for i, band in enumerate(range(1, self.bands + 1))}
            return self._bandnames


    @bandnames.setter
    def bandnames(self, list_bandnames):
150
        # type: (list)
151
152
153
154
155
156
157
158
        assert len(list_bandnames) == self.bands, \
            'Number of given bandnames does not match number of bands in array.'
        assert len(list(set([type(b) for b in list_bandnames]))) == 1 and type(list_bandnames[0] == 'str'), \
            "'bandnames must be a set of strings. Got other datetypes in there.'"
        bN_dict = {band: i for i, band in enumerate(list_bandnames)}  # syntax supported since Python 2.7
        assert len(bN_dict) == self.bands, 'Bands must not have the same name.'
        self._bandnames = bN_dict

Daniel Scheffler's avatar
Daniel Scheffler committed
159
160
161

    @property
    def is_inmem(self):
162
        """Check if associated image array is completely loaded into memory."""
Daniel Scheffler's avatar
Daniel Scheffler committed
163
164
165
166
167
        return isinstance(self.arr, np.ndarray)


    @property
    def shape(self):
168
        """Get the array shape of the associated image array."""
169
170
        if self.is_inmem:
            return self.arr.shape
Daniel Scheffler's avatar
Daniel Scheffler committed
171
        else:
172
173
174
175
176
            if self._shape:
                return self._shape
            else:
                self.set_gdalDataset_meta()
                return self._shape
Daniel Scheffler's avatar
Daniel Scheffler committed
177
178


179
180
    @property
    def ndim(self):
181
        """Get the number dimensions of the associated image array."""
182
183
184
        return len(self.shape)


Daniel Scheffler's avatar
Daniel Scheffler committed
185
186
    @property
    def rows(self):
187
        """Get the number of rows of the associated image array."""
Daniel Scheffler's avatar
Daniel Scheffler committed
188
189
190
191
        return self.shape[0]


    @property
192
    def columns(self):
193
        """Get the number of columns of the associated image array."""
Daniel Scheffler's avatar
Daniel Scheffler committed
194
195
196
        return self.shape[1]


197
198
199
    cols = _alias_property('columns')


Daniel Scheffler's avatar
Daniel Scheffler committed
200
201
    @property
    def bands(self):
202
        """Get the number of bands of the associated image array."""
Daniel Scheffler's avatar
Daniel Scheffler committed
203
204
205
        return self.shape[2] if len(self.shape)>2 else 1


Daniel Scheffler's avatar
Daniel Scheffler committed
206
207
    @property
    def dtype(self):
208
        """Get the numpy data type of the associated image array."""
Daniel Scheffler's avatar
Daniel Scheffler committed
209
210
        if self._dtype:
            return self._dtype
Daniel Scheffler's avatar
Daniel Scheffler committed
211
212
        elif self.is_inmem:
            return self.arr.dtype
Daniel Scheffler's avatar
Daniel Scheffler committed
213
214
215
216
217
218
219
        else:
            self.set_gdalDataset_meta()
            return self._dtype


    @property
    def geotransform(self):
220
        """Get the GDAL GeoTransform of the associated image."""
Daniel Scheffler's avatar
Daniel Scheffler committed
221
222
223
224
225
226
        if self._geotransform:
            return self._geotransform
        elif not self.is_inmem:
            self.set_gdalDataset_meta()
            return self._geotransform
        else:
227
            return [0,1,0,0,0,-1]
Daniel Scheffler's avatar
Daniel Scheffler committed
228
229
230
231


    @geotransform.setter
    def geotransform(self, gt):
232
233
234
        assert isinstance(gt,(list,tuple)) and len(gt)==6, 'geotransform must be a list with 6 numbers. Got %s.' %gt
        for i in gt: assert isinstance(i,(int,float)),     "geotransform must contain only numbers. Got '%s'." %i

235
        self._geotransform = gt
Daniel Scheffler's avatar
Daniel Scheffler committed
236
237


238
239
240
    gt = _alias_property('geotransform')


Daniel Scheffler's avatar
Daniel Scheffler committed
241
242
    @property
    def xgsd(self):
243
        """Get the X resolution in units of the given or detected projection."""
Daniel Scheffler's avatar
Daniel Scheffler committed
244
245
246
247
248
        return self.geotransform[1]


    @property
    def ygsd(self):
249
        """Get the Y resolution in units of the given or detected projection."""
Daniel Scheffler's avatar
Daniel Scheffler committed
250
251
252
        return abs(self.geotransform[5])


253
254
255
256
257
258
259
260
    @property
    def xygrid_specs(self):
        """Get the specifications for the X/Y coordinate grid, e.g. [[15,30], [0,30]] for a coordinate with its origin
        at X/Y[15,0] and a GSD of X/Y[15,30]."""
        get_grid = lambda gt, xgsd, ygsd: [[gt[0], gt[0] + xgsd], [gt[3], gt[3] - ygsd]]
        return get_grid(self.geotransform, self.xgsd, self.ygsd)


Daniel Scheffler's avatar
Daniel Scheffler committed
261
262
    @property
    def projection(self):
263
264
        """Get the projection of the associated image. Setting the projection is only allowed if GeoArray has been
        instanced from memory or the associated file on disk has no projection."""
Daniel Scheffler's avatar
Daniel Scheffler committed
265
266
267
268
269
270
        if self._projection:
            return self._projection
        elif not self.is_inmem:
            self.set_gdalDataset_meta()
            return self._projection
        else:
271
            return ''
Daniel Scheffler's avatar
Daniel Scheffler committed
272
273
274
275
276


    @projection.setter
    def projection(self, prj):
        if self.filePath:
Daniel Scheffler's avatar
Daniel Scheffler committed
277
278
279
            assert self.projection is None or prj_equal(self.projection, prj),\
                "Cannot set %s.projection to the given value because it does not match the projection from the file " \
                "on disk." %self.__class__.__name__
Daniel Scheffler's avatar
Daniel Scheffler committed
280
281
282
283
        else:
            self._projection = prj


284
285
286
    prj = _alias_property('projection')


287
288
    @property
    def epsg(self):
289
        """Get the EPSG code of the projection of the GeoArray."""
290
291
292
293
294
295
296
297
        return WKT2EPSG(self.projection)


    @epsg.setter
    def epsg(self, epsg_code):
        self.projection = EPSG2WKT(epsg_code)


Daniel Scheffler's avatar
Daniel Scheffler committed
298
299
    @property
    def box(self):
300
        mapPoly = get_footprint_polygon(get_corner_coordinates(gt=self.geotransform, cols=self.columns, rows=self.rows))
Daniel Scheffler's avatar
Daniel Scheffler committed
301
302
303
        return boxObj(gt=self.geotransform, prj=self.projection, mapPoly=mapPoly)


304
305
    @property
    def nodata(self):
306
307
308
309
        """Get the nodata value of the GeoArray. If GeoArray has been instanced with a file path the file is checked
        for an existing nodata value. Otherwise (if no value is exlicitly given during object instanciation) the nodata
        value is tried to be automatically detected.
        """
310
311
312
313
314
315
316
317
318
319
320
        if self._nodata is not None:
            return self._nodata
        else:
            # try to get nodata value from file
            if not self.is_inmem:
                self.set_gdalDataset_meta()
            if self._nodata is None:
                self._nodata = find_noDataVal(self)
                if self._nodata == 'ambiguous':
                    warnings.warn('Nodata value could not be clearly identified. It has been set to None.')
                    self._nodata = None
321
                else:
322
                    if self._nodata is not None and not self.q:
323
324
                        print("Automatically detected nodata value for %s '%s': %s"
                              %(self.__class__.__name__, self.basename, self._nodata))
325
326
327
328
329
330
331
332
333
334
            return self._nodata


    @nodata.setter
    def nodata(self, value):
        self._nodata = value


    @property
    def mask_nodata(self):
335
        """Get the nodata mask of the associated image array. It is calculated using all image bands."""
336
337
338
        if self._mask_nodata is not None:
            return self._mask_nodata
        else:
339
340
            self.calc_mask_nodata() # sets self._mask_nodata
            return self._mask_nodata
341
342


343
    @mask_nodata.setter
344
345
346
347
348
349
350
    def mask_nodata(self, mask):
        """Set bad data mask.

        :param mask:    Can be a file path, a numpy array or an instance o GeoArray.
        """

        if mask is not None:
351
            geoArr_mask     = NoDataMask(mask, progress=self.progress, q=self.q)
352
            geoArr_mask.gt  = geoArr_mask.gt  if geoArr_mask.gt  not in [None, [0, 1, 0, 0, 0, -1]] else self.gt
Daniel Scheffler's avatar
Daniel Scheffler committed
353
            geoArr_mask.prj = geoArr_mask.prj if geoArr_mask.prj else self.prj
354
355
356
357
358
359
360
361
362
            imName          = "the %s '%s'" %(self.__class__.__name__, self.basename)

            assert geoArr_mask.bands == 1, \
                'Expected one single band as nodata mask for %s. Got %s bands.' % (self.basename, geoArr_mask.bands)
            assert geoArr_mask.shape[:2] == self.shape[:2], 'The provided nodata mask must have the same number of ' \
                                                            'rows and columns as the %s itself.' %imName
            assert geoArr_mask.gt == self.gt, \
                'The geotransform of the given nodata mask for %s must match the geotransform of the %s itself. ' \
                'Got %s.' %(imName, self.__class__.__name__, geoArr_mask.gt)
Daniel Scheffler's avatar
Daniel Scheffler committed
363
            assert not geoArr_mask.prj or prj_equal(geoArr_mask.prj, self.prj), \
364
365
366
367
                'The projection of the given nodata mask for the %s must match the projection of the %s itself.' \
                %(imName, self.__class__.__name__)

            self._mask_nodata = geoArr_mask
368
369


370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    @property
    def mask_baddata(self):
        """Returns the bad data mask for the associated image array if it has been explicitly previously. It can be set
         by passing a file path, a numpy array or an instance of GeoArray to the setter of this property.
        """
        return self._mask_baddata


    @mask_baddata.setter
    def mask_baddata(self, mask):
        """Set bad data mask.

        :param mask:    Can be a file path, a numpy array or an instance o GeoArray.
        """

        if mask is not None:
386
            geoArr_mask     = BadDataMask(mask, progress=self.progress, q=self.q)
387
            geoArr_mask.gt  = geoArr_mask.gt  if geoArr_mask.gt  not in [None, [0, 1, 0, 0, 0, -1]] else self.gt
Daniel Scheffler's avatar
Daniel Scheffler committed
388
            geoArr_mask.prj = geoArr_mask.prj if geoArr_mask.prj else self.prj
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
            imName          = "the %s '%s'" %(self.__class__.__name__, self.basename)

            assert geoArr_mask.bands == 1, \
                'Expected one single band as bad data mask for %s. Got %s bands.' % (self.basename, geoArr_mask.bands)
            assert geoArr_mask.shape[:2] == self.shape[:2], 'The provided bad data mask must have the same number of ' \
                                                            'rows and columns as the %s itself.' %imName
            assert geoArr_mask.gt == self.gt, \
                'The geotransform of the given bad data mask for %s must match the geotransform of the %s itself. ' \
                'Got %s.' %(imName, self.__class__.__name__, geoArr_mask.gt)
            assert prj_equal(geoArr_mask.prj, self.prj), \
                'The projection of the given bad data mask for the %s must match the projection of the %s itself.' \
                %(imName, self.__class__.__name__)

            self._mask_baddata = geoArr_mask


405
406
    @property
    def footprint_poly(self):
407
        # FIXME should return polygon in image coordinates if no projection is available
408

409
        """Get the footprint polygon of the associated image array (returns an instance of shapely.geometry.Polygon."""
410
411
412
        if self._footprint_poly is not None:
            return self._footprint_poly
        else:
413
            assert self.mask_nodata is not None, 'A nodata mask is needed for calculating the footprint polygon. '
414
            if np.std(self.mask_nodata[:])==0:
415
416
417
418
                # do not run raster2polygon if whole image is filled with data
                self._footprint_poly = self.box.mapPoly
            else:
                try:
419
                    multipolygon = raster2polygon(self, exact=False, progress=self.progress, q=self.q,
Daniel Scheffler's avatar
Daniel Scheffler committed
420
                                                          maxfeatCount=10, timeout=3)
421
                    self._footprint_poly = fill_holes_within_poly(multipolygon)
Daniel Scheffler's avatar
Daniel Scheffler committed
422
                except (RuntimeError, TimeoutError, TimeoutError_comp):
423
424
                    if not self.q:
                        warnings.warn("\nCalculation of footprint polygon failed for %s '%s'. Using outer bounds. One "
425
426
427
                                      "reason could be that the nodata value appears within the actual image (not only "
                                      "as fill value). To avoid this use another nodata value. Current nodata value is "
                                      "%s." %(self.__class__.__name__, self.basename, self.nodata))
428
                    self._footprint_poly = self.box.mapPoly
Daniel Scheffler's avatar
Daniel Scheffler committed
429
430

            # validation
431
432
433
            assert not polyVertices_outside_poly(self._footprint_poly, self.box.mapPoly),\
                "Computing footprint polygon for %s '%s' failed. The resulting polygon is partly or completely outside " \
                "of the image bounds." %(self.__class__.__name__, self.basename)
Daniel Scheffler's avatar
Daniel Scheffler committed
434
435
436
437
438
439
            #assert self._footprint_poly
            # for XY in self.corner_coord:
            #    assert self.GeoArray.box.mapPoly.contains(Point(XY)) or self.GeoArray.box.mapPoly.touches(Point(XY)), \
            #        "The corner position '%s' is outside of the %s." % (XY, self.imName)


440
441
442
            return self._footprint_poly


Daniel Scheffler's avatar
Daniel Scheffler committed
443
444
445
446
447
448
449
450
451
452
    @footprint_poly.setter
    def footprint_poly(self, poly):
        if isinstance(poly, Polygon):
            self._footprint_poly = poly
        elif isinstance(poly, str):
            self._footprint_poly = shply_loads(poly)
        else:
            raise ValueError("'footprint_poly' can only be set from a shapely polygon or a WKT string.")


453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    @property
    def metadata(self):
        """
        Returns a GeoDataFrame containing all available metadata (read from file if available).
        Use 'metadata[band_index].to_dict()' to get a metadata dictionary for a specific band.
        Use 'metadata.loc[row_name].to_dict()' to get all metadata values of the same key for all bands as dictionary.
        Use 'metadata.loc[row_name, band_index] = value' to set a new value.

        :return:  geopandas.GeoDataFrame
        """

        if self._metadata is not None:
            return self._metadata
        else:
            default = GeoDataFrame(columns=range(self.bands))
            #for bn,idx in self.bandnames.items():
            #    default.loc['band_index',bn] = idx
            self._metadata = default
            if not self.is_inmem:
                self.set_gdalDataset_meta()
                return self._metadata
            else:
                return self._metadata


    @metadata.setter
    def metadata(self, GDF):
        assert isinstance(GDF, GeoDataFrame) and len(GDF.columns)==self.bands, \
            "%s.metadata can only be set with an instance of geopandas.GeoDataFrame of which the column number " \
            "corresponds to the band number of %s." %(self.__class__.__name__, self.__class__.__name__)
        self._metadata = GDF


    meta = _alias_property('metadata')


Daniel Scheffler's avatar
Daniel Scheffler committed
489
490
    def __getitem__(self, given):
        # TODO check if array cache contains the needed slice and return data from there
491
492
493
494
495

        if isinstance(given, (int,float,slice)) and self.ndim==3:
            # handle 'given' as index for 3rd (bands) dimension
            if self.is_inmem:
                return self.arr[:, :, given]
Daniel Scheffler's avatar
Daniel Scheffler committed
496
            else:
Daniel Scheffler's avatar
Daniel Scheffler committed
497
                getitem_params = [given]
498
499
500
501
502
503
504
505
506
507

        elif isinstance(given, str):
            # behave like a dictionary and return the corresponding band
            if self.bandnames:
                if given not in self.bandnames:
                    raise ValueError("'%s' is not a known band. Known bands are: %s"
                                     % (given, ', '.join(list(self.bandnames.keys()))))
                if self.is_inmem:
                    return self.arr[:, :, self.bandnames[given]]
                else:
Daniel Scheffler's avatar
Daniel Scheffler committed
508
509
                    getitem_params = [self.bandnames[given]]
            else:
510
511
512
513
                raise ValueError('String indices are only supported if %s has been instanced with bandnames given.'
                                 %self.__class__.__name__)

        elif isinstance(given, (tuple, list)) and len(given)==3 and self.ndim==2:
514
515
516
517
518
            # handle requests like geoArr[[1,2],[3,4]  -> not implemented in from_path if array is not in mem
            types = [type(i) for i in given]
            if list in types or tuple in types:
                self.to_mem()

519
520
521
522
523
            # in case a third dim is requested from 2D-array -> ignore 3rd dim
            if self.is_inmem:
                return self.arr[given[:2]]
            else:
                getitem_params = given[:2]
Daniel Scheffler's avatar
Daniel Scheffler committed
524

525
        else:
526
527
528
529
530
531
            if isinstance(given, (tuple, list)):
                # handle requests like geoArr[[1,2],[3,4]  -> not implemented in from_path if array is not in mem
                types = [type(i) for i in given]
                if list in types or tuple in types:
                    self.to_mem()

532
533
534
535
536
537
538
539
            # behave like a numpy array
            if self.is_inmem:
                return self.arr[given]
            else:
                getitem_params = [given] if isinstance(given, slice) else given


        if not self.is_inmem:
Daniel Scheffler's avatar
Daniel Scheffler committed
540
541
542
543
544
            self._arr_cache = self.from_path(self.arg, getitem_params)

            return self._arr_cache


Daniel Scheffler's avatar
Daniel Scheffler committed
545
546
547
548
549
550
551
552
553
554
555
556
557
558
    def __setitem__(self, idx, array2set):
        """Overwrites the pixel values of GeoArray.arr with the given array.

        :param idx:         <int, list, slice> the index position to overwrite
        :param array2set:   <np.ndarray> array to be set. Must be compatible to the given index position.
        :return:
        """
        if self.is_inmem:
            self.arr[idx] = array2set
        else:
            raise NotImplementedError('Item assignment for %s instances that are not in memory is not yet supported.'
                                      %self.__class__.__name__)


Daniel Scheffler's avatar
Daniel Scheffler committed
559
560
    def __getattr__(self, attr):
        # check if the requested attribute can not be present because GeoArray has been instanced with an array
561
        if attr not in self.__dir__() and not self.is_inmem and attr in ['shape','dtype','geotransform', 'projection']:
Daniel Scheffler's avatar
Daniel Scheffler committed
562
563
            self.set_gdalDataset_meta()

564
565
        if attr in self.__dir__():             #__dir__() includes also methods and properties
            return self.__getattribute__(attr) #__getattribute__ avoids infinite loop
Daniel Scheffler's avatar
Daniel Scheffler committed
566
        elif hasattr(np.array([]),attr):
567
            return self[:].__getattribute__(attr)
Daniel Scheffler's avatar
Daniel Scheffler committed
568
        else:
Daniel Scheffler's avatar
Daniel Scheffler committed
569
            raise AttributeError("%s object has no attribute '%s'." %(self.__class__.__name__, attr))
Daniel Scheffler's avatar
Daniel Scheffler committed
570
571
572
573
574
575
576


    def __getstate__(self):
        """Defines how the attributes of GMS object are pickled."""

        # clean array cache in order to avoid cache pickling
        self.flush_cache()
577

Daniel Scheffler's avatar
Daniel Scheffler committed
578
579
580
581
582
        return self.__dict__


    def __setstate__(self, state):
        """Defines how the attributes of GMS object are unpickled.
583
        NOTE: This method has been implemented because otherwise pickled and unpickled instances show recursion errors
Daniel Scheffler's avatar
Daniel Scheffler committed
584
585
586
587
588
        within __getattr__ when requesting any attribute."""

        self.__dict__ = state


589
    def calc_mask_nodata(self, fromBand=None, overwrite=False):
590
591
592
593
594
595
        """Calculates a no data mask with (values: 0=nodata; 1=data)

        :param fromBand:  <int> index of the band to be used (if None, all bands are used)
        :param overwrite: <bool> whether to overwrite existing nodata mask that has already been calculated
        :return:
        """
596
        if self._mask_nodata is None or overwrite:
597
            assert self.ndim in [2, 3], "Only 2D or 3D arrays are supported. Got a %sD array." % self.ndim
Daniel Scheffler's avatar
Daniel Scheffler committed
598
            arr = self[:,:,fromBand] if self.ndim==3 and fromBand is not None else self[:]
599

600
            if self.nodata is None:
601
                self.mask_nodata = np.ones((self.rows, self.cols), np.bool)
602
            else:
603
604
                self.mask_nodata = np.where(arr == self.nodata, 0, 1).astype(np.bool) if arr.ndim == 2 else \
                                   np.all(np.where(arr == self.nodata, 0, 1), axis=2).astype(np.bool)
605
606


Daniel Scheffler's avatar
Daniel Scheffler committed
607
    def set_gdalDataset_meta(self):
608
609
610
611
612
613
614
615
616
617
618
619
        """Retrieves GDAL metadata from file. This function is only executed once to avoid overwriting of user defined
         attributes, that are defined after object instanciation.

        :return:
        """
        if not self._gdalDataset_meta_already_set:
            assert self.filePath
            ds = gdal.Open(self.filePath)
            # set private class variables (in order to avoid recursion error)
            self._shape        = tuple([ds.RasterYSize, ds.RasterXSize] + ([ds.RasterCount] if ds.RasterCount>1 else []))
            self._dtype        = gdal_array.GDALTypeCodeToNumericTypeCode(ds.GetRasterBand(1).DataType)
            self._geotransform = ds.GetGeoTransform()
Daniel Scheffler's avatar
Daniel Scheffler committed
620
            self._projection   = EPSG2WKT(WKT2EPSG(ds.GetProjection())) # temp conversion to EPSG needed because GDAL seems to modify WKT string when writing file to disk (e.g. using gdal_merge) -> conversion to EPSG and back undos that
621
            if not 'nodata' in self._initParams or self._initParams['nodata'] is None:
622
623
                band           = ds.GetRasterBand(1)
                self._nodata   = band.GetNoDataValue()   # FIXME this does not support different nodata values within the same file
624
625
626
627
628
629

            # read metadata
            for b in range(self.bands):
                band     = ds.GetRasterBand(b+1)
                self.metadata[b] = GeoSeries(band.GetMetadata())

630
631
632
            ds = band = None

        self._gdalDataset_meta_already_set = True
Daniel Scheffler's avatar
Daniel Scheffler committed
633
634
635


    def from_path(self, path, getitem_params=None):
636
637
638
639
640
641
642
643
        # type: (str, list) -> np.ndarray
        """Read a GDAL compatible raster image from disk, with respect to the given image position.

        :param path:            <str> the file path of the image to read
        :param getitem_params:  <list> a list of slices in the form [row_slice, col_slice, band_slice]
        :return out_arr:        <np.ndarray> the output array
        """

Daniel Scheffler's avatar
Daniel Scheffler committed
644
645
646
647
648
        ds = gdal.Open(path)
        if not ds: raise IOError('Error reading image data at %s.' %path)
        R, C, B = ds.RasterYSize, ds.RasterXSize, ds.RasterCount
        ds = None

649
        ## convert getitem_params to subset area to be read ##
Daniel Scheffler's avatar
Daniel Scheffler committed
650
651
        rS, rE, cS, cE, bS, bE, bL = [None] * 7

652
        # populate rS, rE, cS, cE, bS, bE, bL
Daniel Scheffler's avatar
Daniel Scheffler committed
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        if getitem_params:
            if len(getitem_params) >= 2:
                givenR, givenC = getitem_params[:2]
                if isinstance(givenR, slice):
                    rS = givenR.start
                    rE = givenR.stop - 1 if givenR.stop is not None else None
                elif isinstance(givenR, int):
                    rS = givenR
                    rE = givenR
                if isinstance(givenC, slice):
                    cS = givenC.start
                    cE = givenC.stop - 1 if givenC.stop is not None else None
                elif isinstance(givenC, int):
                    cS = givenC
                    cE = givenC
            if len(getitem_params) in [1, 3]:
                givenB = getitem_params[2] if len(getitem_params) == 3 else getitem_params[0]
                if isinstance(givenB, slice):
                    bS = givenB.start
                    bE = givenB.stop - 1 if givenB.stop is not None else None
                elif isinstance(givenB, int):
                    bS = givenB
                    bE = givenB
676
                elif isinstance(givenB,(tuple,list)):
Daniel Scheffler's avatar
Daniel Scheffler committed
677
678
679
680
681
682
683
684
685
686
                    typesInGivenB = [type(i) for i in givenB]
                    assert len(list(set(typesInGivenB))) == 1, \
                        'Mixed data types within the list of bands are not supported.'
                    if isinstance(givenB[0], int):
                        bL = list(givenB)
                    elif isinstance(givenB[0], str):
                        bL = [self.bandnames[i] for i in givenB]
                elif type(givenB) in [str]:
                    bL = [self.bandnames[givenB]]

687

Daniel Scheffler's avatar
Daniel Scheffler committed
688
689
690
691
692
693
694
695
696
        # set defaults for not given values
        rS = rS if rS is not None else 0
        rE = rE if rE is not None else R - 1
        cS = cS if cS is not None else 0
        cE = cE if cE is not None else C - 1
        bS = bS if bS is not None else 0
        bE = bE if bE is not None else B - 1
        bL = list(range(bS, bE + 1)) if not bL else bL

697
698
699
        # convert negative to positive ones
        rS = rS if rS >= 0 else self.rows + rS
        rE = rE if rE >= 0 else self.rows + rE
700
701
        cS = cS if cS >= 0 else self.columns + cS
        cE = cE if cE >= 0 else self.columns + cE
702
703
        bS = bS if bS >= 0 else self.bands + bS
        bE = bE if bE >= 0 else self.bands + bE
704
        bL = [b if b  >= 0 else (self.bands + b) for b in bL]
705

Daniel Scheffler's avatar
Daniel Scheffler committed
706
        # validate subset area bounds to be read
707
        msg = lambda v, idx, sz: '%s is out of bounds for axis %s with size %s' %(v, idx, sz) # FIXME numpy raises that error ONLY for the 2nd axis
Daniel Scheffler's avatar
Daniel Scheffler committed
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
        for val, axIdx, axSize in zip([rS,rE,cS,cE,bS,bE], [0,0,1,1,2,2], [R,R,C,C,B,B]):
            if not 0 <= val <= axSize - 1: raise ValueError(msg(val,axIdx,axSize))

        # read subset area
        if bL == list(range(0, B)):
            tempArr = gdalnumeric.LoadFile(path, cS, rS, cE - cS + 1, rE - rS + 1)
            if tempArr is None:
                raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
            out_arr = np.swapaxes(np.swapaxes(tempArr, 0, 2), 0, 1) if B > 1 else tempArr
        else:
            ds = gdal.Open(path)
            if ds is None:
                raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
            if len(bL) == 1:
                band = ds.GetRasterBand(bL[0] + 1)
                out_arr= band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
                band = None
            else:
                out_arr = np.empty((rE - rS + 1, cE - cS + 1, len(bL)))
                for i, bIdx in enumerate(bL):
                    band = ds.GetRasterBand(bIdx + 1)
                    out_arr[:, :, i] = band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
                    band = None

            ds = None

        if out_arr is None:
            raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())

        # only set self.arr if the whole cube has been read (in order to avoid sudden shape changes)
        if out_arr.shape==self.shape:
            self.arr = out_arr

Daniel Scheffler's avatar
Daniel Scheffler committed
741
742
        return out_arr # TODO implement check of returned datatype (e.g. NoDataMask should always return np.bool
                       # TODO -> would be np.int8 if an int8 file is read from disk
Daniel Scheffler's avatar
Daniel Scheffler committed
743
744


745
    def save(self, out_path, fmt='ENVI', creationOptions=None):
746
        # type: (str, str, list)
747
748
749
750
        """Write the raster data to disk.

        :param out_path:        <str> output path
        :param fmt:             <str> the output format / GDAL driver code to be used for output creation, e.g. 'ENVI'
Daniel Scheffler's avatar
Bugfix    
Daniel Scheffler committed
751
                                Refer to http://www.gdal.org/formats_list.html to get a full list of supported formats.
752
        :param creationOptions: <list> GDAL creation options, e.g. ["QUALITY=80", "REVERSIBLE=YES", "WRITE_METADATA=YES"]
753
754
755
        :return:
        """

Daniel Scheffler's avatar
Bugfix    
Daniel Scheffler committed
756
757
        if not self.q:
            print('Writing GeoArray of size %s to %s.' %(self.shape, out_path))
758
        assert self.ndim in [2,3], 'Only 2D- or 3D arrays are supported.'
Daniel Scheffler's avatar
Daniel Scheffler committed
759

Daniel Scheffler's avatar
Bugfix    
Daniel Scheffler committed
760
761
762
        driver = gdal.GetDriverByName(fmt)
        if driver is None:
            raise Exception("'%s' is not a supported GDAL driver." % fmt)
763

Daniel Scheffler's avatar
Bugfix    
Daniel Scheffler committed
764
765
        if not os.path.isdir(os.path.dirname(out_path)):
            os.makedirs(os.path.dirname(out_path))
766

Daniel Scheffler's avatar
Bugfix    
Daniel Scheffler committed
767
        if self.is_inmem:
768
769
770
771
            ds     = get_GDAL_ds_inmem(self.arr,self.geotransform, self.projection, self.nodata) # expects rows,columns,bands

            # set metadata
            if not self.metadata.empty:
772
                for bidx in range(self.bands):
773
                    band = ds.GetRasterBand(bidx+1)
774
                    band.SetMetadata(self.metadata[bidx].to_dict())
775
776
                    band = None

777
            driver.CreateCopy(out_path, ds, options=creationOptions if creationOptions else [])
778
779
780

            #out_arr = self.arr if self.ndim == 2 else np.swapaxes(np.swapaxes(self.arr, 0, 2), 1, 2)  # rows, columns, bands => bands, rows, columns
            #gdalnumeric.SaveArray(out_arr, out_path, format=fmt, prototype=ds) # expects bands,rows,columns
781
            ds      = None
782

783
784
        else:
            src_ds = gdal.Open(self.filePath)
785
786
            gdal_Translate = get_gdal_func('Translate')
            gdal_Translate(out_path, src_ds, format=fmt, creationOptions=creationOptions)
787
            src_ds = None
Daniel Scheffler's avatar
Daniel Scheffler committed
788

789
790
791
        if not os.path.exists(out_path):
            raise Exception(gdal.GetLastErrorMsg())

Daniel Scheffler's avatar
Daniel Scheffler committed
792

793
    def dump(self, out_path):
794
        # type: (str)
795
        """Serialize the whole object instance to disk using dill."""
796
        import dill
797
798
        with open(out_path,'w') as outF:
            dill.dump(self,outF)
Daniel Scheffler's avatar
Daniel Scheffler committed
799
800


801
    def _get_plottable_image(self, xlim=None, ylim=None, band=None, boundsMap=None, boundsMapPrj=None, res_factor=None,
802
                             nodataVal=None, out_prj=None):
803
        # handle limits
804
805
806
807
808
809
810
811
        if boundsMap:
            boundsMapPrj = boundsMapPrj if boundsMapPrj else self.prj
            image2plot, gt, prj = self.get_mapPos(boundsMap, boundsMapPrj, band2get=band,
                                                  fillVal=nodataVal if nodataVal is not None else self.nodata)
        else:
            cS, cE = xlim if isinstance(xlim, (tuple, list)) else (0, self.columns - 1)
            rS, rE = ylim if isinstance(ylim, (tuple, list)) else (0, self.rows    - 1)

Daniel Scheffler's avatar
Daniel Scheffler committed
812
            image2plot = self[rS:rE, cS:cE, band] if band is not None else self[rS:rE, cS:cE]
813
814
            gt, prj    = self.geotransform, self.projection

815

816
817
        transOpt   = ['SRC_METHOD=NO_GEOTRANSFORM'] if tuple(gt) == (0, 1, 0, 0, 0, -1) else None
        xdim, ydim = None, None
818
        nodataVal  = nodataVal if nodataVal is not None else self.nodata
819
820
821

        if res_factor != 1. and image2plot.shape[0] * image2plot.shape[1] > 1e6:  # shape > 1000*1000
            # sample image down
822
823
            xdim, ydim = (self.columns * res_factor, self.rows * res_factor) if res_factor else \
                tuple((np.array([self.columns, self.rows]) / (np.array([self.columns, self.rows]).max() / 1000)))  # normalize
824
825
            xdim, ydim = int(xdim), int(ydim)

826
        if xdim or ydim or out_prj:
827
828
829
830
831
832
833
834
835
            from ...geo.raster.reproject import warp_ndarray
            image2plot, gt, prj = warp_ndarray(image2plot, self.geotransform, self.projection,
                                               out_XYdims=(xdim, ydim), in_nodata=nodataVal, out_nodata=nodataVal,
                                               transformerOptions=transOpt, out_prj=out_prj, q=True)
            if transOpt and 'NO_GEOTRANSFORM' in ','.join(transOpt):
                image2plot = np.flipud(image2plot)
                gt=list(gt)
                gt[3]=0

836
837
            if xdim or ydim:
                print('Note: array has been downsampled to %s x %s for faster visualization.' % (xdim, ydim))
838
839
840
841

        return image2plot, gt, prj


842
843
    def show(self, xlim=None, ylim=None, band=None, boundsMap=None, boundsMapPrj=None, figsize=None,
             interpolation='none', cmap=None, nodataVal=None, res_factor=None, interactive=False):
844
845
        """Plots the desired array position into a figure.

846
847
        :param xlim:            [start_column, end_column]
        :param ylim:            [start_row, end_row]
848
849
        :param band:            the band index of the band to be plotted (if None and interactive==True all bands are
                                shown, otherwise the first band is chosen)
850
851
        :param boundsMap:       xmin, ymin, xmax, ymax
        :param boundsMapPrj:
852
853
854
855
856
        :param figsize:
        :param interpolation:
        :param cmap:
        :param nodataVal:
        :param res_factor:
857
858
        :param interactive:     <bool> activates interactive plotting based on 'holoviews' library.
                                NOTE: this deactivates the magic '% matplotlib inline' in Jupyter Notebook
859
860
861
        :return:
        """

862
863
        band = (band if band is not None else 0) if not interactive else band

864
        # get image to plot
865
        nodataVal           = nodataVal if nodataVal is not None else self.nodata
866
867
868
        image2plot, gt, prj = self._get_plottable_image(xlim, ylim, band, boundsMap=boundsMap,
                                                        boundsMapPrj=boundsMapPrj, res_factor=res_factor,
                                                        nodataVal=nodataVal)
869
870

        # set color palette
871
        palette   = cmap if cmap else plt.cm.gray
872
        if nodataVal is not None and np.std(image2plot)!=0: # do not show nodata
873
            image2plot = np.ma.masked_equal(image2plot, nodataVal)
874
            vmin, vmax = np.percentile(image2plot.compressed(),2), np.percentile(image2plot.compressed(),98)
875
            palette.set_bad('aqua', 0)
876
877
        else:
            vmin, vmax = np.percentile(image2plot, 2), np.percentile(image2plot, 98)
878
879
        palette.set_over ('1')
        palette.set_under('0')
880

881
882
883
884
885
886
887
888
889
        if interactive and image2plot.ndim==3:
            import holoviews as hv
            from skimage.exposure import rescale_intensity
            hv.notebook_extension('matplotlib')

            cS, cE = xlim if isinstance(xlim, (tuple, list)) else (0, self.columns - 1)
            rS, rE = ylim if isinstance(ylim, (tuple, list)) else (0, self.rows - 1)

            image2plot   = np.array(rescale_intensity(image2plot, in_range=(vmin, vmax)))
Daniel Scheffler's avatar
Daniel Scheffler committed
890
891
892
            get_hv_image = lambda b: hv.Image(image2plot[:,:,b] if b is not None else image2plot,
                                              bounds=(cS, rS, cE, rE))(style={'cmap': 'gray'}, # FIXME ylabels have the wrong order
                                              plot={'fig_inches':4 if figsize is None else figsize, 'show_grid':True})
893
894
895
896
897
898
899

            #hvIm = hv.Image(image2plot)(style={'cmap': 'gray'}, figure_inches=figsize)
            hmap = hv.HoloMap([(band, get_hv_image(band)) for band in range(image2plot.shape[2])], kdims=['band'])

            return hmap

        else:
900
901
902
903
            if interactive:
                warnings.warn('Currently there is no interactive mode for single-band arrays. '
                              'Switching to standard matplotlib figure..') # TODO implement zoomable fig

904
905
906
907
908
909
            # show image
            plt.figure(figsize=figsize)
            rows, cols = image2plot.shape[:2]
            plt.imshow(image2plot, palette, interpolation=interpolation, extent=(0, cols, rows, 0),
                       vmin=vmin, vmax=vmax, ) # compressed excludes nodata values
            plt.show()
Daniel Scheffler's avatar
Daniel Scheffler committed
910
911


912
913
    def show_map(self, xlim=None, ylim=None, band=0, boundsMap=None, boundsMapPrj=None, figsize=None,
                 interpolation='none', cmap=None, nodataVal=None, res_factor=None, return_map=False):
914
        """
915

916
917
918
        :param xlim:
        :param ylim:
        :param band:
919
920
        :param boundsMap:       xmin, ymin, xmax, ymax
        :param boundsMapPrj:
921
922
923
924
925
926
927
928
        :param figsize:
        :param interpolation:
        :param cmap:
        :param nodataVal:
        :param res_factor:
        :param return_map:
        :return:
        """
929
        from mpl_toolkits.basemap import Basemap
930
931
932
933
934
935

        assert self.geotransform and tuple(self.geotransform) != (0,1,0,0,0,-1),\
            'A valid geotransform is needed for a map visualization. Got %s.' %self.geotransform
        assert self.projection,   'A projection is needed for a map visualization. Got %s.' %self.projection

        # get image to plot
936
        nodataVal           = nodataVal if nodataVal is not None else self.nodata
937
938
939
        image2plot, gt, prj = self._get_plottable_image(xlim, ylim, band, boundsMap=boundsMap,
                                                        boundsMapPrj=boundsMapPrj, res_factor=res_factor,
                                                        nodataVal=nodataVal, out_prj='epsg:4326')
940
941

        # calculate corner coordinates of plot
942
943
944
945
946
947
948
        #if boundsMap:
        #    boundsMapPrj = boundsMapPrj if boundsMapPrj else self.prj
        #    if not prj_equal(boundsMapPrj, 4326):
        #        boundsMap = reproject_shapelyGeometry(box(*boundsMap), boundsMapPrj, 4626).bounds
        #    xmin, ymin, xmax, ymax = boundsMap
        #    UL_XY, UR_XY, LR_XY, LL_XY = (xmin,ymax), (xmax, ymax), (xmax,ymin), (xmin, ymin)
        #else:
949
950
951
952
953
954
955
956
957
958
959
960
961
        UL_XY, UR_XY, LR_XY, LL_XY = [(YX[1],YX[0]) for YX in GeoArray(image2plot, gt, prj).box.boxMapYX]
        center_lon, center_lat     = (UL_XY[0]+UR_XY[0])/2., (UL_XY[1]+LL_XY[1])/2.

        # create map
        fig = plt.figure(figsize=figsize)
        plt.subplots_adjust(left=0.05, right=0.95, top=0.90, bottom=0.05, wspace=0.15, hspace=0.05)
        ax = plt.subplot(111)

        m = Basemap(projection='tmerc', resolution=None,    lon_0=center_lon,   lat_0=center_lat,
                    urcrnrlon=UR_XY[0], urcrnrlat=UR_XY[1], llcrnrlon=LL_XY[0], llcrnrlat=LL_XY[1])

        # set color palette
        palette = cmap if cmap else plt.cm.gray
962
        if nodataVal is not None and np.std(image2plot)!=0: # do not show nodata
963
            image2plot = np.ma.masked_equal(image2plot, nodataVal)
964
            vmin, vmax = np.percentile(image2plot.compressed(), 2), np.percentile(image2plot.compressed(), 98)
965
            palette.set_bad('aqua', 0)
966
967
        else:
            vmin, vmax = np.percentile(image2plot, 2), np.percentile(image2plot, 98)
968
969
970
971
        palette.set_over ('1')
        palette.set_under('0')

        # add image to map (y-axis must be inverted for basemap)
972
        m.imshow(np.flipud(image2plot), palette, interpolation=interpolation, vmin=vmin, vmax=vmax)
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988

        # add coordinate grid lines
        parallels = np.arange(-90, 90., 0.25)
        m.drawparallels(parallels, labels=[1, 0, 0, 0], fontsize=12, linewidth=0.4)

        meridians = np.arange(-180., 180., 0.25)
        m.drawmeridians(meridians, labels=[0, 0, 0, 1], fontsize=12, linewidth=0.4)

        if return_map:
            return fig,ax, m
        else:
            plt.show()


    def show_map_utm(self, xlim=None, ylim=None, band=0, figsize=None, interpolation='none', cmap=None, nodataVal=None,
                 res_factor=None, return_map=False):
989
990

        from mpl_toolkits.basemap import Basemap
991
992
993
994
        warnings.warn(UserWarning('This function is still under construction and may not work as expected!'))
        # TODO debug this function

        # get image to plot
995
        nodataVal           = nodataVal if nodataVal is not None else self.nodata
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
        image2plot, gt, prj = self._get_plottable_image(xlim, ylim, band, res_factor, nodataVal)

        # calculate corner coordinates of plot
        box2plot = GeoArray(image2plot, gt, prj).box
        UL_XY, UR_XY, LR_XY, LL_XY = [(YX[1], YX[0]) for YX in GeoArray(image2plot, gt, prj).box.boxMapYX]
        # Xarr, Yarr = self.box.get_coordArray_MapXY(prj=EPSG2WKT(4326))
        UL_XY, UR_XY, LR_XY, LL_XY = [transform_any_prj(self.projection,'epsg:4326',x,y)  for y,x in box2plot.boxMapYX]
        center_X, center_Y = (UL_XY[0] + UR_XY[0]) / 2., (UL_XY[1] + LL_XY[1]) / 2.
        center_lon, center_lat = transform_any_prj(prj,'epsg:4326', center_X, center_Y)
        print(center_lon, center_lat)

        # create map
        fig = plt.figure(figsize=figsize)
        plt.subplots_adjust(left=0.05, right=0.95, top=0.90, bottom=0.05, wspace=0.15, hspace=0.05)
        ax = plt.subplot(111)
        print(UL_XY, UR_XY, LR_XY, LL_XY)
#        m = Basemap(projection='tmerc', resolution=None, lon_0=center_lon, lat_0=center_lat,
#                    urcrnrx=UR_XY[0], urcrnry=UR_XY[1], llcrnrx=LL_XY[0], llcrnry=LL_XY[1])
        m = Basemap(projection='tmerc', resolution=None, lon_0=center_lon, lat_0=center_lat,
                    urcrnrlon=UR_XY[0], urcrnrlat=UR_XY[1], llcrnrlon=LL_XY[0], llcrnrlat=LL_XY[1],
                    k_0=0.9996, rsphere=(6378137.00, 6356752.314245179),suppress_ticks=False)
        # m.pcolormesh(Xarr, Yarr, self[:], cmap=plt.cm.jet)

        # set color palette
        palette = cmap if cmap else plt.cm.gray
1021
        if nodataVal is not None: # do not show nodata
1022
            image2plot = np.ma.masked_equal(image2plot, nodataVal)
1023
            vmin, vmax = np.percentile(image2plot.compressed(), 2), np.percentile(image2plot.compressed(), 98)
1024
            palette.set_bad('aqua', 0)
1025
1026
        else:
            vmin, vmax = np.percentile(image2plot, 2), np.percentile(image2plot, 98)
1027
1028
1029
        palette.set_over('1')
        palette.set_under('0')

1030
1031
        # add image to map (y-axis must be inverted for basemap)
        m.imshow(np.flipud(image2plot), palette, interpolation=interpolation, vmin=vmin, vmax=vmax)
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043

        # add coordinate grid lines
        parallels = np.arange(-90, 90., 0.25)
        m.drawparallels(parallels, labels=[1, 0, 0, 0], fontsize=12, linewidth=0.4)

        meridians = np.arange(-180., 180., 0.25)
        m.drawmeridians(meridians, labels=[0, 0, 0, 1], fontsize=12, linewidth=0.4)

        if return_map:
            return fig, ax, m
        else:
            plt.show()
1044
1045


1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
    def show_footprint(self):
        """This method is intended to be called from Jupyter Notebook and shows a web map containing the calculated
        footprint of GeoArray."""

        try:
            import folium, geojson
        except ImportError:
            folium, geojson = None, None
        if not folium or not geojson:
            raise ImportError(
                "This method requires the libraries 'folium' and 'geojson'. They can be installed with "
                "the shell command 'pip install folium geojson'.")

        lonlatPoly = reproject_shapelyGeometry(self.footprint_poly, self.epsg, 4326)

        m   = folium.Map(location=tuple(np.array(lonlatPoly.centroid.coords.xy).flatten())[::-1])
        gjs = geojson.Feature(geometry=lonlatPoly, properties={})
        folium.GeoJson(gjs).add_to(m)
        return m


1067
    def get_mapPos(self, mapBounds, mapBounds_prj, band2get=None, arr_gt=None, arr_prj=None, fillVal=None,
1068
                   rspAlg='near', v=False): # TODO implement slice for indexing bands
Daniel Scheffler's avatar
Daniel Scheffler committed
1069
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
1070

Daniel Scheffler's avatar
Daniel Scheffler committed
1071
1072
        :param mapBounds:       xmin, ymin, xmax, ymax
        :param mapBounds_prj:
1073
        :param band2get:        band index of the band to be returned (full array if not given)
Daniel Scheffler's avatar
Daniel Scheffler committed
1074
1075
        :param arr_gt:
        :param arr_prj:
1076
        :param fillVal:         nodata value
Daniel Scheffler's avatar
Daniel Scheffler committed
1077
1078
        :param rspAlg:          <str> Resampling method to use. Available methods are:
                                near, bilinear, cubic, cubicspline, lanczos, average, mode, max, min, med, q1, q2
1079
        :param v:               verbose mode (not related to GeoArray.v; must be explicitly set)
Daniel Scheffler's avatar
Daniel Scheffler committed
1080
1081
        :return:
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
1082

Daniel Scheffler's avatar
Daniel Scheffler committed
1083
1084
        arr_gt  = arr_gt  if arr_gt  else self.geotransform
        arr_prj = arr_prj if arr_prj else self.projection
1085
        fillVal = fillVal if fillVal is not None else self.nodata
1086
        if self.is_inmem and (not arr_gt or not arr_prj):
Daniel Scheffler's avatar
Daniel Scheffler committed
1087
1088
            raise ValueError('In case of in-mem arrays the respective geotransform and projection of the array '
                             'has to be passed.')
Daniel Scheffler's avatar
Daniel Scheffler committed
1089

1090
1091
1092
1093
1094
1095
1096
1097
        if v:
            print('%s.get_mapPos() input parameters:')
            print('\tmapBounds', mapBounds, '<==>', self.box.boundsMap)
            print('\tEPSG', WKT2EPSG(mapBounds_prj), self.epsg)
            print('\tarr_gt', arr_gt, self.gt)
            print('\tarr_prj', WKT2EPSG(arr_prj), self.epsg)
            print('\tfillVal', fillVal, self.nodata, '\n')

Daniel Scheffler's avatar
Daniel Scheffler committed
1098
        sub_arr, sub_gt, sub_prj = get_array_at_mapPos(self, arr_gt, arr_prj, mapBounds_prj, mapBounds, fillVal=fillVal,
1099
                                                       rspAlg=rspAlg, out_gsd=(self.xgsd,self.ygsd), band2get=band2get)
Daniel Scheffler's avatar
Daniel Scheffler committed
1100
1101
1102
        return sub_arr, sub_gt, sub_prj


1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
    def reproject_to_new_grid(self, prototype=None, tgt_prj=None, tgt_xygrid=None, rspAlg='cubic'):
        """Reproject all array-like attributes to a given target grid.

        :param prototype:   <GeoArray> an instance of GeoArray to be used as pixel grid reference
        :param tgt_prj:     <str> WKT string of the projection
        :param tgt_xygrid:  <list> target XY grid, e.g. [[xmin,xmax], [ymax, ymin]] for the UL corner
        :param rspAlg:      <str, int> GDAL compatible resampling algorithm code
        :return:
        """
        assert (tgt_prj and tgt_xygrid) or prototype, "Provide either 'prototype' or 'tgt_prj' and 'tgt_xygrid'!"
        tgt_prj    = tgt_prj    if tgt_prj    else prototype.prj
        tgt_xygrid = tgt_xygrid if tgt_xygrid is not None else prototype.xygrid_specs
        assert tgt_xygrid[1][0]>tgt_xygrid[1][1]

        # set target GSD
        tgt_xgsd, tgt_ygsd = abs(tgt_xygrid[0][0]-tgt_xygrid[0][1]), abs(tgt_xygrid[1][0]-tgt_xygrid[1][1])

        # set target bounds
        tgt_bounds = reproject_shapelyGeometry(self.box.mapPoly, self.prj, tgt_prj).bounds

        gt = (tgt_xygrid[0][0], tgt_xgsd, 0, max(tgt_xygrid[1]), 0, -tgt_ygsd)
        xmin, ymin, xmax, ymax = snap_bounds_to_pixGrid(tgt_bounds, gt, roundAlg='on')

        from ...geo.raster.reproject import warp_ndarray
        self.arr, self.gt, self.prj = \
            warp_ndarray(self[:], self.gt, self.prj, tgt_prj,
                         out_gsd    = (tgt_xgsd, tgt_ygsd),
                         out_bounds = (xmin, ymin, xmax, ymax),
                         out_bounds_prj = tgt_prj,
                         rspAlg     = rspAlg,
                         in_nodata  = self.nodata,
                         CPUs       = None,# if self.mp else 1, # TODO
1135
1136
                         progress   = self.progress,
                         q          = self.q)
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150

        if hasattr(self, '_mask_nodata') and self._mask_nodata is not None:
            self.mask_nodata.reproject_to_new_grid(prototype  = prototype,
                                                   tgt_prj    = tgt_prj,
                                                   tgt_xygrid = tgt_xygrid,
                                                   rspAlg     = 'near')

        if hasattr(self, '_mask_baddata') and self._mask_baddata is not None:
            self.mask_baddata.reproject_to_new_grid(prototype  = prototype,
                                                    tgt_prj    = tgt_prj,
                                                    tgt_xygrid = tgt_xygrid,
                                                    rspAlg     = 'near')


1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
    def read_pointData(self, mapXY_points, mapXY_points_prj=None, band=None):
        """Returns the array values for the given set of X/Y coordinates.
         NOTE: If GeoArray has been instanced with a file path, the function will read the dataset into memory.

        :param mapXY_points:        <np.ndarray, tuple> X/Y coordinates of the points of interest. If a numpy array is
                                    given, it must have the shape [Nx2]
        :param mapXY_points_prj:    <str, int> WKT string or EPSG code of the projection corresponding to the given
                                    coordinates.
        :param band:                <int> the band index of the band of interest. If None, the values of all bands are
                                    returned.
        :return:                    np.ndarray with shape [Nx2xbands]
        """
        mapXY = mapXY_points if isinstance(mapXY_points, np.ndarray) else np.array(mapXY_points).reshape(1,2)
        prj   = mapXY_points_prj if mapXY_points_prj else self.prj

        assert prj, 'A projection is needed for returning image DNs at specific map X/Y coordinates!'
        if not prj_equal(prj1=prj, prj2=self.prj):
            mapXY = transform_any_prj(prj, self.prj, mapXY[:,0], mapXY[:,1])

        imXY = mapXY2imXY(mapXY, self.geotransform)
        imYX = np.fliplr(np.array(imXY)).astype(np.int16)

        if imYX.size==2: # only one coordinate pair
            Y,X = imYX[0].tolist()
            return self[Y,X,band]
        else: # multiple coordinate pairs
            return self[imYX.T.tolist()+[band]]


Daniel Scheffler's avatar
Daniel Scheffler committed
1180
1181
1182
1183
1184
1185
1186
    def to_mem(self):
        """Reads the whole dataset into memory and sets self.arr to the read data.
        """
        self.arr = self[:]
        return self


Daniel Scheffler's avatar
Daniel Scheffler committed
1187
    def to_disk(self):