Coverage for gws-app/gws/gis/gdalx/__init__.py: 75%
370 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-16 23:09 +0200
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-16 23:09 +0200
1"""GDAL/OGR wrapper."""
3from typing import Optional, Iterable, cast
5import contextlib
7from osgeo import gdal
8from osgeo import ogr
9from osgeo import osr
11import gws
12import gws.base.shape
13import gws.lib.crs
14import gws.lib.bounds
15import gws.lib.datetimex as datetimex
18class Error(gws.Error):
19 pass
22class DriverInfo(gws.Data):
23 index: int
24 name: str
25 longName: str
26 metaData: dict
29class _DriverInfoCache(gws.Data):
30 infos: list[DriverInfo]
31 extToName: dict
32 vectorNames: set[str]
33 rasterNames: set[str]
36class _DataSetOptions(gws.Data):
37 path: str
38 mode: str
39 driver: str
40 encoding: str
41 defaultCrs: gws.Crs
42 geometryAsText: bool
43 gdalOpts: dict
46def drivers() -> list[DriverInfo]:
47 """Enumerate GDAL drivers."""
49 di = _fetch_driver_infos()
50 return di.infos
53def open_raster(
54 path: str,
55 mode: str = 'r',
56 driver: str = '',
57 default_crs: Optional[gws.Crs] = None,
58 **opts,
59) -> 'RasterDataSet':
60 """Create a raster DataSet from a path.
62 Args:
63 path: File path.
64 mode: 'r' (=read), 'a' (=update), 'w' (=create/write)
65 driver: Driver name, if omitted, will be suggested from the path extension.
66 default_crs: Default CRS for geometries (fallback to Webmercator).
67 opts: Options for gdal.OpenEx/CreateDataSource.
68 """
70 dso = _DataSetOptions(
71 path=path,
72 mode=mode,
73 driver=driver,
74 defaultCrs=default_crs,
75 gdalOpts=opts,
76 )
78 return cast(RasterDataSet, _open(dso, need_raster=True))
81def open_vector(
82 path: str,
83 mode: str = 'r',
84 driver: str = '',
85 encoding: Optional[str] = 'utf8',
86 default_crs: Optional[gws.Crs] = None,
87 geometry_as_text: bool = False,
88 **opts,
89) -> 'VectorDataSet':
90 """Create a vector DataSet from a path.
92 Args:
93 path: File path.
94 mode: 'r' (=read), 'a' (=update), 'w' (=create/write)
95 driver: Driver name, if omitted, will be suggested from the path extension.
96 encoding: If not None, strings will be automatically decoded.
97 default_crs: Default CRS for geometries (fallback to Webmercator).
98 geometry_as_text: Don't interpret geometry, extract raw WKT.
99 opts: Options for gdal.OpenEx/CreateDataSource.
102 Returns:
103 DataSet object.
105 """
107 dso = _DataSetOptions(
108 path=path,
109 mode=mode,
110 driver=driver,
111 defaultCrs=default_crs,
112 encoding=encoding,
113 geometryAsText=geometry_as_text,
114 gdalOpts=opts,
115 )
117 return cast(VectorDataSet, _open(dso, need_raster=False))
120def _open(dso: _DataSetOptions, need_raster):
121 if not dso.mode:
122 dso.mode = 'r'
123 if dso.mode not in 'rwa':
124 raise Error(f'invalid open mode {dso.mode!r}')
126 gdal.UseExceptions()
128 drv = _driver_from_args(dso.path, dso.driver, need_raster)
129 dso.defaultCrs = dso.defaultCrs or gws.lib.crs.WEBMERCATOR
131 if dso.mode == 'w':
132 gd = drv.CreateDataSource(dso.path, **dso.gdalOpts)
133 if gd is None:
134 raise Error(f'cannot create {dso.path!r}')
135 if need_raster:
136 return RasterDataSet(dso, gd)
137 return VectorDataSet(dso, gd)
139 flags = gdal.OF_VERBOSE_ERROR
140 if dso.mode == 'r':
141 flags += gdal.OF_READONLY
142 if dso.mode == 'a':
143 flags += gdal.OF_UPDATE
144 if need_raster:
145 flags += gdal.OF_RASTER
146 else:
147 flags += gdal.OF_VECTOR
149 gd = gdal.OpenEx(dso.path, flags, **dso.gdalOpts)
150 if gd is None:
151 raise Error(f'cannot open {dso.path!r}')
153 if need_raster:
154 return RasterDataSet(dso, gd)
155 return VectorDataSet(dso, gd)
158def open_from_image(image: gws.Image, bounds: gws.Bounds) -> 'RasterDataSet':
159 """Create an in-memory Dataset from an Image.
161 Args:
162 image: Image object
163 bounds: geographic bounds
164 """
166 gdal.UseExceptions()
168 drv = gdal.GetDriverByName('MEM')
169 img_array = image.to_array()
170 band_count = img_array.shape[2]
172 gd = drv.Create('', img_array.shape[1], img_array.shape[0], band_count, gdal.GDT_Byte)
173 for band in range(band_count):
174 gd.GetRasterBand(band + 1).WriteArray(img_array[:, :, band])
176 ext = bounds.extent
178 src_res_x = (ext[2] - ext[0]) / gd.RasterXSize
179 src_res_y = (ext[1] - ext[3]) / gd.RasterYSize
181 src_transform = (
182 ext[0],
183 src_res_x,
184 0,
185 ext[3],
186 0,
187 src_res_y,
188 )
190 gd.SetGeoTransform(src_transform)
191 gd.SetSpatialRef(_srs_from_srid(bounds.crs.srid))
193 dso = _DataSetOptions(path='')
194 return RasterDataSet(dso, gd)
197##
200class _DataSet:
201 gdDataset: gdal.Dataset
202 gdDriver: gdal.Driver
203 dso: _DataSetOptions
204 driverName: str
206 def __init__(self, dso: _DataSetOptions, gd_dataset):
207 self.gdDataset = gd_dataset
208 self.gdDriver = self.gdDataset.GetDriver()
209 self.driverName = self.gdDriver.GetDescription()
210 self.dso = dso
212 def __enter__(self):
213 return self
215 def __exit__(self, exc_type, exc_val, exc_tb):
216 self.close()
217 return False
219 def close(self):
220 self.gdDataset.FlushCache()
221 setattr(self, 'gdDataset', None)
223 def crs(self) -> Optional[gws.Crs]:
224 srid = _srid_from_srs(self.gdDataset.GetSpatialRef())
225 return gws.lib.crs.get(srid) if srid else None
228class RasterDataSet(_DataSet):
229 def create_copy(self, path: str, driver: str = '', strict=False, **opts):
230 """Create a copy of a DataSet."""
232 gdal.UseExceptions()
234 drv = _driver_from_args(path, driver, need_raster=True)
235 gd = drv.CreateCopy(path, self.gdDataset, 1 if strict else 0, **opts)
236 gd.SetMetadata(self.gdDataset.GetMetadata())
237 gd.FlushCache()
238 gd = None
240 def bounds(self) -> gws.Bounds:
241 gt = self.gdDataset.GetGeoTransform()
242 x0 = gt[0]
243 x1 = x0 + gt[1] * self.gdDataset.RasterXSize
244 y1 = gt[3]
245 y0 = y1 + gt[5] * self.gdDataset.RasterYSize
247 crs = self.crs() or self.dso.defaultCrs
249 # gws.log.debug(f'{crs.srid=} {crs.isYX=} {(x0, y0, x1, y1)}')
251 return gws.lib.bounds.from_extent((x0, y0, x1, y1), crs, always_xy=True)
254class VectorDataSet(_DataSet):
255 @contextlib.contextmanager
256 def transaction(self):
257 self.gdDataset.StartTransaction()
258 try:
259 yield self
260 self.gdDataset.CommitTransaction()
261 except:
262 self.gdDataset.RollbackTransaction()
263 raise
265 def create_layer(
266 self,
267 name: str,
268 columns: dict[str, gws.AttributeType],
269 geometry_type: gws.GeometryType = None,
270 crs: gws.Crs = None,
271 overwrite=False,
272 *options,
273 ) -> 'VectorLayer':
274 opts = list(options)
275 if overwrite:
276 opts.append('OVERWRITE=YES')
278 geom_type = ogr.wkbUnknown
279 srs = None
281 if geometry_type:
282 geom_type = _GEOM_TO_OGR.get(geometry_type)
283 if not geom_type:
284 gws.log.warning(f'gdal: unsupported {geometry_type=}')
285 geom_type = ogr.wkbUnknown
286 crs = crs or self.dso.defaultCrs
287 srs = _srs_from_srid(crs.srid)
289 gd_layer = self.gdDataset.CreateLayer(
290 name,
291 geom_type=geom_type,
292 srs=srs,
293 options=opts,
294 )
295 for col_name, col_type in columns.items():
296 gd_layer.CreateField(ogr.FieldDefn(col_name, _ATTR_TO_OGR[col_type]))
298 return VectorLayer(self, gd_layer)
300 def layers(self) -> list['VectorLayer']:
301 cnt = self.gdDataset.GetLayerCount()
302 return [VectorLayer(self, self.gdDataset.GetLayerByIndex(n)) for n in range(cnt)]
304 def layer(self, name_or_index: str | int) -> Optional['VectorLayer']:
305 gd_layer = None
306 if isinstance(name_or_index, int):
307 gd_layer = self.gdDataset.GetLayerByIndex(name_or_index)
308 elif isinstance(name_or_index, str):
309 gd_layer = self.gdDataset.GetLayerByName(name_or_index)
310 return VectorLayer(self, gd_layer) if gd_layer else None
312 def require_layer(self, name_or_index: str | int) -> 'VectorLayer':
313 """Get a layer by name or index, raise an error if not found."""
314 la = self.layer(name_or_index)
315 if la:
316 return la
317 raise Error(f'layer {name_or_index} not found')
320class VectorLayer:
321 name: str
322 dso: _DataSetOptions
323 gdLayer: ogr.Layer
324 gdDefn: ogr.FeatureDefn
326 def __init__(self, ds: VectorDataSet, gd_layer: ogr.Layer):
327 self.gdLayer = gd_layer
328 self.gdDefn = self.gdLayer.GetLayerDefn()
329 self.name = self.gdDefn.GetName()
330 self.dso = ds.dso
332 def describe(self) -> gws.DataSetDescription:
333 desc = gws.DataSetDescription(
334 columns=[],
335 columnMap={},
336 fullName=self.name,
337 geometryName='',
338 geometrySrid=0,
339 geometryType='',
340 name=self.name,
341 schema='',
342 )
344 cols = []
346 fid_col = self.gdLayer.GetFIDColumn()
347 if fid_col:
348 cols.append(
349 gws.ColumnDescription(
350 name=fid_col,
351 type=_OGR_TO_ATTR[ogr.OFTInteger],
352 nativeType=ogr.OFTInteger,
353 isPrimaryKey=True,
354 columnIndex=0,
355 )
356 )
358 for i in range(self.gdDefn.GetFieldCount()):
359 fdef = self.gdDefn.GetFieldDefn(i)
360 typ = fdef.GetType()
361 if typ not in _OGR_TO_ATTR:
362 continue
363 cols.append(
364 gws.ColumnDescription(
365 name=fdef.GetName(),
366 type=_OGR_TO_ATTR[typ],
367 nativeType=typ,
368 columnIndex=i,
369 )
370 )
372 for i in range(self.gdDefn.GetGeomFieldCount()):
373 fdef: ogr.GeomFieldDefn = self.gdDefn.GetGeomFieldDefn(i)
374 typ = fdef.GetType()
375 cols.append(
376 gws.ColumnDescription(
377 name=fdef.GetName() or 'geom',
378 type=gws.AttributeType.geometry,
379 nativeType=typ,
380 columnIndex=i,
381 geometryType=_OGR_TO_GEOM.get(typ) or gws.GeometryType.geometry,
382 geometrySrid=_srid_from_srs(fdef.GetSpatialRef()) or self.dso.defaultCrs.srid,
383 )
384 )
386 desc.columns = cols
387 desc.columnMap = {c.name: c for c in cols}
389 for c in cols:
390 # NB take the last geom
391 if c.geometryType:
392 desc.geometryName = c.name
393 desc.geometryType = c.geometryType
394 desc.geometrySrid = c.geometrySrid
396 return desc
398 def insert(self, records: list[gws.FeatureRecord]) -> list[int]:
399 desc = self.describe()
400 fids = []
402 for rec in records:
403 gd_feature = ogr.Feature(self.gdDefn)
404 if desc.geometryType and rec.shape:
405 gd_feature.SetGeometry(ogr.CreateGeometryFromWkt(rec.shape.to_wkt(), _srs_from_srid(rec.shape.crs.srid)))
407 if rec.uid and isinstance(rec.uid, int):
408 gd_feature.SetFID(rec.uid)
410 for col in desc.columns:
411 if col.geometryType or col.isPrimaryKey:
412 continue
413 val = rec.attributes.get(col.name)
414 if val is None:
415 continue
416 try:
417 _attr_to_ogr(gd_feature, int(col.nativeType), col.columnIndex, val, self.dso.encoding)
418 except Exception as exc:
419 raise Error(f'field cannot be set: {col.name=} {val=}') from exc
421 self.gdLayer.CreateFeature(gd_feature)
422 fids.append(gd_feature.GetFID())
424 return fids
426 def count(self, force=False):
427 return self.gdLayer.GetFeatureCount(force=1 if force else 0)
429 def get_all(self) -> list[gws.FeatureRecord]:
430 return list(self.iter_features())
432 def iter_features(self) -> Iterable[gws.FeatureRecord]:
433 self.gdLayer.ResetReading()
435 while True:
436 gd_feature = self.gdLayer.GetNextFeature()
437 if not gd_feature:
438 break
439 yield self._feature_record(gd_feature)
441 def get(self, fid: int) -> Optional[gws.FeatureRecord]:
442 gd_feature = self.gdLayer.GetFeature(fid)
443 if gd_feature:
444 return self._feature_record(gd_feature)
446 def _feature_record(self, gd_feature):
447 rec = gws.FeatureRecord(
448 attributes={},
449 shape=None,
450 meta={'layerName': self.name},
451 uid=str(gd_feature.GetFID()),
452 )
454 for i in range(gd_feature.GetFieldCount()):
455 gd_field_defn: ogr.FieldDefn = gd_feature.GetFieldDefnRef(i)
456 name = gd_field_defn.GetName()
457 val = _attr_from_ogr(gd_feature, gd_field_defn.type, i, self.dso.encoding)
458 rec.attributes[name] = val
460 cnt = gd_feature.GetGeomFieldCount()
461 if cnt > 0:
462 # NB take the last geom
463 # @TODO multigeometry support
464 gd_geom_defn = gd_feature.GetGeomFieldRef(cnt - 1)
465 if gd_geom_defn:
466 srid = _srid_from_srs(gd_geom_defn.GetSpatialReference()) or self.dso.defaultCrs.srid
467 wkt = gd_geom_defn.ExportToWkt()
468 if self.dso.geometryAsText:
469 rec.ewkt = f'SRID={srid};{wkt}'
470 else:
471 rec.shape = gws.base.shape.from_wkt(wkt, gws.lib.crs.get(srid))
473 return rec
476##
479def _driver_from_args(path, driver_name, need_raster):
480 di = _fetch_driver_infos()
482 if not driver_name:
483 ext = path.split('.')[-1]
484 names = di.extToName.get(ext)
485 if not names:
486 raise Error(f'no default driver found for {path!r}')
487 if len(names) > 1:
488 if ext in ('tif', 'tiff'):
489 driver_name = 'GTiff'
490 else:
491 raise Error(f'multiple drivers found for {path!r}: {names}')
492 else:
493 driver_name = names[0]
495 is_vector = driver_name in di.vectorNames
496 is_raster = driver_name in di.rasterNames
498 if need_raster:
499 if not is_raster:
500 raise Error(f'driver {driver_name!r} is not raster')
501 return gdal.GetDriverByName(driver_name)
503 if not is_vector:
504 raise Error(f'driver {driver_name!r} is not vector')
505 return ogr.GetDriverByName(driver_name)
508_di_cache: Optional[_DriverInfoCache] = None
511def _fetch_driver_infos() -> _DriverInfoCache:
512 global _di_cache
514 if _di_cache:
515 return _di_cache
517 _di_cache = _DriverInfoCache(
518 infos=[],
519 extToName={},
520 vectorNames=set(),
521 rasterNames=set(),
522 )
524 for n in range(gdal.GetDriverCount()):
525 drv = gdal.GetDriver(n)
526 inf = DriverInfo(index=n, name=str(drv.ShortName), longName=str(drv.LongName), metaData=dict(drv.GetMetadata() or {}))
527 _di_cache.infos.append(inf)
529 for e in inf.metaData.get(gdal.DMD_EXTENSIONS, '').split():
530 _di_cache.extToName.setdefault(e, []).append(inf.name)
531 if inf.metaData.get('DCAP_VECTOR') == 'YES':
532 _di_cache.vectorNames.add(inf.name)
533 if inf.metaData.get('DCAP_RASTER') == 'YES':
534 _di_cache.rasterNames.add(inf.name)
536 return _di_cache
539_name_to_srid = {}
542def _srs_from_srid(srid):
543 srs = osr.SpatialReference()
544 srs.ImportFromEPSG(srid)
545 return srs
548def _srid_from_srs(srs):
549 if not srs:
550 return 0
552 name = srs.GetName()
553 if not name:
554 wkt = srs.ExportToWkt()
555 gws.log.warning(f'gdalx: no name for SRS {wkt!r}')
556 return 0
558 if name in _name_to_srid:
559 return _name_to_srid[name]
561 srid = srs.GetAuthorityCode(None)
562 if not srid:
563 wkt = srs.ExportToWkt()
564 gws.log.warning(f'gdalx: no srid for SRS {wkt!r}')
565 srid = 0
567 _name_to_srid[name] = srid
568 return srid
571def _attr_from_ogr(gd_feature: ogr.Feature, gtype: int, idx: int, encoding: str = 'utf8'):
572 if gd_feature.IsFieldNull(idx):
573 return None
575 if gtype == ogr.OFTString:
576 b = gd_feature.GetFieldAsBinary(idx)
577 if encoding:
578 return b.decode(encoding)
579 return bytes(b)
581 if gtype in {ogr.OFTDate, ogr.OFTTime, ogr.OFTDateTime}:
582 # python GetFieldAsDateTime appears to use float seconds, as in
583 # GetFieldAsDateTime (int i, int *pnYear, int *pnMonth, int *pnDay, int *pnHour, int *pnMinute, float *pfSecond, int *pnTZFlag)
584 #
585 v = gd_feature.GetFieldAsDateTime(idx)
586 sec, fsec = divmod(v[5], 1)
587 try:
588 return datetimex.new(v[0], v[1], v[2], v[3], v[4], int(sec), int(fsec * 1e6), tz=_tzflag_to_tz(v[6]))
589 except ValueError:
590 return
592 if gtype == ogr.OFSTBoolean:
593 return gd_feature.GetFieldAsInteger(idx) != 0
594 if gtype in {ogr.OFTInteger, ogr.OFTInteger64}:
595 return gd_feature.GetFieldAsInteger(idx)
596 if gtype in {ogr.OFTIntegerList, ogr.OFTInteger64List}:
597 return gd_feature.GetFieldAsIntegerList(idx)
598 if gtype in {ogr.OFTReal, ogr.OFSTFloat32}:
599 return gd_feature.GetFieldAsDouble(idx)
600 if gtype == ogr.OFTRealList:
601 return gd_feature.GetFieldAsDoubleList(idx)
602 if gtype == ogr.OFTBinary:
603 return gd_feature.GetFieldAsBinary(idx)
606def _tzflag_to_tz(tzflag):
607 # see gdal/ogr/ogrutils.cpp OGRGetISO8601DateTime
609 if tzflag == 0 or tzflag == 1:
610 return ''
611 if tzflag == 100:
612 return 'UTC'
613 if tzflag % 4 != 0:
614 # @TODO
615 raise Error(f'unsupported timezone {tzflag=}')
616 hrs = (100 - tzflag) // 4
617 return f'Etc/GMT{hrs:+}'
620def _attr_to_ogr(gd_feature: ogr.Feature, gtype: int, idx: int, value, encoding):
621 if gtype == ogr.OFTDate:
622 return gd_feature.SetField(idx, datetimex.to_iso_date_string(value))
623 if gtype == ogr.OFTTime:
624 return gd_feature.SetField(idx, datetimex.to_iso_time_string(value))
625 if gtype == ogr.OFTDateTime:
626 return gd_feature.SetField(idx, datetimex.to_iso_string(value))
627 if gtype == ogr.OFSTBoolean:
628 return gd_feature.SetField(idx, bool(value))
629 if gtype in {ogr.OFTInteger, ogr.OFTInteger64}:
630 return gd_feature.SetField(idx, int(value))
631 if gtype in {ogr.OFTIntegerList, ogr.OFTInteger64List}:
632 return gd_feature.SetField(idx, [int(x) for x in value])
633 if gtype in {ogr.OFTReal, ogr.OFSTFloat32}:
634 return gd_feature.SetField(idx, float(value))
635 if gtype == ogr.OFTRealList:
636 return gd_feature.SetField(idx, [float(x) for x in value])
638 return gd_feature.SetField(idx, value)
641def is_attribute_supported(typ):
642 return typ in _ATTR_TO_OGR
645_ATTR_TO_OGR = {
646 gws.AttributeType.bool: ogr.OFTInteger,
647 gws.AttributeType.bytes: ogr.OFTBinary,
648 gws.AttributeType.date: ogr.OFTDate,
649 gws.AttributeType.datetime: ogr.OFTDateTime,
650 gws.AttributeType.float: ogr.OFTReal,
651 gws.AttributeType.floatlist: ogr.OFTRealList,
652 gws.AttributeType.int: ogr.OFTInteger,
653 gws.AttributeType.intlist: ogr.OFTIntegerList,
654 gws.AttributeType.str: ogr.OFTString,
655 gws.AttributeType.strlist: ogr.OFTStringList,
656 gws.AttributeType.time: ogr.OFTTime,
657}
659_OGR_TO_ATTR = {
660 ogr.OFTBinary: gws.AttributeType.bytes,
661 ogr.OFTDate: gws.AttributeType.date,
662 ogr.OFTDateTime: gws.AttributeType.datetime,
663 ogr.OFTReal: gws.AttributeType.float,
664 ogr.OFTRealList: gws.AttributeType.floatlist,
665 ogr.OFTInteger: gws.AttributeType.int,
666 ogr.OFTIntegerList: gws.AttributeType.intlist,
667 ogr.OFTInteger64: gws.AttributeType.int,
668 ogr.OFTInteger64List: gws.AttributeType.intlist,
669 ogr.OFTString: gws.AttributeType.str,
670 ogr.OFTStringList: gws.AttributeType.strlist,
671 ogr.OFTTime: gws.AttributeType.time,
672}
674_GEOM_TO_OGR = {
675 gws.GeometryType.curve: ogr.wkbCurve,
676 gws.GeometryType.geometrycollection: ogr.wkbGeometryCollection,
677 gws.GeometryType.linestring: ogr.wkbLineString,
678 gws.GeometryType.multicurve: ogr.wkbMultiCurve,
679 gws.GeometryType.multilinestring: ogr.wkbMultiLineString,
680 gws.GeometryType.multipoint: ogr.wkbMultiPoint,
681 gws.GeometryType.multipolygon: ogr.wkbMultiPolygon,
682 gws.GeometryType.multisurface: ogr.wkbMultiSurface,
683 gws.GeometryType.point: ogr.wkbPoint,
684 gws.GeometryType.polygon: ogr.wkbPolygon,
685 gws.GeometryType.polyhedralsurface: ogr.wkbPolyhedralSurface,
686 gws.GeometryType.surface: ogr.wkbSurface,
687}
689_OGR_TO_GEOM = {
690 ogr.wkbCurve: gws.GeometryType.curve,
691 ogr.wkbGeometryCollection: gws.GeometryType.geometrycollection,
692 ogr.wkbLineString: gws.GeometryType.linestring,
693 ogr.wkbMultiCurve: gws.GeometryType.multicurve,
694 ogr.wkbMultiLineString: gws.GeometryType.multilinestring,
695 ogr.wkbMultiPoint: gws.GeometryType.multipoint,
696 ogr.wkbMultiPolygon: gws.GeometryType.multipolygon,
697 ogr.wkbMultiSurface: gws.GeometryType.multisurface,
698 ogr.wkbPoint: gws.GeometryType.point,
699 ogr.wkbPolygon: gws.GeometryType.polygon,
700 ogr.wkbPolyhedralSurface: gws.GeometryType.polyhedralsurface,
701 ogr.wkbSurface: gws.GeometryType.surface,
702}