Coverage for gws-app/gws/base/database/provider.py: 80%

177 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-16 22:59 +0200

1import threading 

2from typing import Optional, cast 

3 

4import gws 

5import gws.lib.sa as sa 

6 

7from . import connection 

8 

9 

10class Config(gws.Config): 

11 """Database provider""" 

12 

13 schemaCacheLifeTime: gws.Duration = '3600' 

14 """Life time for schema caches.""" 

15 withPool: Optional[bool] = False 

16 """Use connection pooling""" 

17 pool: Optional[dict] 

18 """Options for connection pooling.""" 

19 

20 

21_thread_local = threading.local() 

22 

23 

24class Object(gws.DatabaseProvider): 

25 saEngine: sa.Engine 

26 saMetaMap: dict[str, sa.MetaData] 

27 

28 def __getstate__(self): 

29 return gws.u.omit(vars(self), 'saMetaMap', 'saEngine') 

30 

31 def configure(self): 

32 # init a dummy engine just to check things 

33 self.saEngine = self.create_engine(poolclass=sa.NullPool) 

34 self.saMetaMap = {} 

35 

36 def activate(self): 

37 self.saEngine = self.create_engine() 

38 self.saMetaMap = {} 

39 

40 def engine(self): 

41 eng = getattr(self, 'saEngine', None) 

42 if eng is not None: 

43 return eng 

44 self.saEngine = self.create_engine() 

45 return self.saEngine 

46 

47 def create_engine(self, **kwargs): 

48 eng = sa.create_engine(self.url(), **self.engine_options(**kwargs)) 

49 # setattr(eng, '_connection_cls', connection.Object) 

50 return eng 

51 

52 def engine_options(self, **kwargs): 

53 if self.root.app.developer_option('db.engine_echo'): 

54 kwargs.setdefault('echo', True) 

55 kwargs.setdefault('echo_pool', True) 

56 

57 if self.cfg('withPool') is False: 

58 kwargs.setdefault('poolclass', sa.NullPool) 

59 return kwargs 

60 

61 pool = self.cfg('pool') or {} 

62 p = pool.get('disabled') 

63 if p is True: 

64 kwargs.setdefault('poolclass', sa.NullPool) 

65 return kwargs 

66 

67 p = pool.get('pre_ping') 

68 if p is True: 

69 kwargs.setdefault('pool_pre_ping', True) 

70 p = pool.get('size') 

71 if isinstance(p, int): 

72 kwargs.setdefault('pool_size', p) 

73 p = pool.get('recycle') 

74 if isinstance(p, int): 

75 kwargs.setdefault('pool_recycle', p) 

76 p = pool.get('timeout') 

77 if isinstance(p, int): 

78 kwargs.setdefault('pool_timeout', p) 

79 

80 return kwargs 

81 

82 def reflect_schema(self, schema: str): 

83 if schema in self.saMetaMap: 

84 return 

85 

86 def _load(): 

87 md = sa.MetaData(schema=schema) 

88 

89 # introspecting the whole schema is generally faster 

90 # but what if we only need a single table from a big schema? 

91 # @TODO add options for reflection 

92 

93 gws.debug.time_start(f'AUTOLOAD {self.uid=} {schema=}') 

94 with self.connect() as conn: 

95 md.reflect(conn.saConn, schema, resolve_fks=False, views=True) 

96 gws.debug.time_end() 

97 return md 

98 

99 life_time = self.cfg('schemaCacheLifeTime', 0) 

100 if not life_time: 

101 self.saMetaMap[schema] = _load() 

102 else: 

103 self.saMetaMap[schema] = gws.u.get_cached_object(f'database_metadata_schema_{schema}', life_time, _load) 

104 

105 def connect(self): 

106 conn = self._open_connection() 

107 return connection.Object(self, conn) 

108 

109 def _sa_connection(self) -> sa.Connection | None: 

110 return getattr(_thread_local, '_connection', None) 

111 

112 def _open_connection(self) -> sa.Connection: 

113 conn = getattr(_thread_local, '_connection', None) 

114 cc = getattr(_thread_local, '_connectionCount', 0) 

115 

