GeoArray.py 20.9 KB
Newer Older
Daniel Scheffler's avatar
Daniel Scheffler committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# -*- coding: utf-8 -*-
__author__='Daniel Scheffler'


import numpy as np
import os
import warnings
from matplotlib import pyplot as plt

# custom
from shapely.geometry import Polygon, box
from osgeo import gdal_array
try:
    from osgeo import gdal
    from osgeo import gdalnumeric
except ImportError:
    import gdal
    import gdalnumeric


21
22
23
24
25
26
from ...geo.coord_calc       import get_corner_coordinates, calc_FullDataset_corner_positions
from ...geo.coord_grid       import snap_bounds_to_pixGrid
from ...geo.coord_trafo      import mapXY2imXY, imXY2mapXY, transform_any_prj
from ...geo.projection       import prj_equal
from ...geo.vector.topology  import get_overlap_polygon
from ...geo.raster.reproject import warp_ndarray
Daniel Scheffler's avatar
Daniel Scheffler committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332



class GeoArray(object):
    def __init__(self, path_or_array, geotransform=None, projection=None, bandnames=None):
        # type: (Any, tuple, str, list) -> GeoArray
        """

        :param path_or_array:   a numpy.ndarray or a valid file path
        :param geotransform:    GDAL geotransform of the given array or file on disk
        :param projection:      projection of the given array or file on disk as WKT string
                                (only needed if GeoArray is instanced with an array)
        :param bandnames:       names of the bands within the input array, e.g. ['mask_1bit', 'mask_clouds']
        """
        if type(path_or_array) not in [str, np.ndarray, type(self)]:
            raise ValueError("%s parameter 'arg' takes only string "
                             "or np.ndarray types. Got %s." %(self.__class__.__name__,type(path_or_array)))

        self.arg           = path_or_array
        self.arr           = self.arg if isinstance(self.arg, np.ndarray)       else None
        self.filePath      = self.arg if isinstance(self.arg, str) and self.arg else None
        self._arr_cache    = None
        self._geotransform = None
        self._projection   = None
        self._shape        = None
        self._dtype        = None

        if isinstance(self.arg, str):
            if not os.path.exists(self.filePath):
                raise FileNotFoundError(self.filePath)

            ds = gdal.Open(self.filePath)
            if not ds:
                raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
            bands = ds.RasterCount
            ds    = None
        else:
            self._shape = self.arr.shape
            self._dtype = self.arr.dtype
            bands = self.arr.shape[2] if len(self.arr.shape) == 3 else 1

        if bandnames:
            assert len(bandnames) == bands, \
                'Number of given bandnames does not match number of bands in array.'
            assert len(list(set([type(b) for b in bandnames]))) == 1 and type(bandnames[0] == 'str'), \
                "'bandnames must be a set of strings. Got other datetypes in there.'"
            self.bandnames = {band: i for i, band in enumerate(bandnames)}
            assert len(self.bandnames) == bands, 'Bands must not have the same name.'
        else:
            self.bandnames = {'B%s' % band: i for i, band in enumerate(range(1, bands + 1))}

        if geotransform:
            self.geotransform = geotransform
        if projection:
            self.projection   = projection


    @property
    def is_inmem(self):
        return isinstance(self.arr, np.ndarray)


    @property
    def shape(self):
        if self._shape:
            return self._shape
        else:
            self.set_gdalDataset_meta()
            return self._shape


    @property
    def dtype(self):
        if self._dtype:
            return self._dtype
        else:
            self.set_gdalDataset_meta()
            return self._dtype


    @property
    def geotransform(self):
        if self._geotransform:
            return self._geotransform
        elif not self.is_inmem:
            self.set_gdalDataset_meta()
            return self._geotransform
        else:
            raise AttributeError("Attribute 'geotransform' has not been set yet.")


    @geotransform.setter
    def geotransform(self, gt):
        if self.filePath:
            assert self.geotransform == gt, "Cannot set %s.geotransform to the given value because it does not " \
                                            "match the geotransform from the file on disk." %self.__class__.__name__
        else:
            self._geotransform = gt


    @property
    def projection(self):
        if self._projection:
            return self._projection
        elif not self.is_inmem:
            self.set_gdalDataset_meta()
            return self._projection
        else:
            raise AttributeError("Attribute 'projection' has not been set yet.")


    @projection.setter
    def projection(self, prj):
        if self.filePath:
            assert self.projection == prj, "Cannot set %s.projection to the given value because it does not " \
                                            "match the projection from the file on disk." %self.__class__.__name__
        else:
            self._projection = prj


    def __getitem__(self, given):
        # TODO check if array cache contains the needed slice and return data from there
        # FIXME negative slices do not work
        if self.is_inmem:
            # this is only the case if GeoArray has been instanced with an array to GeoArray.to_mem has been executed
            return self.arr[given]
        else:
            if isinstance(given, tuple):
                # behave like a numpy array
                getitem_params = given
            elif isinstance(given, int) or isinstance(given, slice):
                # return only one band
                getitem_params = [given]
            elif isinstance(given, str):
                # behave like a dictionary
                if self.bandnames:
                    if given not in self.bandnames: raise ValueError("'%s' is not a known band." % given)
                    getitem_params = [self.bandnames[given]]
            else:
                raise ValueError

            self._arr_cache = self.from_path(self.arg, getitem_params)

            return self._arr_cache


    def __getattr__(self, attr):
        # check if the requested attribute can not be present because GeoArray has been instanced with an array
        if attr not in self.__dict__ and not self.is_inmem and attr in ['shape','dtype','geotransform', 'projection']:
            self.set_gdalDataset_meta()

        if attr in self.__dict__:
            return self.__dict__[attr]
        else:
            raise AttributeError('%s object has not attribute %s.' %(self.__class__.__name__, attr))


    def __getstate__(self):
        """Defines how the attributes of GMS object are pickled."""

        # clean array cache in order to avoid cache pickling
        self.flush_cache()
        return self.__dict__


    def __setstate__(self, state):
        """Defines how the attributes of GMS object are unpickled.
        NOTE: This method has been implmemnted because otherwise pickled and unplickled instances show recursion errors
        within __getattr__ when requesting any attribute."""

        self.__dict__ = state


    def set_gdalDataset_meta(self):
        assert self.filePath
        ds = gdal.Open(self.filePath)
        # set private class variables (in order to avoid recursion error)
        self._shape        = tuple([ds.RasterYSize, ds.RasterXSize] + ([ds.RasterCount] if ds.RasterCount>1 else []))
        self._dtype        = gdal_array.GDALTypeCodeToNumericTypeCode(ds.GetRasterBand(1).DataType)
        self._geotransform = ds.GetGeoTransform()
        self._projection   = ds.GetProjection()
        ds = None


    def from_path(self, path, getitem_params=None):
        ds = gdal.Open(path)
        if not ds: raise IOError('Error reading image data at %s.' %path)
        R, C, B = ds.RasterYSize, ds.RasterXSize, ds.RasterCount
        ds = None

        # convert getitem_params to subset area to be read
        rS, rE, cS, cE, bS, bE, bL = [None] * 7

        if getitem_params:
            if len(getitem_params) >= 2:
                givenR, givenC = getitem_params[:2]
                if isinstance(givenR, slice):
                    rS = givenR.start
                    rE = givenR.stop - 1 if givenR.stop is not None else None
                elif isinstance(givenR, int):
                    rS = givenR
                    rE = givenR
                if isinstance(givenC, slice):
                    cS = givenC.start
                    cE = givenC.stop - 1 if givenC.stop is not None else None
                elif isinstance(givenC, int):
                    cS = givenC
                    cE = givenC
            if len(getitem_params) in [1, 3]:
                givenB = getitem_params[2] if len(getitem_params) == 3 else getitem_params[0]
                if isinstance(givenB, slice):
                    bS = givenB.start
                    bE = givenB.stop - 1 if givenB.stop is not None else None
                elif isinstance(givenB, int):
                    bS = givenB
                    bE = givenB
                elif type(givenB) in [tuple, list]:
                    typesInGivenB = [type(i) for i in givenB]
                    assert len(list(set(typesInGivenB))) == 1, \
                        'Mixed data types within the list of bands are not supported.'
                    if isinstance(givenB[0], int):
                        bL = list(givenB)
                    elif isinstance(givenB[0], str):
                        bL = [self.bandnames[i] for i in givenB]
                elif type(givenB) in [str]:
                    bL = [self.bandnames[givenB]]

        # set defaults for not given values
        rS = rS if rS is not None else 0
        rE = rE if rE is not None else R - 1
        cS = cS if cS is not None else 0
        cE = cE if cE is not None else C - 1
        bS = bS if bS is not None else 0
        bE = bE if bE is not None else B - 1
        bL = list(range(bS, bE + 1)) if not bL else bL

        # validate subset area bounds to be read
        msg = lambda v, idx, sz: '%s is out of bounds for axis %s with size %s' %(v, idx, sz)
        for val, axIdx, axSize in zip([rS,rE,cS,cE,bS,bE], [0,0,1,1,2,2], [R,R,C,C,B,B]):
            if not 0 <= val <= axSize - 1: raise ValueError(msg(val,axIdx,axSize))

        # read subset area
        if bL == list(range(0, B)):
            tempArr = gdalnumeric.LoadFile(path, cS, rS, cE - cS + 1, rE - rS + 1)
            if tempArr is None:
                raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
            out_arr = np.swapaxes(np.swapaxes(tempArr, 0, 2), 0, 1) if B > 1 else tempArr
        else:
            ds = gdal.Open(path)
            if ds is None:
                raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
            if len(bL) == 1:
                band = ds.GetRasterBand(bL[0] + 1)
                out_arr= band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
                band = None
            else:
                out_arr = np.empty((rE - rS + 1, cE - cS + 1, len(bL)))
                for i, bIdx in enumerate(bL):
                    band = ds.GetRasterBand(bIdx + 1)
                    out_arr[:, :, i] = band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
                    band = None

            ds = None

        if out_arr is None:
            raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())

        # only set self.arr if the whole cube has been read (in order to avoid sudden shape changes)
        if out_arr.shape==self.shape:
            self.arr = out_arr

        return out_arr


    def save(self):
        raise NotImplementedError


    def dump(self):
        raise NotImplementedError


    def show(self, band=0, figsize=None):
        # TODO implement slice

        plt.figure(figsize=figsize)
        is_3D = len(self.shape)>2
        plt.imshow(self[:,:,band] if is_3D else self[:,:]) # FIXME
        #plt.imshow(self[:30,:30,band] if is_3D else self[:30,:30],interpolation='none')
        plt.show()


    def get_mapPos(self, mapBounds, mapBounds_prj, bandslist=None, arr_gt=None, arr_prj=None, fillVal=0): # TODO implement slice for indexing bands
        if not self.is_inmem:
            arr_gt  = arr_gt  if arr_gt  else self.gdalDataset_meta['geotransform']
            arr_prj = arr_prj if arr_prj else self.gdalDataset_meta['projection']
        else:
            if not arr_gt or not arr_prj:
                raise ValueError('In case of in-mem arrays the respective geotransform and projection of the array has to '
                                 'be passed.')

        shape = self.shape
        R, C  = shape[:2]
        B     = shape[2] if len(shape) > 2 else 1

        #if not bandslist or (bandslist and sorted(bandslist)==list(range(B))): # FIXME activate if bug is fixed
