Coverage for gws-app / gws / base / database / provider.py: 78%

186 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-03 10:12 +0100

1import threading 

2from typing import Optional, cast 

3 

4import gws 

5import gws.lib.sa as sa 

6 

7from . import connection 

8 

9 

10class Config(gws.Config): 

11 """Database provider""" 

12 

13 schemaCacheLifeTime: gws.Duration = '3600' 

14 """Life time for schema caches.""" 

15 withPool: Optional[bool] = False 

16 """Use connection pooling""" 

17 pool: Optional[dict] 

18 """Options for connection pooling.""" 

19 

20 

21_thread_local = threading.local() 

22 

23 

24class Object(gws.DatabaseProvider): 

25 saEngine: sa.Engine 

26 saMetaMap: dict[str, sa.MetaData] 

27 

28 def __getstate__(self): 

29 return gws.u.omit(vars(self), 'saMetaMap', 'saEngine') 

30 

31 def configure(self): 

32 # init a dummy engine just to check things 

33 self.saEngine = self.create_engine(poolclass=sa.NullPool) 

34 self.saMetaMap = {} 

35 

36 def activate(self): 

37 self.saEngine = self.create_engine() 

38 self.saMetaMap = {} 

39 

40 def engine(self): 

41 eng = getattr(self, 'saEngine', None) 

42 if eng is not None: 

43 return eng 

44 self.saEngine = self.create_engine() 

45 return self.saEngine 

46 

47 def create_engine(self, **kwargs): 

48 eng = sa.create_engine(self.url(), **self.engine_options(**kwargs)) 

49 # setattr(eng, '_connection_cls', connection.Object) 

50 return eng 

51 

52 def engine_options(self, **kwargs): 

53 if self.root.app.developer_option('db.engine_echo'): 

54 kwargs.setdefault('echo', True) 

55 kwargs.setdefault('echo_pool', True) 

56 

57 if self.cfg('withPool') is False: 

58 kwargs.setdefault('poolclass', sa.NullPool) 

59 return kwargs 

60 

61 pool = self.cfg('pool') or {} 

62 p = pool.get('disabled') 

63 if p is True: 

64 kwargs.setdefault('poolclass', sa.NullPool) 

65 return kwargs 

66 

67 p = pool.get('pre_ping') 

68 if p is True: 

69 kwargs.setdefault('pool_pre_ping', True) 

70 p = pool.get('size') 

71 if isinstance(p, int): 

72 kwargs.setdefault('pool_size', p) 

73 p = pool.get('recycle') 

74 if isinstance(p, int): 

75 kwargs.setdefault('pool_recycle', p) 

76 p = pool.get('timeout') 

77 if isinstance(p, int): 

78 kwargs.setdefault('pool_timeout', p) 

79 

80 return kwargs 

81 

82 def inspect_schema(self, schema, options=None): 

83 if options and options.refresh: 

84 self.saMetaMap.pop(schema, None) 

85 

86 if schema in self.saMetaMap: 

87 return 

88 

89 def _load(): 

90 md = sa.MetaData(schema=schema) 

91 

92 # introspecting the whole schema is generally faster 

93 # but what if we only need a single table from a big schema? 

94 # @TODO add options for reflection 

95 

96 gws.debug.time_start(f'AUTOLOAD {self.uid=} {schema=}') 

97 with self.connect() as conn: 

98 md.reflect(conn.saConn, schema, resolve_fks=False, views=True) 

99 gws.debug.time_end() 

100 return md 

101 

102 life_time = self.cfg('schemaCacheLifeTime', 0) 

103 if options and options.cacheLifeTime is not None: 

104 life_time = options.cacheLifeTime 

105 if not life_time: 

106 self.saMetaMap[schema] = _load() 

107 else: 

108 self.saMetaMap[schema] = gws.u.get_cached_object(f'database_metadata_schema_{schema}', life_time, _load) 

109 

110 def connect(self): 

111 conn = self._open_connection() 

112 return connection.Object(self, conn) 

113 

114 def _sa_connection(self) -> sa.Connection | None: 

115 return getattr(_thread_local, '_connection', None) 

116 

117 def _open_connection(self) -> sa.Connection: 

118 conn = getattr(_thread_local, '_connection', None) 

