metadata.py 8.64 KB
Newer Older
1 2 3 4
# -*- coding: utf-8 -*-

import os
from pprint import pformat
5
from copy import deepcopy
6
from typing import Union  # noqa F401  # flake8 issue
7 8

from geopandas import GeoDataFrame, GeoSeries
9
import numpy as np
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
try:
    from osgeo import gdal
except ImportError:
    import gdal


autohandled_meta = [
    'bands',
    'byte_order',
    'coordinate_system_string',
    'data_type',
    'file_type',
    'header_offset',
    'interleave',
    'lines',
    'samples',
]


class GDAL_Metadata(object):
    def __init__(self, filePath='', nbands=1):
        # privates
        self._global_meta = dict()
        self._band_meta = dict()

        self.bands = nbands
        self.filePath = filePath
37
        self.fileFormat = ''
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

        if filePath:
            self.read_from_file(filePath)

    @classmethod
    def from_file(cls, filePath):
        return GDAL_Metadata(filePath=filePath)

    @classmethod
    def from_DataFrame(cls, dataframe):
        # type: (GeoDataFrame) -> 'GDAL_Metadata'

        if dataframe.empty:
            raise ValueError('DataFame must not be empty.')

        GDMD = GDAL_Metadata(nbands=dataframe.shape[1])
        GDMD.df = dataframe

        return GDMD

    def to_DataFrame(self):
        df = GeoDataFrame(columns=range(self.bands))

        # add global meta
        for k, v in self.global_meta.items():
            df.loc[k] = GeoSeries(dict(zip(df.columns, [v] * len(df.columns))))

        # add band meta
        for k, v in self.band_meta.items():
            df.loc[k] = GeoSeries(dict(zip(df.columns, v)))

        return df

    @property
    def global_meta(self):
        return self._global_meta

    @global_meta.setter
    def global_meta(self, meta_dict):
        if not isinstance(meta_dict, dict):
            raise TypeError("Expected type 'dict', received '%s'." % type(meta_dict))

80
        self._global_meta = meta_dict  # TODO convert strings to useful types
81 82 83 84 85 86

    @property
    def band_meta(self):
        return self._band_meta

    @band_meta.setter
87 88 89 90 91 92 93 94 95 96 97 98
    def band_meta(self, meta_dict):
        if not isinstance(meta_dict, dict):
            raise TypeError("Expected type 'dict', received '%s'." % type(meta_dict))

        for k, v in meta_dict.items():
            if not isinstance(v, list):
                raise TypeError('The values of the given dictionary must be lists. Received %s for %s.' % (type(v), k))
            if len(v) != self.bands:
                raise ValueError("The length of the given lists must be equal to the number of bands. "
                                 "Received a list with %d items for '%s'." % (len(v), k))

        self._band_meta = meta_dict  # TODO convert strings to useful types
99 100 101 102 103 104 105 106

    @property
    def all_meta(self):
        all_meta = self.global_meta.copy()
        all_meta.update(self.band_meta)
        return all_meta

    @staticmethod
107
    def _convert_param_from_str(param_value):
108 109 110 111 112 113
        try:
            try:
                return int(param_value)  # NOTE: float('0.34') causes ValueError: invalid literal for int() with base 10
            except ValueError:
                return float(param_value)
        except ValueError:
114 115 116 117 118 119 120 121 122 123 124 125 126 127
            if param_value.startswith('{'):
                param_value = param_value.split('{')[1]
            if param_value.endswith('}'):
                param_value = param_value.split('}')[0]
            return param_value.strip()

    def _convert_param_to_ENVI_str(self, param_value):
        if isinstance(param_value, int):
            return str(param_value)

        elif isinstance(param_value, float):
            return '%f' % param_value

        elif isinstance(param_value, list):
128
            return '{ ' + ',\n'.join([self._convert_param_to_ENVI_str(i) for i in param_value]) + ' }'
129 130

        else:
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
            return param_value

    def read_from_file(self, filePath):
        assert ' ' not in filePath, "The given path contains whitespaces. This is not supported by GDAL."

        if not os.path.exists(filePath):
            raise FileNotFoundError(filePath)

        ds = gdal.Open(filePath)

        try:
            if not ds:
                raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())

            self.bands = ds.RasterCount
