Skip to content

prefect.task_runs

TaskRunWaiter

A service used for waiting for a task run to finish.

This service listens for task run events and provides a way to wait for a specific task run to finish. This is useful for waiting for a task run to finish before continuing execution.

The service is a singleton and must be started before use. The service will automatically start when the first instance is created. A single websocket connection is used to listen for task run events.

The service can be used to wait for a task run to finish by calling TaskRunWaiter.wait_for_task_run with the task run ID to wait for. The method will return when the task run has finished or the timeout has elapsed.

The service will automatically stop when the Python process exits or when the global loop thread is stopped.

Example:

import asyncio
from uuid import uuid4

from prefect import task
from prefect.task_engine import run_task_async
from prefect.task_runs import TaskRunWaiter


@task
async def test_task():
    await asyncio.sleep(5)
    print("Done!")


async def main():
    task_run_id = uuid4()
    asyncio.create_task(run_task_async(task=test_task, task_run_id=task_run_id))

    await TaskRunWaiter.wait_for_task_run(task_run_id)
    print("Task run finished")


if __name__ == "__main__":
    asyncio.run(main())
Source code in src/prefect/task_runs.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class TaskRunWaiter:
    """
    A service used for waiting for a task run to finish.

    This service listens for task run events and provides a way to wait for a specific
    task run to finish. This is useful for waiting for a task run to finish before
    continuing execution.

    The service is a singleton and must be started before use. The service will
    automatically start when the first instance is created. A single websocket
    connection is used to listen for task run events.

    The service can be used to wait for a task run to finish by calling
    `TaskRunWaiter.wait_for_task_run` with the task run ID to wait for. The method
    will return when the task run has finished or the timeout has elapsed.

    The service will automatically stop when the Python process exits or when the
    global loop thread is stopped.

    Example:
    ```python
    import asyncio
    from uuid import uuid4

    from prefect import task
    from prefect.task_engine import run_task_async
    from prefect.task_runs import TaskRunWaiter


    @task
    async def test_task():
        await asyncio.sleep(5)
        print("Done!")


    async def main():
        task_run_id = uuid4()
        asyncio.create_task(run_task_async(task=test_task, task_run_id=task_run_id))

        await TaskRunWaiter.wait_for_task_run(task_run_id)
        print("Task run finished")


    if __name__ == "__main__":
        asyncio.run(main())
    ```
    """

    _instance: Optional[Self] = None
    _instance_lock = threading.Lock()

    def __init__(self):
        self.logger = get_logger("TaskRunWaiter")
        self._consumer_task: Optional[asyncio.Task] = None
        self._observed_completed_task_runs: TTLCache[uuid.UUID, bool] = TTLCache(
            maxsize=10000, ttl=600
        )
        self._completion_events: Dict[uuid.UUID, asyncio.Event] = {}
        self._completion_callbacks: Dict[uuid.UUID, Callable] = {}
        self._loop: Optional[asyncio.AbstractEventLoop] = None
        self._observed_completed_task_runs_lock = threading.Lock()
        self._completion_events_lock = threading.Lock()
        self._started = False

    def start(self):
        """
        Start the TaskRunWaiter service.
        """
        if self._started:
            return
        self.logger.debug("Starting TaskRunWaiter")
        loop_thread = get_global_loop()

        if not asyncio.get_running_loop() == loop_thread._loop:
            raise RuntimeError("TaskRunWaiter must run on the global loop thread.")

        self._loop = loop_thread._loop

        consumer_started = asyncio.Event()
        self._consumer_task = self._loop.create_task(
            self._consume_events(consumer_started)
        )
        asyncio.run_coroutine_threadsafe(consumer_started.wait(), self._loop)

        loop_thread.add_shutdown_call(create_call(self.stop))
        atexit.register(self.stop)
        self._started = True

    async def _consume_events(self, consumer_started: asyncio.Event):
        async with get_events_subscriber(
            filter=EventFilter(
                event=EventNameFilter(
                    name=[
                        f"prefect.task-run.{state.name.title()}"
                        for state in TERMINAL_STATES
                    ],
                )
            )
        ) as subscriber:
            consumer_started.set()
            async for event in subscriber:
                try:
                    self.logger.debug(
                        f"Received event: {event.resource['prefect.resource.id']}"
                    )
                    task_run_id = uuid.UUID(
                        event.resource["prefect.resource.id"].replace(
                            "prefect.task-run.", ""
                        )
                    )

                    with self._observed_completed_task_runs_lock:
                        # Cache the task run ID for a short period of time to avoid
                        # unnecessary waits
                        self._observed_completed_task_runs[task_run_id] = True
                    with self._completion_events_lock:
                        # Set the event for the task run ID if it is in the cache
                        # so the waiter can wake up the waiting coroutine
                        if task_run_id in self._completion_events:
                            self._completion_events[task_run_id].set()
                        if task_run_id in self._completion_callbacks:
                            self._completion_callbacks[task_run_id]()
                except Exception as exc:
                    self.logger.error(f"Error processing event: {exc}")

    def stop(self):
        """
        Stop the TaskRunWaiter service.
        """
        self.logger.debug("Stopping TaskRunWaiter")
        if self._consumer_task:
            self._consumer_task.cancel()
            self._consumer_task = None
        self.__class__._instance = None
        self._started = False

    @classmethod
    async def wait_for_task_run(
        cls, task_run_id: uuid.UUID, timeout: Optional[float] = None
    ):
        """
        Wait for a task run to finish.

        Note this relies on a websocket connection to receive events from the server
        and will not work with an ephemeral server.

        Args:
            task_run_id: The ID of the task run to wait for.
            timeout: The maximum time to wait for the task run to
                finish. Defaults to None.
        """
        instance = cls.instance()
        with instance._observed_completed_task_runs_lock:
            if task_run_id in instance._observed_completed_task_runs:
                return

        # Need to create event in loop thread to ensure it can be set
        # from the loop thread
        finished_event = await from_async.wait_for_call_in_loop_thread(
            create_call(asyncio.Event)
        )
        with instance._completion_events_lock:
            # Cache the event for the task run ID so the consumer can set it
            # when the event is received
            instance._completion_events[task_run_id] = finished_event

        try:
            # Now check one more time whether the task run arrived before we start to
            # wait on it, in case it came in while we were setting up the event above.
            with instance._observed_completed_task_runs_lock:
                if task_run_id in instance._observed_completed_task_runs:
                    return

            with anyio.move_on_after(delay=timeout):
                await from_async.wait_for_call_in_loop_thread(
                    create_call(finished_event.wait)
                )
        finally:
            with instance._completion_events_lock:
                # Remove the event from the cache after it has been waited on
                instance._completion_events.pop(task_run_id, None)

    @classmethod
    def add_done_callback(cls, task_run_id: uuid.UUID, callback):
        """
        Add a callback to be called when a task run finishes.

        Args:
            task_run_id: The ID of the task run to wait for.
            callback: The callback to call when the task run finishes.
        """
        instance = cls.instance()
        with instance._observed_completed_task_runs_lock:
            if task_run_id in instance._observed_completed_task_runs:
                callback()
                return

        with instance._completion_events_lock:
            # Cache the event for the task run ID so the consumer can set it
            # when the event is received
            instance._completion_callbacks[task_run_id] = callback

    @classmethod
    def instance(cls):
        """
        Get the singleton instance of TaskRunWaiter.
        """
        with cls._instance_lock:
            if cls._instance is None:
                cls._instance = cls._new_instance()
            return cls._instance

    @classmethod
    def _new_instance(cls):
        instance = cls()

        if threading.get_ident() == get_global_loop().thread.ident:
            instance.start()
        else:
            from_sync.call_soon_in_loop_thread(create_call(instance.start)).result()

        return instance