333
        if prj_equal(arr_prj,mapBounds_prj):   ############# WORKAROUND FOR corrupt return values in case of 3D array input that has to be warped
Daniel Scheffler's avatar
Daniel Scheffler committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
            """process whole array at once"""
            sub_arr, sub_gt, sub_prj = get_array_at_mapPos(self, arr_gt, arr_prj, mapBounds, mapBounds_prj,
                                                           fillVal=fillVal)
        else:
            """process bands separately"""
            out_arrs = []
            #for b in bandslist: # FIXME activate if bug is fixed
            for b in range(B):
                sub_arr, sub_gt, sub_prj = get_array_at_mapPos(self, arr_gt, arr_prj, mapBounds, mapBounds_prj,
                                                               band2get=b, fillVal=fillVal)
                out_arrs.append(sub_arr)

            sub_arr = np.dstack(out_arrs) if len(out_arrs) > 1 else sub_arr


        return sub_arr, sub_gt, sub_prj


    def to_mem(self):
        """Reads the whole dataset into memory and sets self.arr to the read data.
        """
        self.arr = self[:]
        return self


    def _to_disk(self):
        """Sets self.arr back to None if GeoArray has been instanced with a file path
        and the whole datadaset has been read."""

        if self.filePath and os.path.isfile(self.filePath):
            self.arr = None
        else:
            warnings.warn('GeoArray object cannot be turned into disk mode because this asserts that GeoArray.filePath '
                          'contains a valid file path. Got %s.' %self.filePath)
        return self


    def flush_cache(self):
        self._arr_cache = None