116 if conn is None: 

117 assert cc == 0 

118 conn = self.engine().connect() 

119 setattr(_thread_local, '_connection', conn) 

120 else: 

121 assert cc > 0 

122 

123 setattr(_thread_local, '_connectionCount', cc + 1) 

124 # gws.log.debug(f'db.connect: open: {cc + 1}') 

125 return conn 

126 

127 def _close_connection(self): 

128 conn = getattr(_thread_local, '_connection', None) 

129 cc = getattr(_thread_local, '_connectionCount', 0) 

130 assert conn is not None 

131 assert cc > 0 

132 # gws.log.debug(f'db.connect: close: {cc}') 

133 if cc == 1: 

134 if conn: 

135 conn.close() 

136 setattr(_thread_local, '_connection', None) 

137 setattr(_thread_local, '_connectionCount', 0) 

138 else: 

139 setattr(_thread_local, '_connectionCount', cc - 1) 

140 

141 def table(self, table, **kwargs): 

142 tab = self._sa_table(table) 

143 if tab is None: 

144 raise sa.Error(f'table {str(table)} not found') 

145 return tab 

146 

147 def count(self, table): 

148 tab = self._sa_table(table) 

149 if tab is None: 

150 return 0 

151 sql = sa.select(sa.func.count()).select_from(tab) 

152 with self.connect() as conn: 

153 return conn.fetch_int(sql) 

154 

155 def has_table(self, table_name: str): 

156 tab = self._sa_table(table_name) 

157 return tab is not None 

158 

159 def _sa_table(self, tab_or_name) -> sa.Table | None: 

160 if isinstance(tab_or_name, sa.Table): 

161 return tab_or_name 

162 schema, name = self.split_table_name(tab_or_name) 

163 self.reflect_schema(schema) 

164 # see _get_table_key in sqlalchemy/sql/schema.py 

165 table_key = schema + '.' + name 

166 sm = self.saMetaMap[schema] 

167 if sm is None: 

168 raise sa.Error(f'schema {schema!r} not found') 

169 return sm.tables.get(table_key) 

170 

171 def column(self, table, column_name): 

172 tab = self.table(table) 

173 try: 

174 return tab.columns[column_name] 

175 except KeyError: 

176 raise sa.Error(f'column {str(table)}.{column_name!r} not found') 

177 

178 def has_column(self, table, column_name): 

179 tab = self._sa_table(table) 

180 return tab is not None and column_name in tab.columns 

181 

182 def select_text(self, sql, **kwargs): 

183 with self.connect() as conn: 

184 try: 

185 return [gws.u.to_dict(r) for r in conn.execute(sa.text(sql), kwargs)] 

186 except sa.Error: 

187 conn.rollback() 

188 raise 

189 

190 def execute_text(self, sql, **kwargs): 

191 with self.connect() as conn: 

192 try: 

193 res = conn.execute(sa.text(sql), kwargs) 

194 conn.commit() 

195 return res 

196 except sa.Error: 

197 conn.rollback() 

198 raise 

199 

200 SA_TO_ATTR = { 

201 # common: sqlalchemy.sql.sqltypes 

202 'BIGINT': gws.AttributeType.int, 

203 'BOOLEAN': gws.AttributeType.bool, 

204 'CHAR': gws.AttributeType.str, 

205 'DATE': gws.AttributeType.date, 

206 'DOUBLE_PRECISION': gws.AttributeType.float, 

207 'INTEGER': gws.AttributeType.int, 

208 'NUMERIC': gws.AttributeType.float, 

209 'REAL': gws.AttributeType.float, 

210 'SMALLINT': gws.AttributeType.int, 

211 'TEXT': gws.AttributeType.str, 

212 # 'UUID': ..., 

213 'VARCHAR': gws.AttributeType.str, 

214 # postgres specific: sqlalchemy.dialects.postgresql.types 

215 # 'JSON': ..., 

216 # 'JSONB': ..., 

217 # 'BIT': ..., 

218 'BYTEA': gws.AttributeType.bytes, 

219 # 'CIDR': ..., 

220 # 'INET': ..., 

221 # 'MACADDR': ..., 

222 # 'MACADDR8': ..., 

223 # 'MONEY': ..., 

224 'TIME': gws.AttributeType.time, 

225 'TIMESTAMP': gws.AttributeType.datetime, 

226 } 