119 cc = getattr(_thread_local, '_connectionCount', 0) 

120 

121 if conn is None: 

122 assert cc == 0 

123 conn = self.engine().connect() 

124 setattr(_thread_local, '_connection', conn) 

125 else: 

126 assert cc > 0 

127 

128 setattr(_thread_local, '_connectionCount', cc + 1) 

129 # gws.log.debug(f'db.connect: open: {cc + 1}') 

130 return conn 

131 

132 def _close_connection(self): 

133 conn = getattr(_thread_local, '_connection', None) 

134 cc = getattr(_thread_local, '_connectionCount', 0) 

135 assert conn is not None 

136 assert cc > 0 

137 # gws.log.debug(f'db.connect: close: {cc}') 

138 if cc == 1: 

139 if conn: 

140 conn.close() 

141 setattr(_thread_local, '_connection', None) 

142 setattr(_thread_local, '_connectionCount', 0) 

143 else: 

144 setattr(_thread_local, '_connectionCount', cc - 1) 

145 

146 def table(self, table, **kwargs): 

147 tab = self._sa_table(table) 

148 if tab is None: 

149 raise sa.Error(f'table {str(table)} not found') 

150 return tab 

151 

152 def count(self, table): 

153 tab = self._sa_table(table) 

154 if tab is None: 

155 return 0 

156 sql = sa.select(sa.func.count()).select_from(tab) 

157 with self.connect() as conn: 

158 return conn.fetch_int(sql) 

159 

160 def has_schema(self, schema): 

161 return schema in self.schema_names() 

162 

163 def schema_names(self): 

164 inspector = sa.inspect(self.engine()) 

165 return inspector.get_schema_names() 

166 

167 def has_table(self, table_name: str): 

168 tab = self._sa_table(table_name) 

169 return tab is not None 

170 

171 def _sa_table(self, tab_or_name) -> sa.Table | None: 

172 if isinstance(tab_or_name, sa.Table): 

173 return tab_or_name 

174 schema, name = self.split_table_name(tab_or_name) 

175 self.inspect_schema(schema) 

176 # see _get_table_key in sqlalchemy/sql/schema.py 

177 table_key = schema + '.' + name 

178 sm = self.saMetaMap.get(schema) 

179 if sm is None: 

180 raise sa.Error(f'schema {schema!r} not found') 

181 return sm.tables.get(table_key) 

182 

183 def column(self, table, column_name): 

184 tab = self.table(table) 

185 try: 

186 return tab.columns[column_name] 

187 except KeyError: 

188 raise sa.Error(f'column {str(table)}.{column_name!r} not found') 

189 

190 def has_column(self, table, column_name): 

191 tab = self._sa_table(table) 

192 return tab is not None and column_name in tab.columns 

193 

194 def select_text(self, sql, **kwargs): 

195 with self.connect() as conn: 

196 try: 

197 return [gws.u.to_dict(r) for r in conn.execute(sa.text(sql), kwargs)] 

198 except sa.Error: 

199 conn.rollback() 

200 raise 

201 

202 def execute_text(self, sql, **kwargs): 

203 with self.connect() as conn: 

204 try: 

205 res = conn.execute(sa.text(sql), kwargs) 

206 conn.commit() 

207 return res 

208 except sa.Error: 

209 conn.rollback() 

210 raise 

211 

212 SA_TO_ATTR = { 

213 # common: sqlalchemy.sql.sqltypes 

214 'BIGINT': gws.AttributeType.int, 

215 'BOOLEAN': gws.AttributeType.bool, 

216 'CHAR': gws.AttributeType.str, 

217 'DATE': gws.AttributeType.date, 

218 'DOUBLE_PRECISION': gws.AttributeType.float, 

219 'INTEGER': gws.AttributeType.int, 

220 'NUMERIC': gws.AttributeType.float, 

221 'REAL': gws.AttributeType.float, 

222 'SMALLINT': gws.AttributeType.int, 

223 'TEXT': gws.AttributeType.str, 

224 # 'UUID': ..., 

225 'VARCHAR': gws.AttributeType.str, 

226 # postgres specific: sqlalchemy.dialects.postgresql.types 

227 # 'JSON': ..., 

228 # 'JSONB': ..., 

229 # 'BIT': ..., 

230 'BYTEA': gws.AttributeType.bytes, 

231 # 'CIDR': ..., 

232 # 'INET': ..., 

233 # 'MACADDR': ..., 

234 # 'MACADDR8': ..., 

235 # 'MONEY': ..., 

236 'TIME': gws.AttributeType.time, 

237 'TIMESTAMP': gws.AttributeType.datetime, 

238 } 

