53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331 | class Transaction(ContextModel):
"""
A base model for transaction state.
"""
store: Optional[RecordStore] = None
key: Optional[str] = None
children: List["Transaction"] = Field(default_factory=list)
commit_mode: Optional[CommitMode] = None
isolation_level: Optional[IsolationLevel] = IsolationLevel.READ_COMMITTED
state: TransactionState = TransactionState.PENDING
on_commit_hooks: List[Callable[["Transaction"], None]] = Field(default_factory=list)
on_rollback_hooks: List[Callable[["Transaction"], None]] = Field(
default_factory=list
)
overwrite: bool = False
logger: Union[logging.Logger, logging.LoggerAdapter] = Field(
default_factory=partial(get_logger, "transactions")
)
_stored_values: Dict[str, Any] = PrivateAttr(default_factory=dict)
_staged_value: Any = None
__var__: ContextVar = ContextVar("transaction")
def set(self, name: str, value: Any) -> None:
self._stored_values[name] = value
def get(self, name: str, default: Any = NotSet) -> Any:
if name not in self._stored_values:
if default is not NotSet:
return default
raise ValueError(f"Could not retrieve value for unknown key: {name}")
return self._stored_values.get(name)
def is_committed(self) -> bool:
return self.state == TransactionState.COMMITTED
def is_rolled_back(self) -> bool:
return self.state == TransactionState.ROLLED_BACK
def is_staged(self) -> bool:
return self.state == TransactionState.STAGED
def is_pending(self) -> bool:
return self.state == TransactionState.PENDING
def is_active(self) -> bool:
return self.state == TransactionState.ACTIVE
def __enter__(self):
if self._token is not None:
raise RuntimeError(
"Context already entered. Context enter calls cannot be nested."
)
parent = get_transaction()
if parent:
self._stored_values = copy.deepcopy(parent._stored_values)
# set default commit behavior; either inherit from parent or set a default of eager
if self.commit_mode is None:
self.commit_mode = parent.commit_mode if parent else CommitMode.LAZY
# set default isolation level; either inherit from parent or set a default of read committed
if self.isolation_level is None:
self.isolation_level = (
parent.isolation_level if parent else IsolationLevel.READ_COMMITTED
)
assert self.isolation_level is not None, "Isolation level was not set correctly"
if (
self.store
and self.key
and not self.store.supports_isolation_level(self.isolation_level)
):
raise ValueError(
f"Isolation level {self.isolation_level.name} is not supported by record store type {self.store.__class__.__name__}"
)
# this needs to go before begin, which could set the state to committed
self.state = TransactionState.ACTIVE
self.begin()
self._token = self.__var__.set(self)
return self
def __exit__(self, *exc_info):
exc_type, exc_val, _ = exc_info
if not self._token:
raise RuntimeError(
"Asymmetric use of context. Context exit called without an enter."
)
if exc_type:
self.rollback()
self.reset()
raise exc_val
if self.commit_mode == CommitMode.EAGER:
self.commit()
# if parent, let them take responsibility
if self.get_parent():
self.reset()
return
if self.commit_mode == CommitMode.OFF:
# if no one took responsibility to commit, rolling back
# note that rollback returns if already committed
self.rollback()
elif self.commit_mode == CommitMode.LAZY:
# no one left to take responsibility for committing
self.commit()
self.reset()
def begin(self):
if (
self.store
and self.key
and self.isolation_level == IsolationLevel.SERIALIZABLE
):
self.logger.debug(f"Acquiring lock for transaction {self.key!r}")
self.store.acquire_lock(self.key)
if (
not self.overwrite
and self.store
and self.key
and self.store.exists(key=self.key)
):
self.state = TransactionState.COMMITTED
def read(self) -> Optional[BaseResult]:
if self.store and self.key:
record = self.store.read(key=self.key)
if record is not None:
return record.result
return None
def reset(self) -> None:
parent = self.get_parent()
if parent:
# parent takes responsibility
parent.add_child(self)
if self._token:
self.__var__.reset(self._token)
self._token = None
# do this below reset so that get_transaction() returns the relevant txn
if parent and self.state == TransactionState.ROLLED_BACK:
parent.rollback()
def add_child(self, transaction: "Transaction") -> None:
self.children.append(transaction)
def get_parent(self) -> Optional["Transaction"]:
prev_var = getattr(self._token, "old_value")
if prev_var != Token.MISSING:
parent = prev_var
else:
parent = None
return parent
def commit(self) -> bool:
if self.state in [TransactionState.ROLLED_BACK, TransactionState.COMMITTED]:
if (
self.store
and self.key
and self.isolation_level == IsolationLevel.SERIALIZABLE
):
self.logger.debug(f"Releasing lock for transaction {self.key!r}")
self.store.release_lock(self.key)
return False
try:
for child in self.children:
child.commit()
for hook in self.on_commit_hooks:
self.run_hook(hook, "commit")
if self.store and self.key:
self.store.write(key=self.key, result=self._staged_value)
self.state = TransactionState.COMMITTED
if (
self.store
and self.key
and self.isolation_level == IsolationLevel.SERIALIZABLE
):
self.logger.debug(f"Releasing lock for transaction {self.key!r}")
self.store.release_lock(self.key)
return True
except SerializationError as exc:
if self.logger:
self.logger.warning(
f"Encountered an error while serializing result for transaction {self.key!r}: {exc}"
" Code execution will continue, but the transaction will not be committed.",
)
self.rollback()
return False
except Exception:
if self.logger:
self.logger.exception(
f"An error was encountered while committing transaction {self.key!r}",
exc_info=True,
)
self.rollback()
return False
def run_hook(self, hook, hook_type: str) -> None:
hook_name = _get_hook_name(hook)
# Undocumented way to disable logging for a hook. Subject to change.
should_log = getattr(hook, "log_on_run", True)
if should_log:
self.logger.info(f"Running {hook_type} hook {hook_name!r}")
try:
hook(self)
except Exception as exc:
if should_log:
self.logger.error(
f"An error was encountered while running {hook_type} hook {hook_name!r}",
)
raise exc
else:
if should_log:
self.logger.info(
f"{hook_type.capitalize()} hook {hook_name!r} finished running successfully"
)
def stage(
self,
value: BaseResult,
on_rollback_hooks: Optional[List] = None,
on_commit_hooks: Optional[List] = None,
) -> None:
"""
Stage a value to be committed later.
"""
on_commit_hooks = on_commit_hooks or []
on_rollback_hooks = on_rollback_hooks or []
if self.state != TransactionState.COMMITTED:
self._staged_value = value
self.on_rollback_hooks += on_rollback_hooks
self.on_commit_hooks += on_commit_hooks
self.state = TransactionState.STAGED
def rollback(self) -> bool:
if self.state in [TransactionState.ROLLED_BACK, TransactionState.COMMITTED]:
return False
try:
for hook in reversed(self.on_rollback_hooks):
self.run_hook(hook, "rollback")
self.state = TransactionState.ROLLED_BACK
for child in reversed(self.children):
child.rollback()
return True
except Exception:
if self.logger:
self.logger.exception(
f"An error was encountered while rolling back transaction {self.key!r}",
exc_info=True,
)
return False
finally:
if (
self.store
and self.key
and self.isolation_level == IsolationLevel.SERIALIZABLE
):
self.logger.debug(f"Releasing lock for transaction {self.key!r}")
self.store.release_lock(self.key)
@classmethod
def get_active(cls: Type[Self]) -> Optional[Self]:
return cls.__var__.get(None)
|