def _clip_array_at_mapPos(arr, mapBounds, arr_gt, band2clip=None, fillVal=0):
    """
    NOTE: asserts that mapBounds have the same projection like the coordinates in arr_gt

    :param arr:
    :param mapBounds:
    :param arr_gt:
    :param band2clip:        band index of the band to be returned (full array if not given)
    :param fillVal:
    :return:
    """

    # assertions
    assert isinstance(arr_gt, (tuple,list))
    assert isinstance(band2clip, int) or band2clip is None

    # get array metadata
    rows, cols             = arr.shape[:2]
    bands                  = arr.shape[2] if len(arr.shape) == 3 else 1
    arr_dtype              = arr.dtype
396
    ULxy, LLxy, LRxy, URxy = get_corner_coordinates(gt=arr_gt, rows=rows, cols=cols)
Daniel Scheffler's avatar
Daniel Scheffler committed
397
398
399
    arrBounds              = ULxy[0], LRxy[1], LRxy[0], ULxy[1]

    # snap mapBounds to the grid of the array
400
    mapBounds              = snap_bounds_to_pixGrid(mapBounds, arr_gt)
Daniel Scheffler's avatar
Daniel Scheffler committed
401
402
403
404
405
406
407
    xmin, ymin, xmax, ymax = mapBounds

    # get out_gt and out_prj
    out_gt               = list(arr_gt)
    out_gt[0], out_gt[3] = xmin, ymax

    # get image area to read
