Coverage for gws-app / gws / base / database / provider.py: 78%
186 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-03 10:12 +0100
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-03 10:12 +0100
1import threading
2from typing import Optional, cast
4import gws
5import gws.lib.sa as sa
7from . import connection
10class Config(gws.Config):
11 """Database provider"""
13 schemaCacheLifeTime: gws.Duration = '3600'
14 """Life time for schema caches."""
15 withPool: Optional[bool] = False
16 """Use connection pooling"""
17 pool: Optional[dict]
18 """Options for connection pooling."""
21_thread_local = threading.local()
24class Object(gws.DatabaseProvider):
25 saEngine: sa.Engine
26 saMetaMap: dict[str, sa.MetaData]
28 def __getstate__(self):
29 return gws.u.omit(vars(self), 'saMetaMap', 'saEngine')
31 def configure(self):
32 # init a dummy engine just to check things
33 self.saEngine = self.create_engine(poolclass=sa.NullPool)
34 self.saMetaMap = {}
36 def activate(self):
37 self.saEngine = self.create_engine()
38 self.saMetaMap = {}
40 def engine(self):
41 eng = getattr(self, 'saEngine', None)
42 if eng is not None:
43 return eng
44 self.saEngine = self.create_engine()
45 return self.saEngine
47 def create_engine(self, **kwargs):
48 eng = sa.create_engine(self.url(), **self.engine_options(**kwargs))
49 # setattr(eng, '_connection_cls', connection.Object)
50 return eng
52 def engine_options(self, **kwargs):
53 if self.root.app.developer_option('db.engine_echo'):
54 kwargs.setdefault('echo', True)
55 kwargs.setdefault('echo_pool', True)
57 if self.cfg('withPool') is False:
58 kwargs.setdefault('poolclass', sa.NullPool)
59 return kwargs
61 pool = self.cfg('pool') or {}
62 p = pool.get('disabled')
63 if p is True:
64 kwargs.setdefault('poolclass', sa.NullPool)
65 return kwargs
67 p = pool.get('pre_ping')
68 if p is True:
69 kwargs.setdefault('pool_pre_ping', True)
70 p = pool.get('size')
71 if isinstance(p, int):
72 kwargs.setdefault('pool_size', p)
73 p = pool.get('recycle')
74 if isinstance(p, int):
75 kwargs.setdefault('pool_recycle', p)
76 p = pool.get('timeout')
77 if isinstance(p, int):
78 kwargs.setdefault('pool_timeout', p)
80 return kwargs
82 def inspect_schema(self, schema, options=None):
83 if options and options.refresh:
84 self.saMetaMap.pop(schema, None)
86 if schema in self.saMetaMap:
87 return
89 def _load():
90 md = sa.MetaData(schema=schema)
92 # introspecting the whole schema is generally faster
93 # but what if we only need a single table from a big schema?
94 # @TODO add options for reflection
96 gws.debug.time_start(f'AUTOLOAD {self.uid=} {schema=}')
97 with self.connect() as conn:
98 md.reflect(conn.saConn, schema, resolve_fks=False, views=True)
99 gws.debug.time_end()
100 return md
102 life_time = self.cfg('schemaCacheLifeTime', 0)
103 if options and options.cacheLifeTime is not None:
104 life_time = options.cacheLifeTime
105 if not life_time:
106 self.saMetaMap[schema] = _load()
107 else:
108 self.saMetaMap[schema] = gws.u.get_cached_object(f'database_metadata_schema_{schema}', life_time, _load)
110 def connect(self):
111 conn = self._open_connection()
112 return connection.Object(self, conn)
114 def _sa_connection(self) -> sa.Connection | None:
115 return getattr(_thread_local, '_connection', None)
117 def _open_connection(self) -> sa.Connection:
118 conn = getattr(_thread_local, '_connection', None)
119 cc = getattr(_thread_local, '_connectionCount', 0)
121 if conn is None:
122 assert cc == 0
123 conn = self.engine().connect()
124 setattr(_thread_local, '_connection', conn)
125 else:
126 assert cc > 0
128 setattr(_thread_local, '_connectionCount', cc + 1)
129 # gws.log.debug(f'db.connect: open: {cc + 1}')
130 return conn
132 def _close_connection(self):
133 conn = getattr(_thread_local, '_connection', None)
134 cc = getattr(_thread_local, '_connectionCount', 0)
135 assert conn is not None
136 assert cc > 0
137 # gws.log.debug(f'db.connect: close: {cc}')
138 if cc == 1:
139 if conn:
140 conn.close()
141 setattr(_thread_local, '_connection', None)
142 setattr(_thread_local, '_connectionCount', 0)
143 else:
144 setattr(_thread_local, '_connectionCount', cc - 1)
146 def table(self, table, **kwargs):
147 tab = self._sa_table(table)
148 if tab is None:
149 raise sa.Error(f'table {str(table)} not found')
150 return tab
152 def count(self, table):
153 tab = self._sa_table(table)
154 if tab is None:
155 return 0
156 sql = sa.select(sa.func.count()).select_from(tab)
157 with self.connect() as conn:
158 return conn.fetch_int(sql)
160 def has_schema(self, schema):
161 return schema in self.schema_names()
163 def schema_names(self):
164 inspector = sa.inspect(self.engine())
165 return inspector.get_schema_names()
167 def has_table(self, table_name: str):
168 tab = self._sa_table(table_name)
169 return tab is not None
171 def _sa_table(self, tab_or_name) -> sa.Table | None:
172 if isinstance(tab_or_name, sa.Table):
173 return tab_or_name
174 schema, name = self.split_table_name(tab_or_name)
175 self.inspect_schema(schema)
176 # see _get_table_key in sqlalchemy/sql/schema.py
177 table_key = schema + '.' + name
178 sm = self.saMetaMap.get(schema)
179 if sm is None:
180 raise sa.Error(f'schema {schema!r} not found')
181 return sm.tables.get(table_key)
183 def column(self, table, column_name):
184 tab = self.table(table)
185 try:
186 return tab.columns[column_name]
187 except KeyError:
188 raise sa.Error(f'column {str(table)}.{column_name!r} not found')
190 def has_column(self, table, column_name):
191 tab = self._sa_table(table)
192 return tab is not None and column_name in tab.columns
194 def select_text(self, sql, **kwargs):
195 with self.connect() as conn:
196 try:
197 return [gws.u.to_dict(r) for r in conn.execute(sa.text(sql), kwargs)]
198 except sa.Error:
199 conn.rollback()
200 raise
202 def execute_text(self, sql, **kwargs):
203 with self.connect() as conn:
204 try:
205 res = conn.execute(sa.text(sql), kwargs)
206 conn.commit()
207 return res
208 except sa.Error:
209 conn.rollback()
210 raise
212 SA_TO_ATTR = {
213 # common: sqlalchemy.sql.sqltypes
214 'BIGINT': gws.AttributeType.int,
215 'BOOLEAN': gws.AttributeType.bool,
216 'CHAR': gws.AttributeType.str,
217 'DATE': gws.AttributeType.date,
218 'DOUBLE_PRECISION': gws.AttributeType.float,
219 'INTEGER': gws.AttributeType.int,
220 'NUMERIC': gws.AttributeType.float,
221 'REAL': gws.AttributeType.float,
222 'SMALLINT': gws.AttributeType.int,
223 'TEXT': gws.AttributeType.str,
224 # 'UUID': ...,
225 'VARCHAR': gws.AttributeType.str,
226 # postgres specific: sqlalchemy.dialects.postgresql.types
227 # 'JSON': ...,
228 # 'JSONB': ...,
229 # 'BIT': ...,
230 'BYTEA': gws.AttributeType.bytes,
231 # 'CIDR': ...,
232 # 'INET': ...,
233 # 'MACADDR': ...,
234 # 'MACADDR8': ...,
235 # 'MONEY': ...,
236 'TIME': gws.AttributeType.time,
237 'TIMESTAMP': gws.AttributeType.datetime,
238 }
240 # @TODO proper support for Z/M geoms
242 SA_TO_GEOM = {
243 'POINT': gws.GeometryType.point,
244 'POINTM': gws.GeometryType.point,
245 'POINTZ': gws.GeometryType.point,
246 'POINTZM': gws.GeometryType.point,
247 'LINESTRING': gws.GeometryType.linestring,
248 'LINESTRINGM': gws.GeometryType.linestring,
249 'LINESTRINGZ': gws.GeometryType.linestring,
250 'LINESTRINGZM': gws.GeometryType.linestring,
251 'POLYGON': gws.GeometryType.polygon,
252 'POLYGONM': gws.GeometryType.polygon,
253 'POLYGONZ': gws.GeometryType.polygon,
254 'POLYGONZM': gws.GeometryType.polygon,
255 'MULTIPOINT': gws.GeometryType.multipoint,
256 'MULTIPOINTM': gws.GeometryType.multipoint,
257 'MULTIPOINTZ': gws.GeometryType.multipoint,
258 'MULTIPOINTZM': gws.GeometryType.multipoint,
259 'MULTILINESTRING': gws.GeometryType.multilinestring,
260 'MULTILINESTRINGM': gws.GeometryType.multilinestring,
261 'MULTILINESTRINGZ': gws.GeometryType.multilinestring,
262 'MULTILINESTRINGZM': gws.GeometryType.multilinestring,
263 'MULTIPOLYGON': gws.GeometryType.multipolygon,
264 # 'GEOMETRYCOLLECTION': gws.GeometryType.geometrycollection,
265 # 'CURVE': gws.GeometryType.curve,
266 }
268 UNKNOWN_TYPE = gws.AttributeType.str
269 UNKNOWN_ARRAY_TYPE = gws.AttributeType.strlist
271 def describe(self, table):
272 tab = self._sa_table(table)
273 if tab is None:
274 raise sa.Error(f'table {table!r} not found')
276 schema = tab.schema
277 name = tab.name
279 desc = gws.DataSetDescription(
280 columns=[],
281 columnMap={},
282 fullName=self.join_table_name(schema or '', name),
283 geometryName='',
284 geometrySrid=0,
285 geometryType='',
286 name=name,
287 schema=schema,
288 )
290 for n, sa_col in enumerate(cast(list[sa.Column], tab.columns)):
291 col = self.describe_column(table, sa_col.name)
292 col.columnIndex = n
293 desc.columns.append(col)
294 desc.columnMap[col.name] = col
296 for col in desc.columns:
297 if col.geometryType:
298 desc.geometryName = col.name
299 desc.geometryType = col.geometryType
300 desc.geometrySrid = col.geometrySrid
301 break
303 return desc
305 def describe_column(self, table, column_name) -> gws.ColumnDescription:
306 sa_col = self.column(table, column_name)
308 col = gws.ColumnDescription(
309 columnIndex=0,
310 comment=str(sa_col.comment or ''),
311 default=sa_col.default,
312 geometrySrid=0,
313 geometryType='',
314 isAutoincrement=bool(sa_col.autoincrement),
315 isNullable=bool(sa_col.nullable),
316 isPrimaryKey=bool(sa_col.primary_key),
317 isUnique=bool(sa_col.unique),
318 hasDefault=sa_col.server_default is not None,
319 name=str(sa_col.name),
320 nativeType='',
321 type='',
322 )
324 col.nativeType = type(sa_col.type).__name__.upper()
325 col.type = self.SA_TO_ATTR.get(col.nativeType, self.UNKNOWN_TYPE)
327 return col
330##