add_done_callback(task_run_id, callback) classmethod

Add a callback to be called when a task run finishes.

Parameters:

Name Type Description Default
task_run_id UUID

The ID of the task run to wait for.

required
callback

The callback to call when the task run finishes.

required
Source code in src/prefect/task_runs.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
@classmethod
def add_done_callback(cls, task_run_id: uuid.UUID, callback):
    """
    Add a callback to be called when a task run finishes.

    Args:
        task_run_id: The ID of the task run to wait for.
        callback: The callback to call when the task run finishes.
    """
    instance = cls.instance()
    with instance._observed_completed_task_runs_lock:
        if task_run_id in instance._observed_completed_task_runs:
            callback()
            return

    with instance._completion_events_lock:
        # Cache the event for the task run ID so the consumer can set it
        # when the event is received
        instance._completion_callbacks[task_run_id] = callback

instance() classmethod

Get the singleton instance of TaskRunWaiter.

Source code in src/prefect/task_runs.py
221
222
223
224
225
226
227
228
229
@classmethod
def instance(cls):
    """
    Get the singleton instance of TaskRunWaiter.
    """
    with cls._instance_lock:
        if cls._instance is None:
            cls._instance = cls._new_instance()
        return cls._instance

start()

Start the TaskRunWaiter service.