408
409
    cS, rS = [int(i)   for i in mapXY2imXY((xmin, ymax), arr_gt)]
    cE, rE = [int(i)-1 for i in mapXY2imXY((xmax, ymin), arr_gt)]
Daniel Scheffler's avatar
Daniel Scheffler committed
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426

    if 0 <= rS <= rows - 1 and 0 <= rE <= rows - 1 and 0 <= cS <= cols - 1 and 0 <= cE <= rows - 1:
        """requested area is within the input array"""
        if bands==1:
            out_arr = arr[rS:rE + 1, cS:cE + 1]
        else:
            out_arr = arr[rS:rE + 1, cS:cE + 1, band2clip] if band2clip is not None else arr[rS:rE + 1, cS:cE + 1, :]
    else:
        """requested area is not completely within the input array"""
        # create array according to size of mapBounds + fill with nodata
        tgt_rows  = int(abs((ymax - ymin) / arr_gt[5]))
        tgt_cols  = int(abs((xmax - xmin) / arr_gt[1]))
        tgt_bands = bands if band2clip is None else 1
        tgt_shape = (tgt_rows, tgt_cols, tgt_bands) if tgt_bands > 1 else (tgt_rows, tgt_cols)
        out_arr   = np.full(tgt_shape, fillVal, arr_dtype)

        # calculate image area to be read from input array
427
428
        overlap_poly = get_overlap_polygon(box(*arrBounds), box(*mapBounds))['overlap poly']
        assert overlap_poly, 'The input array and the requested geo area have no overlap - most likely due to a bug.'
Daniel Scheffler's avatar
Daniel Scheffler committed
429
        xmin_in, ymin_in, xmax_in, ymax_in = overlap_poly.bounds
430
431
        cS_in, rS_in = [int(i)   for i in mapXY2imXY((xmin_in, ymax_in), arr_gt)]
        cE_in, rE_in = [int(i)-1 for i in mapXY2imXY((xmax_in, ymin_in), arr_gt)]  # -1 because max values do not represent pixel origins
Daniel Scheffler's avatar
Daniel Scheffler committed
432
433
434
435
436
437
438
439
440

        # read a subset of the input array
        if bands == 1:
            data = arr[rS_in:rE_in + 1, cS_in:cE_in + 1]
        else:
            data = arr[rS_in:rE_in + 1, cS_in:cE_in + 1, band2clip] if band2clip is not None else \
                   arr[rS_in:rE_in + 1, cS_in:cE_in + 1, :]

        # calculate correct area of out_arr to be filled and fill it with read data  from input array
441
442
        cS_out, rS_out = [int(i)   for i in mapXY2imXY((xmin_in, ymax_in), out_gt)]
        cE_out, rE_out = [int(i)-1 for i in mapXY2imXY((xmax_in, ymin_in), out_gt)]  # -1 because max values do not represent pixel origins
