Coverage for gws-app/gws/base/database/provider.py: 80%
177 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-16 22:59 +0200
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-16 22:59 +0200
1import threading
2from typing import Optional, cast
4import gws
5import gws.lib.sa as sa
7from . import connection
10class Config(gws.Config):
11 """Database provider"""
13 schemaCacheLifeTime: gws.Duration = '3600'
14 """Life time for schema caches."""
15 withPool: Optional[bool] = False
16 """Use connection pooling"""
17 pool: Optional[dict]
18 """Options for connection pooling."""
21_thread_local = threading.local()
24class Object(gws.DatabaseProvider):
25 saEngine: sa.Engine
26 saMetaMap: dict[str, sa.MetaData]
28 def __getstate__(self):
29 return gws.u.omit(vars(self), 'saMetaMap', 'saEngine')
31 def configure(self):
32 # init a dummy engine just to check things
33 self.saEngine = self.create_engine(poolclass=sa.NullPool)
34 self.saMetaMap = {}
36 def activate(self):
37 self.saEngine = self.create_engine()
38 self.saMetaMap = {}
40 def engine(self):
41 eng = getattr(self, 'saEngine', None)
42 if eng is not None:
43 return eng
44 self.saEngine = self.create_engine()
45 return self.saEngine
47 def create_engine(self, **kwargs):
48 eng = sa.create_engine(self.url(), **self.engine_options(**kwargs))
49 # setattr(eng, '_connection_cls', connection.Object)
50 return eng
52 def engine_options(self, **kwargs):
53 if self.root.app.developer_option('db.engine_echo'):
54 kwargs.setdefault('echo', True)
55 kwargs.setdefault('echo_pool', True)
57 if self.cfg('withPool') is False:
58 kwargs.setdefault('poolclass', sa.NullPool)
59 return kwargs
61 pool = self.cfg('pool') or {}
62 p = pool.get('disabled')
63 if p is True:
64 kwargs.setdefault('poolclass', sa.NullPool)
65 return kwargs
67 p = pool.get('pre_ping')
68 if p is True:
69 kwargs.setdefault('pool_pre_ping', True)
70 p = pool.get('size')
71 if isinstance(p, int):
72 kwargs.setdefault('pool_size', p)
73 p = pool.get('recycle')
74 if isinstance(p, int):
75 kwargs.setdefault('pool_recycle', p)
76 p = pool.get('timeout')
77 if isinstance(p, int):
78 kwargs.setdefault('pool_timeout', p)
80 return kwargs
82 def reflect_schema(self, schema: str):
83 if schema in self.saMetaMap:
84 return
86 def _load():
87 md = sa.MetaData(schema=schema)
89 # introspecting the whole schema is generally faster
90 # but what if we only need a single table from a big schema?
91 # @TODO add options for reflection
93 gws.debug.time_start(f'AUTOLOAD {self.uid=} {schema=}')
94 with self.connect() as conn:
95 md.reflect(conn.saConn, schema, resolve_fks=False, views=True)
96 gws.debug.time_end()
97 return md
99 life_time = self.cfg('schemaCacheLifeTime', 0)
100 if not life_time:
101 self.saMetaMap[schema] = _load()
102 else:
103 self.saMetaMap[schema] = gws.u.get_cached_object(f'database_metadata_schema_{schema}', life_time, _load)
105 def connect(self):
106 conn = self._open_connection()
107 return connection.Object(self, conn)
109 def _sa_connection(self) -> sa.Connection | None:
110 return getattr(_thread_local, '_connection', None)
112 def _open_connection(self) -> sa.Connection:
113 conn = getattr(_thread_local, '_connection', None)
114 cc = getattr(_thread_local, '_connectionCount', 0)
116 if conn is None:
117 assert cc == 0
118 conn = self.engine().connect()
119 setattr(_thread_local, '_connection', conn)
120 else:
121 assert cc > 0
123 setattr(_thread_local, '_connectionCount', cc + 1)
124 # gws.log.debug(f'db.connect: open: {cc + 1}')
125 return conn
127 def _close_connection(self):
128 conn = getattr(_thread_local, '_connection', None)
129 cc = getattr(_thread_local, '_connectionCount', 0)
130 assert conn is not None
131 assert cc > 0
132 # gws.log.debug(f'db.connect: close: {cc}')
133 if cc == 1:
134 if conn:
135 conn.close()
136 setattr(_thread_local, '_connection', None)
137 setattr(_thread_local, '_connectionCount', 0)
138 else:
139 setattr(_thread_local, '_connectionCount', cc - 1)
141 def table(self, table, **kwargs):
142 tab = self._sa_table(table)
143 if tab is None:
144 raise sa.Error(f'table {str(table)} not found')
145 return tab
147 def count(self, table):
148 tab = self._sa_table(table)
149 if tab is None:
150 return 0
151 sql = sa.select(sa.func.count()).select_from(tab)
152 with self.connect() as conn:
153 return conn.fetch_int(sql)
155 def has_table(self, table_name: str):
156 tab = self._sa_table(table_name)
157 return tab is not None
159 def _sa_table(self, tab_or_name) -> sa.Table | None:
160 if isinstance(tab_or_name, sa.Table):
161 return tab_or_name
162 schema, name = self.split_table_name(tab_or_name)
163 self.reflect_schema(schema)
164 # see _get_table_key in sqlalchemy/sql/schema.py
165 table_key = schema + '.' + name
166 sm = self.saMetaMap[schema]
167 if sm is None:
168 raise sa.Error(f'schema {schema!r} not found')
169 return sm.tables.get(table_key)
171 def column(self, table, column_name):
172 tab = self.table(table)
173 try:
174 return tab.columns[column_name]
175 except KeyError:
176 raise sa.Error(f'column {str(table)}.{column_name!r} not found')
178 def has_column(self, table, column_name):
179 tab = self._sa_table(table)
180 return tab is not None and column_name in tab.columns
182 def select_text(self, sql, **kwargs):
183 with self.connect() as conn:
184 try:
185 return [gws.u.to_dict(r) for r in conn.execute(sa.text(sql), kwargs)]
186 except sa.Error:
187 conn.rollback()
188 raise
190 def execute_text(self, sql, **kwargs):
191 with self.connect() as conn:
192 try:
193 res = conn.execute(sa.text(sql), kwargs)
194 conn.commit()
195 return res
196 except sa.Error:
197 conn.rollback()
198 raise
200 SA_TO_ATTR = {
201 # common: sqlalchemy.sql.sqltypes
202 'BIGINT': gws.AttributeType.int,
203 'BOOLEAN': gws.AttributeType.bool,
204 'CHAR': gws.AttributeType.str,
205 'DATE': gws.AttributeType.date,
206 'DOUBLE_PRECISION': gws.AttributeType.float,
207 'INTEGER': gws.AttributeType.int,
208 'NUMERIC': gws.AttributeType.float,
209 'REAL': gws.AttributeType.float,
210 'SMALLINT': gws.AttributeType.int,
211 'TEXT': gws.AttributeType.str,
212 # 'UUID': ...,
213 'VARCHAR': gws.AttributeType.str,
214 # postgres specific: sqlalchemy.dialects.postgresql.types
215 # 'JSON': ...,
216 # 'JSONB': ...,
217 # 'BIT': ...,
218 'BYTEA': gws.AttributeType.bytes,
219 # 'CIDR': ...,
220 # 'INET': ...,
221 # 'MACADDR': ...,
222 # 'MACADDR8': ...,
223 # 'MONEY': ...,
224 'TIME': gws.AttributeType.time,
225 'TIMESTAMP': gws.AttributeType.datetime,
226 }
228 # @TODO proper support for Z/M geoms
230 SA_TO_GEOM = {
231 'POINT': gws.GeometryType.point,
232 'POINTM': gws.GeometryType.point,
233 'POINTZ': gws.GeometryType.point,
234 'POINTZM': gws.GeometryType.point,
235 'LINESTRING': gws.GeometryType.linestring,
236 'LINESTRINGM': gws.GeometryType.linestring,
237 'LINESTRINGZ': gws.GeometryType.linestring,
238 'LINESTRINGZM': gws.GeometryType.linestring,
239 'POLYGON': gws.GeometryType.polygon,
240 'POLYGONM': gws.GeometryType.polygon,
241 'POLYGONZ': gws.GeometryType.polygon,
242 'POLYGONZM': gws.GeometryType.polygon,
243 'MULTIPOINT': gws.GeometryType.multipoint,
244 'MULTIPOINTM': gws.GeometryType.multipoint,
245 'MULTIPOINTZ': gws.GeometryType.multipoint,
246 'MULTIPOINTZM': gws.GeometryType.multipoint,
247 'MULTILINESTRING': gws.GeometryType.multilinestring,
248 'MULTILINESTRINGM': gws.GeometryType.multilinestring,
249 'MULTILINESTRINGZ': gws.GeometryType.multilinestring,
250 'MULTILINESTRINGZM': gws.GeometryType.multilinestring,
251 'MULTIPOLYGON': gws.GeometryType.multipolygon,
252 # 'GEOMETRYCOLLECTION': gws.GeometryType.geometrycollection,
253 # 'CURVE': gws.GeometryType.curve,
254 }
256 UNKNOWN_TYPE = gws.AttributeType.str
257 UNKNOWN_ARRAY_TYPE = gws.AttributeType.strlist
259 def describe(self, table):
260 tab = self._sa_table(table)
261 if tab is None:
262 raise sa.Error(f'table {table!r} not found')
264 schema = tab.schema
265 name = tab.name
267 desc = gws.DataSetDescription(
268 columns=[],
269 columnMap={},
270 fullName=self.join_table_name(schema or '', name),
271 geometryName='',
272 geometrySrid=0,
273 geometryType='',
274 name=name,
275 schema=schema,
276 )
278 for n, sa_col in enumerate(cast(list[sa.Column], tab.columns)):
279 col = self.describe_column(table, sa_col.name)
280 col.columnIndex = n
281 desc.columns.append(col)
282 desc.columnMap[col.name] = col
284 for col in desc.columns:
285 if col.geometryType:
286 desc.geometryName = col.name
287 desc.geometryType = col.geometryType
288 desc.geometrySrid = col.geometrySrid
289 break
291 return desc
293 def describe_column(self, table, column_name) -> gws.ColumnDescription:
294 sa_col = self.column(table, column_name)
296 col = gws.ColumnDescription(
297 columnIndex=0,
298 comment=str(sa_col.comment or ''),
299 default=sa_col.default,
300 geometrySrid=0,
301 geometryType='',
302 isAutoincrement=bool(sa_col.autoincrement),
303 isNullable=bool(sa_col.nullable),
304 isPrimaryKey=bool(sa_col.primary_key),
305 isUnique=bool(sa_col.unique),
306 hasDefault=sa_col.server_default is not None,
307 name=str(sa_col.name),
308 nativeType='',
309 type='',
310 )
312 col.nativeType = type(sa_col.type).__name__.upper()
313 col.type = self.SA_TO_ATTR.get(col.nativeType, self.UNKNOWN_TYPE)
315 return col
318##