GeoArray.py 20.8 KB
Newer Older
Daniel Scheffler's avatar
Daniel Scheffler committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
# -*- coding: utf-8 -*-
__author__='Daniel Scheffler'


import numpy as np
import os
import warnings
from matplotlib import pyplot as plt

# custom
from shapely.geometry import Polygon, box
from osgeo import gdal_array
try:
    from osgeo import gdal
    from osgeo import gdalnumeric
except ImportError:
    import gdal
    import gdalnumeric



from ...map        import coord_calc  as CC
from ...map        import coord_grid  as CG
from ...map        import coord_trafo as CT
from ...map        import projection  as PRJ
from ...map.vector import topology    as TOPO
from ...map.raster.reproject import warp_ndarray



class GeoArray(object):
    def __init__(self, path_or_array, geotransform=None, projection=None, bandnames=None):
        # type: (Any, tuple, str, list) -> GeoArray
        """

        :param path_or_array:   a numpy.ndarray or a valid file path
        :param geotransform:    GDAL geotransform of the given array or file on disk
        :param projection:      projection of the given array or file on disk as WKT string
                                (only needed if GeoArray is instanced with an array)
        :param bandnames:       names of the bands within the input array, e.g. ['mask_1bit', 'mask_clouds']
        """
        if type(path_or_array) not in [str, np.ndarray, type(self)]:
            raise ValueError("%s parameter 'arg' takes only string "
                             "or np.ndarray types. Got %s." %(self.__class__.__name__,type(path_or_array)))

        self.arg           = path_or_array
        self.arr           = self.arg if isinstance(self.arg, np.ndarray)       else None
        self.filePath      = self.arg if isinstance(self.arg, str) and self.arg else None
        self._arr_cache    = None
        self._geotransform = None
        self._projection   = None
        self._shape        = None
        self._dtype        = None

        if isinstance(self.arg, str):
            if not os.path.exists(self.filePath):
                raise FileNotFoundError(self.filePath)

            ds = gdal.Open(self.filePath)
            if not ds:
                raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
            bands = ds.RasterCount
            ds    = None
        else:
            self._shape = self.arr.shape
            self._dtype = self.arr.dtype
            bands = self.arr.shape[2] if len(self.arr.shape) == 3 else 1

        if bandnames:
            assert len(bandnames) == bands, \
                'Number of given bandnames does not match number of bands in array.'
            assert len(list(set([type(b) for b in bandnames]))) == 1 and type(bandnames[0] == 'str'), \
                "'bandnames must be a set of strings. Got other datetypes in there.'"
            self.bandnames = {band: i for i, band in enumerate(bandnames)}
            assert len(self.bandnames) == bands, 'Bands must not have the same name.'
        else:
            self.bandnames = {'B%s' % band: i for i, band in enumerate(range(1, bands + 1))}

        if geotransform:
            self.geotransform = geotransform
        if projection:
            self.projection   = projection


    @property
    def is_inmem(self):
        return isinstance(self.arr, np.ndarray)


    @property
    def shape(self):
        if self._shape:
            return self._shape
        else:
            self.set_gdalDataset_meta()
            return self._shape


    @property
    def dtype(self):
        if self._dtype:
            return self._dtype
        else:
            self.set_gdalDataset_meta()
            return self._dtype


    @property
    def geotransform(self):
        if self._geotransform:
            return self._geotransform
        elif not self.is_inmem:
            self.set_gdalDataset_meta()
            return self._geotransform
        else:
            raise AttributeError("Attribute 'geotransform' has not been set yet.")


    @geotransform.setter
    def geotransform(self, gt):
        if self.filePath:
            assert self.geotransform == gt, "Cannot set %s.geotransform to the given value because it does not " \
                                            "match the geotransform from the file on disk." %self.__class__.__name__
        else:
            self._geotransform = gt


    @property
    def projection(self):
        if self._projection:
            return self._projection
        elif not self.is_inmem:
            self.set_gdalDataset_meta()
            return self._projection
        else:
            raise AttributeError("Attribute 'projection' has not been set yet.")


    @projection.setter
    def projection(self, prj):
        if self.filePath:
            assert self.projection == prj, "Cannot set %s.projection to the given value because it does not " \
                                            "match the projection from the file on disk." %self.__class__.__name__
        else:
            self._projection = prj


    def __getitem__(self, given):
        # TODO check if array cache contains the needed slice and return data from there
        # FIXME negative slices do not work
        if self.is_inmem:
            # this is only the case if GeoArray has been instanced with an array to GeoArray.to_mem has been executed
            return self.arr[given]
        else:
            if isinstance(given, tuple):
                # behave like a numpy array
                getitem_params = given
            elif isinstance(given, int) or isinstance(given, slice):
                # return only one band
                getitem_params = [given]
            elif isinstance(given, str):
                # behave like a dictionary
                if self.bandnames:
                    if given not in self.bandnames: raise ValueError("'%s' is not a known band." % given)
                    getitem_params = [self.bandnames[given]]
            else:
                raise ValueError

            self._arr_cache = self.from_path(self.arg, getitem_params)

            return self._arr_cache


    def __getattr__(self, attr):
        # check if the requested attribute can not be present because GeoArray has been instanced with an array
        if attr not in self.__dict__ and not self.is_inmem and attr in ['shape','dtype','geotransform', 'projection']:
            self.set_gdalDataset_meta()

        if attr in self.__dict__:
            return self.__dict__[attr]
        else:
            raise AttributeError('%s object has not attribute %s.' %(self.__class__.__name__, attr))


    def __getstate__(self):
        """Defines how the attributes of GMS object are pickled."""

        # clean array cache in order to avoid cache pickling
        self.flush_cache()
        return self.__dict__


    def __setstate__(self, state):
        """Defines how the attributes of GMS object are unpickled.
        NOTE: This method has been implmemnted because otherwise pickled and unplickled instances show recursion errors
        within __getattr__ when requesting any attribute."""

        self.__dict__ = state


    def set_gdalDataset_meta(self):
        assert self.filePath
        ds = gdal.Open(self.filePath)
        # set private class variables (in order to avoid recursion error)
        self._shape        = tuple([ds.RasterYSize, ds.RasterXSize] + ([ds.RasterCount] if ds.RasterCount>1 else []))
        self._dtype        = gdal_array.GDALTypeCodeToNumericTypeCode(ds.GetRasterBand(1).DataType)
        self._geotransform = ds.GetGeoTransform()
        self._projection   = ds.GetProjection()
        ds = None


    def from_path(self, path, getitem_params=None):
        ds = gdal.Open(path)
        if not ds: raise IOError('Error reading image data at %s.' %path)
        R, C, B = ds.RasterYSize, ds.RasterXSize, ds.RasterCount
        ds = None

        # convert getitem_params to subset area to be read
        rS, rE, cS, cE, bS, bE, bL = [None] * 7

        if getitem_params:
            if len(getitem_params) >= 2:
                givenR, givenC = getitem_params[:2]
                if isinstance(givenR, slice):
                    rS = givenR.start
                    rE = givenR.stop - 1 if givenR.stop is not None else None
                elif isinstance(givenR, int):
                    rS = givenR
                    rE = givenR
                if isinstance(givenC, slice):
                    cS = givenC.start
                    cE = givenC.stop - 1 if givenC.stop is not None else None
                elif isinstance(givenC, int):
                    cS = givenC
                    cE = givenC
            if len(getitem_params) in [1, 3]:
                givenB = getitem_params[2] if len(getitem_params) == 3 else getitem_params[0]
                if isinstance(givenB, slice):
                    bS = givenB.start
                    bE = givenB.stop - 1 if givenB.stop is not None else None
                elif isinstance(givenB, int):
                    bS = givenB
                    bE = givenB
                elif type(givenB) in [tuple, list]:
                    typesInGivenB = [type(i) for i in givenB]
                    assert len(list(set(typesInGivenB))) == 1, \
                        'Mixed data types within the list of bands are not supported.'
                    if isinstance(givenB[0], int):
                        bL = list(givenB)
                    elif isinstance(givenB[0], str):
                        bL = [self.bandnames[i] for i in givenB]
                elif type(givenB) in [str]:
                    bL = [self.bandnames[givenB]]

        # set defaults for not given values
        rS = rS if rS is not None else 0
        rE = rE if rE is not None else R - 1
        cS = cS if cS is not None else 0
        cE = cE if cE is not None else C - 1
        bS = bS if bS is not None else 0
        bE = bE if bE is not None else B - 1
        bL = list(range(bS, bE + 1)) if not bL else bL

        # validate subset area bounds to be read
        msg = lambda v, idx, sz: '%s is out of bounds for axis %s with size %s' %(v, idx, sz)
        for val, axIdx, axSize in zip([rS,rE,cS,cE,bS,bE], [0,0,1,1,2,2], [R,R,C,C,B,B]):
            if not 0 <= val <= axSize - 1: raise ValueError(msg(val,axIdx,axSize))

        # read subset area
        if bL == list(range(0, B)):
            tempArr = gdalnumeric.LoadFile(path, cS, rS, cE - cS + 1, rE - rS + 1)
            if tempArr is None:
                raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
            out_arr = np.swapaxes(np.swapaxes(tempArr, 0, 2), 0, 1) if B > 1 else tempArr
        else:
            ds = gdal.Open(path)
            if ds is None:
                raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
            if len(bL) == 1:
                band = ds.GetRasterBand(bL[0] + 1)
                out_arr= band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
                band = None
            else:
                out_arr = np.empty((rE - rS + 1, cE - cS + 1, len(bL)))
                for i, bIdx in enumerate(bL):
                    band = ds.GetRasterBand(bIdx + 1)
                    out_arr[:, :, i] = band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
                    band = None

            ds = None

        if out_arr is None:
            raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())

        # only set self.arr if the whole cube has been read (in order to avoid sudden shape changes)
        if out_arr.shape==self.shape:
            self.arr = out_arr

        return out_arr


    def save(self):
        raise NotImplementedError


    def dump(self):
        raise NotImplementedError


    def show(self, band=0, figsize=None):
        # TODO implement slice

        plt.figure(figsize=figsize)
        is_3D = len(self.shape)>2
        plt.imshow(self[:,:,band] if is_3D else self[:,:]) # FIXME
        #plt.imshow(self[:30,:30,band] if is_3D else self[:30,:30],interpolation='none')
        plt.show()


    def get_mapPos(self, mapBounds, mapBounds_prj, bandslist=None, arr_gt=None, arr_prj=None, fillVal=0): # TODO implement slice for indexing bands
        if not self.is_inmem:
            arr_gt  = arr_gt  if arr_gt  else self.gdalDataset_meta['geotransform']
            arr_prj = arr_prj if arr_prj else self.gdalDataset_meta['projection']
        else:
            if not arr_gt or not arr_prj:
                raise ValueError('In case of in-mem arrays the respective geotransform and projection of the array has to '
                                 'be passed.')

        shape = self.shape
        R, C  = shape[:2]
        B     = shape[2] if len(shape) > 2 else 1

        #if not bandslist or (bandslist and sorted(bandslist)==list(range(B))): # FIXME activate if bug is fixed
        if PRJ.prj_equal(arr_prj,mapBounds_prj):   ############# WORKAROUND FOR corrupt return values in case of 3D array input that has to be warped
            """process whole array at once"""
            sub_arr, sub_gt, sub_prj = get_array_at_mapPos(self, arr_gt, arr_prj, mapBounds, mapBounds_prj,
                                                           fillVal=fillVal)
        else:
            """process bands separately"""
            out_arrs = []
            #for b in bandslist: # FIXME activate if bug is fixed
            for b in range(B):
                sub_arr, sub_gt, sub_prj = get_array_at_mapPos(self, arr_gt, arr_prj, mapBounds, mapBounds_prj,
                                                               band2get=b, fillVal=fillVal)
                out_arrs.append(sub_arr)

            sub_arr = np.dstack(out_arrs) if len(out_arrs) > 1 else sub_arr


        return sub_arr, sub_gt, sub_prj


    def to_mem(self):
        """Reads the whole dataset into memory and sets self.arr to the read data.
        """
        self.arr = self[:]
        return self


    def _to_disk(self):
        """Sets self.arr back to None if GeoArray has been instanced with a file path
        and the whole datadaset has been read."""

        if self.filePath and os.path.isfile(self.filePath):
            self.arr = None
        else:
            warnings.warn('GeoArray object cannot be turned into disk mode because this asserts that GeoArray.filePath '
                          'contains a valid file path. Got %s.' %self.filePath)
        return self


    def flush_cache(self):
        self._arr_cache = None