Daniel Scheffler's avatar
Daniel Scheffler committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465

        # fill newy created array with read data from input array
        if tgt_bands==1:
            out_arr[rS_out:rE_out + 1, cS_out:cE_out + 1]   = data
        else:
            out_arr[rS_out:rE_out + 1, cS_out:cE_out + 1,:] = data

    return out_arr, out_gt


def get_array_at_mapPos(arr, arr_gt, arr_prj, mapBounds, mapBounds_prj, band2get=None, fillVal=0):
    """

    :param arr:
    :param arr_gt:
    :param arr_prj:
    :param mapBounds:
    :param mapBounds_prj:
    :param band2get:            band index of the band to be returned (full array if not given)
    :param fillVal:
    :return:
    """
    # check if requested bounds have the same projection like the array
466
    samePrj = prj_equal(arr_prj, mapBounds_prj)
Daniel Scheffler's avatar
Daniel Scheffler committed
467
468
469
470
471
472
473
474
475

    if samePrj:
        out_prj         = arr_prj
        out_arr, out_gt = _clip_array_at_mapPos(arr, mapBounds, arr_gt, band2clip=band2get, fillVal=fillVal)

    else:
        # calculate requested corner coordinates in the same projection like the input array (bounds are not sufficient due to projection rotation)
        xmin, ymin, xmax, ymax = mapBounds
        ULxy, URxy, LRxy, LLxy = (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)
476
        ULxy, URxy, LRxy, LLxy = [transform_any_prj(mapBounds_prj, arr_prj, *xy) for xy in
Daniel Scheffler's avatar
Daniel Scheffler committed
477
478
479
480
481
482
483
484
485
486
487
                                  [ULxy, URxy, LRxy, LLxy]]
        mapBounds_arrPrj       = Polygon([ULxy, URxy, LRxy, LLxy]).buffer(arr_gt[1]).bounds

        # read subset of input array as temporary data (that has to be reprojected later)
        temp_arr, temp_gt = _clip_array_at_mapPos(arr, mapBounds_arrPrj, arr_gt, band2clip=band2get, fillVal=fillVal)

        # eliminate no data area for faster warping
        try:
            oneBandArr     = np.all(np.where(temp_arr == fillVal, 0, 1), axis=2) \
                                if len(temp_arr.shape) > 2 else np.where(temp_arr == fillVal, 0, 1)
            corners        = [(i[1], i[0]) for i in
488
                                calc_FullDataset_corner_positions(oneBandArr, assert_four_corners=False)]
Daniel Scheffler's avatar
Daniel Scheffler committed
489
490
491
492
            bounds         = [int(i) for i in Polygon(corners).bounds]
            cS, rS, cE, rE = bounds

            temp_arr               = temp_arr[rS:rE + 1, cS:cE + 1]
493
            temp_gt[0], temp_gt[3] = [int(i) for i in imXY2mapXY((cS, rS), temp_gt)]
Daniel Scheffler's avatar
Daniel Scheffler committed
494
495
496
497
498
499
500
501
        except:
            warnings.warn('Could not eliminate no data area for faster warping. '
                          'Result will not be affected but processing takes a bit longer..')

        #from matplotlib import pyplot as plt
        #plt.figure()
        #plt.imshow(temp_arr[:,:])

502
503
        # calculate requested geo bounds in the target projection, snapped to the output array grid
        mapBounds = snap_bounds_to_pixGrid(mapBounds, arr_gt)
Daniel Scheffler's avatar
Daniel Scheffler committed
504
505
506
507
508
509
510
511
512
513
514
515
        xmin, ymin, xmax, ymax = mapBounds
        out_gt   = list(arr_gt)
        out_gt[0], out_gt[3] = xmin, ymax
        out_rows = int(abs((ymax - ymin) / arr_gt[5]))
        out_cols = int(abs((xmax - xmin) / arr_gt[1])) # FIXME using out_gt and outRowsCols is a workaround for not beeing able to pass output extent in the OUTPUT projection

        # reproject temporary data to target projection (the projection of mapBounds)
        out_arr, out_gt, out_prj = warp_ndarray(temp_arr, temp_gt, arr_prj, mapBounds_prj,
                                                in_nodata=fillVal, out_nodata=fillVal, out_gt=out_gt,
                                                outRowsCols=(out_rows, out_cols), outExtent_within=True,rsp_alg=0)  # FIXME resampling alg

    return out_arr, out_gt, out_prj