gdalnumeric.py 1.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
# -*- coding: utf-8 -*-
__author__ = "Daniel Scheffler"

try:
    from osgeo import gdal
    from osgeo import gdalnumeric
    from osgeo import gdalconst
except ImportError:
    import gdal
10
    import gdalnumeric  # FIXME this will import this __module__
11
12
13
14
15
16
17
    import gdalconst


def OpenNumPyArray(array):
    """This function emulates the functionality of gdalnumeric.OpenNumPyArray() which is not available in GDAL versions
     below 2.1.0 (?).

18
    :param array:   <numpy.ndarray> in the shape (bands, rows, columns)
19
20
    :return:
    """
21
22
23
24
25
26
27
28
29
30
31
32
    if array.ndim==2:
        rows, cols = array.shape
        bands      = 1
    elif array.ndim==3:
        bands,rows,cols=array.shape
    else:
        raise ValueError('OpenNumPyArray() currently only supports 2D and 3D arrays. Given array shape is %s.'
                         %str(array.shape))

    # get output datatype
    gdal_dtype = gdalnumeric.NumericTypeCodeToGDALTypeCode(array.dtype) # FIXME not all datatypes can be translated
    assert gdal_dtype is not None, 'Datatype %s is currently not supported by OpenNumPyArray().' %array.dtype
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

    mem_drv = gdal.GetDriverByName('MEM')
    mem_ds  = mem_drv.Create('/vsimem/tmp/memfile.mem', cols, rows, bands, gdal_dtype)

    if mem_ds is None:
        raise Exception(gdal.GetLastErrorMsg())

    for bandNr in range(bands):
        band = mem_ds.GetRasterBand(bandNr + 1)
        band.WriteArray(array[:, :, bandNr] if bands > 1 else array)
        band = None

    mem_ds.FlushCache() # Write to disk.
    return mem_ds


def get_gdalnumeric_func(funcName):
    try:
51
        return getattr(gdalnumeric, funcName)
52
53
54
55
56
    except AttributeError:
        if funcName in globals():
            return globals()[funcName]
        else:
            raise AttributeError("'gdalnumeric' has no attribute '%s'." % funcName)