Coverage for gws-app/gws/base/database/connection.py: 70%

74 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-16 23:09 +0200

1import gws 

2import gws.lib.sa as sa 

3 

4 

5class Object(gws.DatabaseConnection): 

6 db: gws.DatabaseProvider 

7 saConn: sa.Connection 

8 

9 def __init__(self, db: gws.DatabaseProvider, conn: sa.Connection): 

10 self.db = db 

11 self.saConn = conn 

12 

13 def __enter__(self): 

14 return self 

15 

16 def __exit__(self, exc_type, exc_value, traceback): 

17 self.close() 

18 

19 def close(self): 

20 getattr(self.db, '_close_connection')() 

21 

22 def execute(self, stmt, params=None, execution_options=None): 

23 return self.saConn.execute(stmt, params, execution_options=execution_options) 

24 

25 def commit(self): 

26 self.saConn.commit() 

27 

28 def rollback(self): 

29 self.saConn.rollback() 

30 

31 def exec(self, sql, **params): 

32 if isinstance(sql, str): 

33 sql = sa.text(sql) 

34 return self.saConn.execute(sql, params) 

35 

36 def exec_commit(self, sql, **params): 

37 if isinstance(sql, str): 

38 sql = sa.text(sql) 

39 try: 

40 res = self.saConn.execute(sql, params) 

41 self.saConn.commit() 

42 return res 

43 except Exception: 

44 self.saConn.rollback() 

45 raise 

46 

47 def exec_rollback(self, sql, **params): 

48 if isinstance(sql, str): 

49 sql = sa.text(sql) 

50 try: 

51 return self.saConn.execute(sql, params) 

52 finally: 

53 self.saConn.rollback() 

54 

55 def fetch_all(self, stmt, **params): 

56 return [r._asdict() for r in self.exec_rollback(stmt, **params)] 

57 

58 def fetch_first(self, stmt, **params): 

59 res = self.exec_rollback(stmt, **params) 

60 r = res.first() 

61 return r._asdict() if r else None 

62 

63 def fetch_scalars(self, stmt, **params): 

64 res = self.exec_rollback(stmt, **params) 

65 return list(res.scalars().all()) 

66 

67 def fetch_strings(self, stmt, **params): 

68 res = self.exec_rollback(stmt, **params) 

69 return [_to_str(s) for s in res.scalars().all()] 

70 

71 def fetch_ints(self, stmt, **params): 

72 res = self.exec_rollback(stmt, **params) 

73 return [_to_int(s) for s in res.scalars().all()] 

74 

75 def fetch_scalar(self, stmt, **params): 

76 res = self.exec_rollback(stmt, **params) 

77 return res.scalar() 

78 

79 def fetch_string(self, stmt, **params): 

80 res = self.exec_rollback(stmt, **params) 

81 s = res.scalar() 

82 return _to_str(s) if s is not None else None 

83 

84 def fetch_int(self, stmt, **params): 

85 res = self.exec_rollback(stmt, **params) 

86 s = res.scalar() 

87 return _to_int(s) if s is not None else None 

88 

89 

90## 

91 

92 

93def _to_int(s) -> int: 

94 if isinstance(s, int): 

95 return s 

96 raise ValueError(f'db: expected int, got {s=}') 

97 

98 

99def _to_str(s) -> str: 

100 if s is None: 

101 return '' 

102 return str(s)