def _clip_array_at_mapPos(arr, mapBounds, arr_gt, band2clip=None, fillVal=0):
    """
    NOTE: asserts that mapBounds have the same projection like the coordinates in arr_gt

    :param arr:
    :param mapBounds:
    :param arr_gt:
    :param band2clip:        band index of the band to be returned (full array if not given)
    :param fillVal:
    :return:
    """

    # assertions
    assert isinstance(arr_gt, (tuple,list))
    assert isinstance(band2clip, int) or band2clip is None

    # get array metadata
    rows, cols             = arr.shape[:2]
    bands                  = arr.shape[2] if len(arr.shape) == 3 else 1
    arr_dtype              = arr.dtype
    ULxy, LLxy, LRxy, URxy = CC.get_corner_coordinates(gt=arr_gt, rows=rows, cols=cols)
    arrBounds              = ULxy[0], LRxy[1], LRxy[0], ULxy[1]

    # snap mapBounds to the grid of the array
    mapBounds              = CG.snap_bounds_to_pixGrid(mapBounds, arr_gt)
    xmin, ymin, xmax, ymax = mapBounds

    # get out_gt and out_prj
    out_gt               = list(arr_gt)
    out_gt[0], out_gt[3] = xmin, ymax

    # get image area to read
    cS, rS = [int(i)   for i in CT.mapXY2imXY((xmin, ymax), arr_gt)]
    cE, rE = [int(i)-1 for i in CT.mapXY2imXY((xmax, ymin), arr_gt)]

    if 0 <= rS <= rows - 1 and 0 <= rE <= rows - 1 and 0 <= cS <= cols - 1 and 0 <= cE <= rows - 1:
        """requested area is within the input array"""
        if bands==1:
            out_arr = arr[rS:rE + 1, cS:cE + 1]
        else:
            out_arr = arr[rS:rE + 1, cS:cE + 1, band2clip] if band2clip is not None else arr[rS:rE + 1, cS:cE + 1, :]
    else:
        """requested area is not completely within the input array"""
        # create array according to size of mapBounds + fill with nodata
        tgt_rows  = int(abs((ymax - ymin) / arr_gt[5]))
        tgt_cols  = int(abs((xmax - xmin) / arr_gt[1]))
        tgt_bands = bands if band2clip is None else 1
        tgt_shape = (tgt_rows, tgt_cols, tgt_bands) if tgt_bands > 1 else (tgt_rows, tgt_cols)
        out_arr   = np.full(tgt_shape, fillVal, arr_dtype)

        # calculate image area to be read from input array
        overlap_poly = TOPO.get_overlap_polygon(box(*arrBounds), box(*mapBounds))['overlap poly']
        assert overlap_poly, 'The input array and the requested map area have no overlap - most likely due to a bug.'
        xmin_in, ymin_in, xmax_in, ymax_in = overlap_poly.bounds
        cS_in, rS_in = [int(i)   for i in CT.mapXY2imXY((xmin_in, ymax_in), arr_gt)]
        cE_in, rE_in = [int(i)-1 for i in CT.mapXY2imXY((xmax_in, ymin_in), arr_gt)]  # -1 because max values do not represent pixel origins

        # read a subset of the input array
        if bands == 1:
            data = arr[rS_in:rE_in + 1, cS_in:cE_in + 1]
        else:
            data = arr[rS_in:rE_in + 1, cS_in:cE_in + 1, band2clip] if band2clip is not None else \
                   arr[rS_in:rE_in + 1, cS_in:cE_in + 1, :]

        # calculate correct area of out_arr to be filled and fill it with read data  from input array
        cS_out, rS_out = [int(i)   for i in CT.mapXY2imXY((xmin_in, ymax_in), out_gt)]
        cE_out, rE_out = [int(i)-1 for i in CT.mapXY2imXY((xmax_in, ymin_in), out_gt)]  # -1 because max values do not represent pixel origins

        # fill newy created array with read data from input array
        if tgt_bands==1:
            out_arr[rS_out:rE_out + 1, cS_out:cE_out + 1]   = data
        else:
            out_arr[rS_out:rE_out + 1, cS_out:cE_out + 1,:] = data

    return out_arr, out_gt


