Coverage for gws-app/gws/plugin/postgres/provider.py: 71%

104 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-16 23:09 +0200

1"""Postgres database provider.""" 

2 

3from typing import Optional 

4 

5import os 

6import re 

7 

8import gws.base.database 

9import gws.lib.crs 

10import gws.lib.extent 

11import gws.lib.net 

12import gws.lib.sa as sa 

13 

14gws.ext.new.databaseProvider('postgres') 

15 

16 

17class Config(gws.base.database.provider.Config): 

18 """Postgres/Postgis database provider""" 

19 

20 database: Optional[str] 

21 """Database name.""" 

22 host: Optional[str] 

23 """Database host.""" 

24 port: int = 5432 

25 """Database port.""" 

26 username: Optional[str] 

27 """Username.""" 

28 password: Optional[str] 

29 """Password.""" 

30 serviceName: Optional[str] 

31 """Service name from pg_services file.""" 

32 options: Optional[dict] 

33 """Libpq connection options.""" 

34 

35 

36class Object(gws.base.database.provider.Object): 

37 def url(self): 

38 return connection_url(self.config) 

39 

40 _RE_TABLE_NAME = r'''(?x)  

41 ^ 

42 ( 

43 ( " (?P<a1> ([^"] | "")+ ) " ) 

44 | 

45 (?P<a2> [^".]+ ) 

46 ) 

47 ( 

48 \. 

49 ( 

50 ( " (?P<b1> ([^"] | "")+ ) " ) 

51 | 

52 (?P<b2> [^".]+ ) 

53 ) 

54 )? 

55 $ 

56 ''' 

57 

58 _DEFAULT_SCHEMA = 'public' 

59 

60 def split_table_name(self, table_name): 

61 m = re.match(self._RE_TABLE_NAME, table_name.strip()) 

62 if not m: 

63 raise ValueError(f'invalid table name {table_name!r}') 

64 

65 d = m.groupdict() 

66 s = d['a1'] or d['a2'] 

67 t = d['b1'] or d['b2'] 

68 if not t: 

69 s, t = self._DEFAULT_SCHEMA, s 

70 

71 return s.replace('""', '"'), t.replace('""', '"') 

72 

73 def join_table_name(self, schema, name): 

74 if schema: 

75 return schema + '.' + name 

76 schema, name2 = self.split_table_name(name) 

77 return schema + '.' + name2 

78 

79 def table_bounds(self, table): 

80 desc = self.describe(table) 

81 if not desc.geometryName: 

82 return 

83 

84 tab = self.table(table) 

85 sql = sa.select(sa.func.ST_Extent(tab.columns.get(desc.geometryName))) 

86 with self.connect() as conn: 

87 box = conn.execute(sql).scalar_one() 

88 extent = gws.lib.extent.from_box(box) 

89 if extent: 

90 return gws.Bounds(extent=extent, crs=gws.lib.crs.get(desc.geometrySrid)) 

91 

92 def describe_column(self, table, column_name): 

93 col = super().describe_column(table, column_name) 

94 

95 if col.nativeType == 'ARRAY': 

96 sa_col = self.column(table, column_name) 

97 it = getattr(sa_col.type, 'item_type', None) 

98 ia = self.SA_TO_ATTR.get(type(it).__name__.upper()) 

99 if ia == gws.AttributeType.str: 

100 col.type = gws.AttributeType.strlist 

101 elif ia == gws.AttributeType.int: 

102 col.type = gws.AttributeType.intlist 

103 elif ia == gws.AttributeType.float: 

104 col.type = gws.AttributeType.floatlist 

105 else: 

106 col.type = self.UNKNOWN_ARRAY_TYPE 

107 return col 

108 

109 if col.nativeType == 'GEOMETRY': 

110 typ, srid = self._get_geom_type_and_srid(table, column_name) 

111 col.type = gws.AttributeType.geometry 

112 col.geometryType = self.SA_TO_GEOM.get(typ, gws.GeometryType.geometry) 

113 col.geometrySrid = srid 

114 return col 

115 

116 return col 

117 

118 def _get_geom_type_and_srid(self, table, column_name): 

119 sa_table = self.table(table) 

120 sa_col = self.column(table, column_name) 

121 

122 typ = getattr(sa_col.type, 'geometry_type', '').upper() 

123 srid = getattr(sa_col.type, 'srid', 0) 

124 

125 if typ != 'GEOMETRY' and srid > 0: 

126 return typ, srid 

127 

128 # not a typmod, possibly constraint-based. Query "geometry_columns"... 

129 

130 gcs = getattr(self, '_geometry_columns_cache', None) 

131 if not gcs: 

132 gcs = self.select_text(f''' 

133 SELECT  

134 f_table_schema, 

135 f_table_name, 

136 f_geometry_column, 

137 type, 

138 srid 

139 FROM public.geometry_columns 

140 ''') 

141 setattr(self, '_geometry_columns_cache', gcs) 

142 

143 for gc in gcs: 

144 if ( 

145 gc['f_table_schema'] == sa_table.schema 

146 and gc['f_table_name'] == sa_table.name 

147 and gc['f_geometry_column'] == sa_col.name 

148 ): 

149 return gc['type'], gc['srid'] 

150 

151 return 'GEOMETRY', -1 

152 

153 

154## 

155 

156def connection_url(cfg: gws.Config): 

157 # https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING 

158 # https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS 

159 

160 defaults = { 

161 'application_name': 'GWS', 

162 } 

163 

164 params = gws.u.merge(defaults, cfg.get('options')) 

165 

166 p = cfg.get('host') 

167 if p: 

168 return gws.lib.net.make_url( 

169 scheme='postgresql', 

170 username=cfg.get('username'), 

171 password=cfg.get('password'), 

172 hostname=p, 

173 port=cfg.get('port'), 

174 path=cfg.get('database') or cfg.get('dbname') or '', 

175 params=params, 

176 ) 

177 

178 p = cfg.get('serviceName') 

179 if p: 

180 s = os.getenv('PGSERVICEFILE') 

181 if not s or not os.path.isfile(s): 

182 raise sa.Error(f'PGSERVICEFILE {s!r} not found') 

183 

184 params['service'] = p 

185 

186 return gws.lib.net.make_url( 

187 scheme='postgresql', 

188 hostname='', 

189 path=cfg.get('database') or cfg.get('dbname') or '', 

190 params=params, 

191 )