239 

240 # @TODO proper support for Z/M geoms 

241 

242 SA_TO_GEOM = { 

243 'POINT': gws.GeometryType.point, 

244 'POINTM': gws.GeometryType.point, 

245 'POINTZ': gws.GeometryType.point, 

246 'POINTZM': gws.GeometryType.point, 

247 'LINESTRING': gws.GeometryType.linestring, 

248 'LINESTRINGM': gws.GeometryType.linestring, 

249 'LINESTRINGZ': gws.GeometryType.linestring, 

250 'LINESTRINGZM': gws.GeometryType.linestring, 

251 'POLYGON': gws.GeometryType.polygon, 

252 'POLYGONM': gws.GeometryType.polygon, 

253 'POLYGONZ': gws.GeometryType.polygon, 

254 'POLYGONZM': gws.GeometryType.polygon, 

255 'MULTIPOINT': gws.GeometryType.multipoint, 

256 'MULTIPOINTM': gws.GeometryType.multipoint, 

257 'MULTIPOINTZ': gws.GeometryType.multipoint, 

258 'MULTIPOINTZM': gws.GeometryType.multipoint, 

259 'MULTILINESTRING': gws.GeometryType.multilinestring, 

260 'MULTILINESTRINGM': gws.GeometryType.multilinestring, 

261 'MULTILINESTRINGZ': gws.GeometryType.multilinestring, 

262 'MULTILINESTRINGZM': gws.GeometryType.multilinestring, 

263 'MULTIPOLYGON': gws.GeometryType.multipolygon, 

264 # 'GEOMETRYCOLLECTION': gws.GeometryType.geometrycollection, 

265 # 'CURVE': gws.GeometryType.curve, 

266 } 

267 

268 UNKNOWN_TYPE = gws.AttributeType.str 

269 UNKNOWN_ARRAY_TYPE = gws.AttributeType.strlist 

270 

271 def describe(self, table): 

272 tab = self._sa_table(table) 

273 if tab is None: 

274 raise sa.Error(f'table {table!r} not found') 

275 

276 schema = tab.schema 

277 name = tab.name 

278 

279 desc = gws.DataSetDescription( 

280 columns=[], 

281 columnMap={}, 

282 fullName=self.join_table_name(schema or '', name), 

283 geometryName='', 

284 geometrySrid=0, 

285 geometryType='', 

286 name=name, 

287 schema=schema, 

288 ) 

289 

290 for n, sa_col in enumerate(cast(list[sa.Column], tab.columns)): 

291 col = self.describe_column(table, sa_col.name) 

292 col.columnIndex = n 

293 desc.columns.append(col) 

294 desc.columnMap[col.name] = col 

295 

296 for col in desc.columns: 

297 if col.geometryType: 

298 desc.geometryName = col.name 

299 desc.geometryType = col.geometryType 

300 desc.geometrySrid = col.geometrySrid 

301 break 

302 

303 return desc 

304 

305 def describe_column(self, table, column_name) -> gws.ColumnDescription: 

306 sa_col = self.column(table, column_name) 

307 

308 col = gws.ColumnDescription( 

309 columnIndex=0, 

310 comment=str(sa_col.comment or ''), 

311 default=sa_col.default, 

312 geometrySrid=0, 

313 geometryType='', 

314 isAutoincrement=bool(sa_col.autoincrement), 

315 isNullable=bool(sa_col.nullable), 

316 isPrimaryKey=bool(sa_col.primary_key), 

317 isUnique=bool(sa_col.unique), 

318 hasDefault=sa_col.server_default is not None, 

319 name=str(sa_col.name), 

320 nativeType='', 

321 type='', 

322 ) 

323 

324 col.nativeType = type(sa_col.type).__name__.upper() 

325 col.type = self.SA_TO_ATTR.get(col.nativeType, self.UNKNOWN_TYPE) 

326 

327 return col 

328 

329 

330##