Source code in src/prefect/task_runs.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def start(self):
    """
    Start the TaskRunWaiter service.
    """
    if self._started:
        return
    self.logger.debug("Starting TaskRunWaiter")
    loop_thread = get_global_loop()

    if not asyncio.get_running_loop() == loop_thread._loop:
        raise RuntimeError("TaskRunWaiter must run on the global loop thread.")

    self._loop = loop_thread._loop

    consumer_started = asyncio.Event()
    self._consumer_task = self._loop.create_task(
        self._consume_events(consumer_started)
    )
    asyncio.run_coroutine_threadsafe(consumer_started.wait(), self._loop)

    loop_thread.add_shutdown_call(create_call(self.stop))
    atexit.register(self.stop)
    self._started = True

stop()

Stop the TaskRunWaiter service.

Source code in src/prefect/task_runs.py
144
145
146
147
148
149
150
151
152
153
def stop(self):
    """
    Stop the TaskRunWaiter service.
    """
    self.logger.debug("Stopping TaskRunWaiter")
    if self._consumer_task:
        self._consumer_task.cancel()
        self._consumer_task = None
    self.__class__._instance = None
    self._started = False

wait_for_task_run(task_run_id, timeout=None) async classmethod

Wait for a task run to finish.

Note this relies on a websocket connection to receive events from the server and will not work with an ephemeral server.

Parameters:

Name Type Description Default
task_run_id UUID

The ID of the task run to wait for.

required
timeout Optional[float]

The maximum time to wait for the task run to finish. Defaults to None.

None
Source code in src/prefect/task_runs.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
@classmethod
async def wait_for_task_run(
    cls, task_run_id: uuid.UUID, timeout: Optional[float] = None
):
    """
    Wait for a task run to finish.

    Note this relies on a websocket connection to receive events from the server
    and will not work with an ephemeral server.

    Args:
        task_run_id: The ID of the task run to wait for.
        timeout: The maximum time to wait for the task run to
            finish. Defaults to None.
    """
    instance = cls.instance()
    with instance._observed_completed_task_runs_lock:
        if task_run_id in instance._observed_completed_task_runs:
            return

    # Need to create event in loop thread to ensure it can be set
    # from the loop thread
    finished_event = await from_async.wait_for_call_in_loop_thread(
        create_call(asyncio.Event)
    )
    with instance._completion_events_lock:
        # Cache the event for the task run ID so the consumer can set it
        # when the event is received
        instance._completion_events[task_run_id] = finished_event

    try:
        # Now check one more time whether the task run arrived before we start to
        # wait on it, in case it came in while we were setting up the event above.
        with instance._observed_completed_task_runs_lock:
            if task_run_id in instance._observed_completed_task_runs:
                return

        with anyio.move_on_after(delay=timeout):
            await from_async.wait_for_call_in_loop_thread(
                create_call(finished_event.wait)
            )
    finally:
        with instance._completion_events_lock:
            # Remove the event from the cache after it has been waited on
            instance._completion_events.pop(task_run_id, None)