146
            self.fileFormat = ds.GetDriver().GetDescription()
147 148 149 150 151

            ###############
            # ENVI format #
            ###############

152
            if self.fileFormat == 'ENVI':
153 154 155 156 157 158 159 160 161 162 163 164 165
                metadict = ds.GetMetadata('ENVI')

                for k, v in metadict.items():

                    if k not in autohandled_meta:

                        if len(v.split(',')) == self.bands:
                            # band meta parameter
                            item_list = [
                                item_str.split('{')[1].strip() if item_str.strip().startswith('{') else
                                item_str.split('}')[0].strip() if item_str.strip().endswith('}') else
                                item_str.strip() for item_str in v.split(',')]

166
                            self.band_meta[k] = [self._convert_param_from_str(item_str) for item_str in item_list]
167 168 169

                        else:
                            # global meta parameter
170
                            self.global_meta[k] = self._convert_param_from_str(v)
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189

            #####################
            # remaining formats #
            #####################

            else:
                # read global domain metadata
                self.global_meta = ds.GetMetadata()

                # read band domain metadata
                for b in range(self.bands):
                    band = ds.GetRasterBand(b + 1)
                    # meta_gs = GeoSeries(band.GetMetadata())
                    bandmeta_dict = band.GetMetadata()

                    for k, v in bandmeta_dict.items():
                        if k not in self.band_meta:
                            self.band_meta[k] = []

190
                        self.band_meta[k].append(self._convert_param_from_str(v))
191 192 193 194 195 196 197 198 199 200 201

                    # # fill metadata
                    # self.df[b] = meta_gs
                    del band

        finally:
            del ds

        return self.all_meta

    def __repr__(self):
202 203 204 205 206
        return 'Metadata: \n\n' + pformat(self.all_meta)

    def to_ENVI_metadict(self):
        return dict(zip(self.all_meta.keys(),
                        [self._convert_param_to_ENVI_str(i) for i in self.all_meta.values()]))
207

208 209
    def get_subset(self, bands2extract=None, keys2extract=None):
        # type: (Union[slice, list, np.ndarray], Union[str, list]) -> 'GDAL_Metadata'
210 211
        meta_sub = deepcopy(self)

212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
        # subset bands
        if bands2extract is not None:
            if isinstance(bands2extract, list):
                bands2extract = np.array(bands2extract)
            elif isinstance(bands2extract, (np.ndarray, slice)):
                pass  # all fine
            else:
                raise TypeError(bands2extract)

            for k, v in meta_sub.band_meta.items():
                meta_sub.band_meta[k] = list(np.array(v)[bands2extract])

            meta_sub.bands = len(list(range(*bands2extract.indices(bands2extract.stop)))) \
                if isinstance(bands2extract, slice) else bands2extract.size

        # subset metadata keys
        if keys2extract:
            keys2extract = [keys2extract] if isinstance(keys2extract, str) else keys2extract

            # global_meta = meta_sub.global_meta.copy()
            for k in meta_sub.global_meta.copy().keys():
                if k not in keys2extract:
                    del meta_sub.global_meta[k]
235

236 237 238 239 240 241
            for k in meta_sub.band_meta.copy().keys():
                if k not in keys2extract:
                    del meta_sub.band_meta[k]

            if not meta_sub.all_meta:
                raise ValueError(keys2extract, 'The given metadata keys do not exist.')
242 243

        return meta_sub
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264

    def __getitem__(self, given):
        if isinstance(given, int):
            return self.get_subset(bands2extract=slice(given, given + 1))
        elif isinstance(given, slice):
            return self.get_subset(bands2extract=given)
        elif isinstance(given, str):
            return self.get_subset(keys2extract=given)
        elif isinstance(given, list):
            if isinstance(given[0], str):
                return self.get_subset(keys2extract=given)
            elif isinstance(given[0], int):
                return self.get_subset(bands2extract=given)
            else:
                raise TypeError(given, 'Given list must contain string or integer items.')
        elif isinstance(given, np.ndarray):
            if given.ndim != 1:
                raise TypeError(given, 'Given numpy array must be one-dimensional.')
            return self.get_subset(bands2extract=given)
        else:
            raise TypeError(given)