A service used for waiting for a task run to finish.
This service listens for task run events and provides a way to wait for a specific
task run to finish. This is useful for waiting for a task run to finish before
continuing execution.
The service is a singleton and must be started before use. The service will
automatically start when the first instance is created. A single websocket
connection is used to listen for task run events.
The service can be used to wait for a task run to finish by calling
TaskRunWaiter.wait_for_task_run
with the task run ID to wait for. The method
will return when the task run has finished or the timeout has elapsed.
The service will automatically stop when the Python process exits or when the
global loop thread is stopped.
Example:
import asyncio
from uuid import uuid4
from prefect import task
from prefect.task_engine import run_task_async
from prefect.task_runs import TaskRunWaiter
@task
async def test_task():
await asyncio.sleep(5)
print("Done!")
async def main():
task_run_id = uuid4()
asyncio.create_task(run_task_async(task=test_task, task_run_id=task_run_id))
await TaskRunWaiter.wait_for_task_run(task_run_id)
print("Task run finished")
if __name__ == "__main__":
asyncio.run(main())
Source code in src/prefect/task_runs.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240 | class TaskRunWaiter:
"""
A service used for waiting for a task run to finish.
This service listens for task run events and provides a way to wait for a specific
task run to finish. This is useful for waiting for a task run to finish before
continuing execution.
The service is a singleton and must be started before use. The service will
automatically start when the first instance is created. A single websocket
connection is used to listen for task run events.
The service can be used to wait for a task run to finish by calling
`TaskRunWaiter.wait_for_task_run` with the task run ID to wait for. The method
will return when the task run has finished or the timeout has elapsed.
The service will automatically stop when the Python process exits or when the
global loop thread is stopped.
Example:
```python
import asyncio
from uuid import uuid4
from prefect import task
from prefect.task_engine import run_task_async
from prefect.task_runs import TaskRunWaiter
@task
async def test_task():
await asyncio.sleep(5)
print("Done!")
async def main():
task_run_id = uuid4()
asyncio.create_task(run_task_async(task=test_task, task_run_id=task_run_id))
await TaskRunWaiter.wait_for_task_run(task_run_id)
print("Task run finished")
if __name__ == "__main__":
asyncio.run(main())
```
"""
_instance: Optional[Self] = None
_instance_lock = threading.Lock()
def __init__(self):
self.logger = get_logger("TaskRunWaiter")
self._consumer_task: Optional[asyncio.Task] = None
self._observed_completed_task_runs: TTLCache[uuid.UUID, bool] = TTLCache(
maxsize=10000, ttl=600
)
self._completion_events: Dict[uuid.UUID, asyncio.Event] = {}
self._completion_callbacks: Dict[uuid.UUID, Callable] = {}
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._observed_completed_task_runs_lock = threading.Lock()
self._completion_events_lock = threading.Lock()
self._started = False
def start(self):
"""
Start the TaskRunWaiter service.
"""
if self._started:
return
self.logger.debug("Starting TaskRunWaiter")
loop_thread = get_global_loop()
if not asyncio.get_running_loop() == loop_thread._loop:
raise RuntimeError("TaskRunWaiter must run on the global loop thread.")
self._loop = loop_thread._loop
consumer_started = asyncio.Event()
self._consumer_task = self._loop.create_task(
self._consume_events(consumer_started)
)
asyncio.run_coroutine_threadsafe(consumer_started.wait(), self._loop)
loop_thread.add_shutdown_call(create_call(self.stop))
atexit.register(self.stop)
self._started = True
async def _consume_events(self, consumer_started: asyncio.Event):
async with get_events_subscriber(
filter=EventFilter(
event=EventNameFilter(
name=[
f"prefect.task-run.{state.name.title()}"
for state in TERMINAL_STATES
],
)
)
) as subscriber:
consumer_started.set()
async for event in subscriber:
try:
self.logger.debug(
f"Received event: {event.resource['prefect.resource.id']}"
)
task_run_id = uuid.UUID(
event.resource["prefect.resource.id"].replace(
"prefect.task-run.", ""
)
)
with self._observed_completed_task_runs_lock:
# Cache the task run ID for a short period of time to avoid
# unnecessary waits
self._observed_completed_task_runs[task_run_id] = True
with self._completion_events_lock:
# Set the event for the task run ID if it is in the cache
# so the waiter can wake up the waiting coroutine
if task_run_id in self._completion_events:
self._completion_events[task_run_id].set()
if task_run_id in self._completion_callbacks:
self._completion_callbacks[task_run_id]()
except Exception as exc:
self.logger.error(f"Error processing event: {exc}")
def stop(self):
"""
Stop the TaskRunWaiter service.
"""
self.logger.debug("Stopping TaskRunWaiter")
if self._consumer_task:
self._consumer_task.cancel()
self._consumer_task = None
self.__class__._instance = None
self._started = False
@classmethod
async def wait_for_task_run(
cls, task_run_id: uuid.UUID, timeout: Optional[float] = None
):
"""
Wait for a task run to finish.
Note this relies on a websocket connection to receive events from the server
and will not work with an ephemeral server.
Args:
task_run_id: The ID of the task run to wait for.
timeout: The maximum time to wait for the task run to
finish. Defaults to None.
"""
instance = cls.instance()
with instance._observed_completed_task_runs_lock:
if task_run_id in instance._observed_completed_task_runs:
return
# Need to create event in loop thread to ensure it can be set
# from the loop thread
finished_event = await from_async.wait_for_call_in_loop_thread(
create_call(asyncio.Event)
)
with instance._completion_events_lock:
# Cache the event for the task run ID so the consumer can set it
# when the event is received
instance._completion_events[task_run_id] = finished_event
try:
# Now check one more time whether the task run arrived before we start to
# wait on it, in case it came in while we were setting up the event above.
with instance._observed_completed_task_runs_lock:
if task_run_id in instance._observed_completed_task_runs:
return
with anyio.move_on_after(delay=timeout):
await from_async.wait_for_call_in_loop_thread(
create_call(finished_event.wait)
)
finally:
with instance._completion_events_lock:
# Remove the event from the cache after it has been waited on
instance._completion_events.pop(task_run_id, None)
@classmethod
def add_done_callback(cls, task_run_id: uuid.UUID, callback):
"""
Add a callback to be called when a task run finishes.
Args:
task_run_id: The ID of the task run to wait for.
callback: The callback to call when the task run finishes.
"""
instance = cls.instance()
with instance._observed_completed_task_runs_lock:
if task_run_id in instance._observed_completed_task_runs:
callback()
return
with instance._completion_events_lock:
# Cache the event for the task run ID so the consumer can set it
# when the event is received
instance._completion_callbacks[task_run_id] = callback
@classmethod
def instance(cls):
"""
Get the singleton instance of TaskRunWaiter.
"""
with cls._instance_lock:
if cls._instance is None:
cls._instance = cls._new_instance()
return cls._instance
@classmethod
def _new_instance(cls):
instance = cls()
if threading.get_ident() == get_global_loop().thread.ident:
instance.start()
else:
from_sync.call_soon_in_loop_thread(create_call(instance.start)).result()
return instance
|
add_done_callback(task_run_id, callback)
classmethod
Add a callback to be called when a task run finishes.
Parameters:
Name |
Type |
Description |
Default |
task_run_id |
UUID
|
The ID of the task run to wait for.
|
required
|
callback |
|
The callback to call when the task run finishes.
|
required
|
Source code in src/prefect/task_runs.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219 | @classmethod
def add_done_callback(cls, task_run_id: uuid.UUID, callback):
"""
Add a callback to be called when a task run finishes.
Args:
task_run_id: The ID of the task run to wait for.
callback: The callback to call when the task run finishes.
"""
instance = cls.instance()
with instance._observed_completed_task_runs_lock:
if task_run_id in instance._observed_completed_task_runs:
callback()
return
with instance._completion_events_lock:
# Cache the event for the task run ID so the consumer can set it
# when the event is received
instance._completion_callbacks[task_run_id] = callback
|
instance()
classmethod
Get the singleton instance of TaskRunWaiter.
Source code in src/prefect/task_runs.py
221
222
223
224
225
226
227
228
229 | @classmethod
def instance(cls):
"""
Get the singleton instance of TaskRunWaiter.
"""
with cls._instance_lock:
if cls._instance is None:
cls._instance = cls._new_instance()
return cls._instance
|
start()
Start the TaskRunWaiter service.
Source code in src/prefect/task_runs.py
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105 | def start(self):
"""
Start the TaskRunWaiter service.
"""
if self._started:
return
self.logger.debug("Starting TaskRunWaiter")
loop_thread = get_global_loop()
if not asyncio.get_running_loop() == loop_thread._loop:
raise RuntimeError("TaskRunWaiter must run on the global loop thread.")
self._loop = loop_thread._loop
consumer_started = asyncio.Event()
self._consumer_task = self._loop.create_task(
self._consume_events(consumer_started)
)
asyncio.run_coroutine_threadsafe(consumer_started.wait(), self._loop)
loop_thread.add_shutdown_call(create_call(self.stop))
atexit.register(self.stop)
self._started = True
|
stop()
Stop the TaskRunWaiter service.
Source code in src/prefect/task_runs.py
144
145
146
147
148
149
150
151
152
153 | def stop(self):
"""
Stop the TaskRunWaiter service.
"""
self.logger.debug("Stopping TaskRunWaiter")
if self._consumer_task:
self._consumer_task.cancel()
self._consumer_task = None
self.__class__._instance = None
self._started = False
|
wait_for_task_run(task_run_id, timeout=None)
async
classmethod
Wait for a task run to finish.
Note this relies on a websocket connection to receive events from the server
and will not work with an ephemeral server.
Parameters:
Name |
Type |
Description |
Default |
task_run_id |
UUID
|
The ID of the task run to wait for.
|
required
|
timeout |
Optional[float]
|
The maximum time to wait for the task run to
finish. Defaults to None.
|
None
|
Source code in src/prefect/task_runs.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199 | @classmethod
async def wait_for_task_run(
cls, task_run_id: uuid.UUID, timeout: Optional[float] = None
):
"""
Wait for a task run to finish.
Note this relies on a websocket connection to receive events from the server
and will not work with an ephemeral server.
Args:
task_run_id: The ID of the task run to wait for.
timeout: The maximum time to wait for the task run to
finish. Defaults to None.
"""
instance = cls.instance()
with instance._observed_completed_task_runs_lock:
if task_run_id in instance._observed_completed_task_runs:
return
# Need to create event in loop thread to ensure it can be set
# from the loop thread
finished_event = await from_async.wait_for_call_in_loop_thread(
create_call(asyncio.Event)
)
with instance._completion_events_lock:
# Cache the event for the task run ID so the consumer can set it
# when the event is received
instance._completion_events[task_run_id] = finished_event
try:
# Now check one more time whether the task run arrived before we start to
# wait on it, in case it came in while we were setting up the event above.
with instance._observed_completed_task_runs_lock:
if task_run_id in instance._observed_completed_task_runs:
return
with anyio.move_on_after(delay=timeout):
await from_async.wait_for_call_in_loop_thread(
create_call(finished_event.wait)
)
finally:
with instance._completion_events_lock:
# Remove the event from the cache after it has been waited on
instance._completion_events.pop(task_run_id, None)
|