def get_array_at_mapPos(arr, arr_gt, arr_prj, mapBounds, mapBounds_prj, band2get=None, fillVal=0):
    """

    :param arr:
    :param arr_gt:
    :param arr_prj:
    :param mapBounds:
    :param mapBounds_prj:
    :param band2get:            band index of the band to be returned (full array if not given)
    :param fillVal:
    :return:
    """
    # check if requested bounds have the same projection like the array
    samePrj = PRJ.prj_equal(arr_prj, mapBounds_prj)

    if samePrj:
        out_prj         = arr_prj
        out_arr, out_gt = _clip_array_at_mapPos(arr, mapBounds, arr_gt, band2clip=band2get, fillVal=fillVal)

    else:
        # calculate requested corner coordinates in the same projection like the input array (bounds are not sufficient due to projection rotation)
        xmin, ymin, xmax, ymax = mapBounds
        ULxy, URxy, LRxy, LLxy = (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)
        ULxy, URxy, LRxy, LLxy = [CT.transform_any_prj(mapBounds_prj, arr_prj, *xy) for xy in
                                  [ULxy, URxy, LRxy, LLxy]]
        mapBounds_arrPrj       = Polygon([ULxy, URxy, LRxy, LLxy]).buffer(arr_gt[1]).bounds

        # read subset of input array as temporary data (that has to be reprojected later)
        temp_arr, temp_gt = _clip_array_at_mapPos(arr, mapBounds_arrPrj, arr_gt, band2clip=band2get, fillVal=fillVal)

        # eliminate no data area for faster warping
        try:
            oneBandArr     = np.all(np.where(temp_arr == fillVal, 0, 1), axis=2) \
                                if len(temp_arr.shape) > 2 else np.where(temp_arr == fillVal, 0, 1)
            corners        = [(i[1], i[0]) for i in
                                CC.calc_FullDataset_corner_positions(oneBandArr, assert_four_corners=False)]
            bounds         = [int(i) for i in Polygon(corners).bounds]
            cS, rS, cE, rE = bounds

            temp_arr               = temp_arr[rS:rE + 1, cS:cE + 1]
            temp_gt[0], temp_gt[3] = [int(i) for i in CT.imXY2mapXY((cS, rS), temp_gt)]
        except:
            warnings.warn('Could not eliminate no data area for faster warping. '
                          'Result will not be affected but processing takes a bit longer..')

        #from matplotlib import pyplot as plt
        #plt.figure()
        #plt.imshow(temp_arr[:,:])

        # calculate requested map bounds in the target projection, snapped to the output array grid
        mapBounds = CG.snap_bounds_to_pixGrid(mapBounds, arr_gt)
        xmin, ymin, xmax, ymax = mapBounds
        out_gt   = list(arr_gt)
        out_gt[0], out_gt[3] = xmin, ymax
        out_rows = int(abs((ymax - ymin) / arr_gt[5]))
        out_cols = int(abs((xmax - xmin) / arr_gt[1])) # FIXME using out_gt and outRowsCols is a workaround for not beeing able to pass output extent in the OUTPUT projection

        # reproject temporary data to target projection (the projection of mapBounds)
        out_arr, out_gt, out_prj = warp_ndarray(temp_arr, temp_gt, arr_prj, mapBounds_prj,
                                                in_nodata=fillVal, out_nodata=fillVal, out_gt=out_gt,
                                                outRowsCols=(out_rows, out_cols), outExtent_within=True,rsp_alg=0)  # FIXME resampling alg

    return out_arr, out_gt, out_prj