227 

228 # @TODO proper support for Z/M geoms 

229 

230 SA_TO_GEOM = { 

231 'POINT': gws.GeometryType.point, 

232 'POINTM': gws.GeometryType.point, 

233 'POINTZ': gws.GeometryType.point, 

234 'POINTZM': gws.GeometryType.point, 

235 'LINESTRING': gws.GeometryType.linestring, 

236 'LINESTRINGM': gws.GeometryType.linestring, 

237 'LINESTRINGZ': gws.GeometryType.linestring, 

238 'LINESTRINGZM': gws.GeometryType.linestring, 

239 'POLYGON': gws.GeometryType.polygon, 

240 'POLYGONM': gws.GeometryType.polygon, 

241 'POLYGONZ': gws.GeometryType.polygon, 

242 'POLYGONZM': gws.GeometryType.polygon, 

243 'MULTIPOINT': gws.GeometryType.multipoint, 

244 'MULTIPOINTM': gws.GeometryType.multipoint, 

245 'MULTIPOINTZ': gws.GeometryType.multipoint, 

246 'MULTIPOINTZM': gws.GeometryType.multipoint, 

247 'MULTILINESTRING': gws.GeometryType.multilinestring, 

248 'MULTILINESTRINGM': gws.GeometryType.multilinestring, 

249 'MULTILINESTRINGZ': gws.GeometryType.multilinestring, 

250 'MULTILINESTRINGZM': gws.GeometryType.multilinestring, 

251 'MULTIPOLYGON': gws.GeometryType.multipolygon, 

252 # 'GEOMETRYCOLLECTION': gws.GeometryType.geometrycollection, 

253 # 'CURVE': gws.GeometryType.curve, 

254 } 

255 

256 UNKNOWN_TYPE = gws.AttributeType.str 

257 UNKNOWN_ARRAY_TYPE = gws.AttributeType.strlist 

258 

259 def describe(self, table): 

260 tab = self._sa_table(table) 

261 if tab is None: 

262 raise sa.Error(f'table {table!r} not found') 

263 

264 schema = tab.schema 

265 name = tab.name 

266 

267 desc = gws.DataSetDescription( 

268 columns=[], 

269 columnMap={}, 

270 fullName=self.join_table_name(schema or '', name), 

271 geometryName='', 

272 geometrySrid=0, 

273 geometryType='', 

274 name=name, 

275 schema=schema, 

276 ) 

277 

278 for n, sa_col in enumerate(cast(list[sa.Column], tab.columns)): 

279 col = self.describe_column(table, sa_col.name) 

280 col.columnIndex = n 

281 desc.columns.append(col) 

282 desc.columnMap[col.name] = col 

283 

284 for col in desc.columns: 

285 if col.geometryType: 

286 desc.geometryName = col.name 

287 desc.geometryType = col.geometryType 

288 desc.geometrySrid = col.geometrySrid 

289 break 

290 

291 return desc 

292 

293 def describe_column(self, table, column_name) -> gws.ColumnDescription: 

294 sa_col = self.column(table, column_name) 

295 

296 col = gws.ColumnDescription( 

297 columnIndex=0, 

298 comment=str(sa_col.comment or ''), 

299 default=sa_col.default, 

300 geometrySrid=0, 

301 geometryType='', 

302 isAutoincrement=bool(sa_col.autoincrement), 

303 isNullable=bool(sa_col.nullable), 

304 isPrimaryKey=bool(sa_col.primary_key), 

305 isUnique=bool(sa_col.unique), 

306 hasDefault=sa_col.server_default is not None, 

307 name=str(sa_col.name), 

308 nativeType='', 

309 type='', 

310 ) 

311 

312 col.nativeType = type(sa_col.type).__name__.upper() 

313 col.type = self.SA_TO_ATTR.get(col.nativeType, self.UNKNOWN_TYPE) 

314 

315 return col 

316 

317 

318##