Swift Concurrency - Racing async functions with task groups

📆 Published 2021-12-10 ⏰ 4m read

Did you hear that Swift concurrency now back deploys to iOS 13? This means async/await for all of us!

Swift doesn't have a Promise.race-like feature (from JS), which is a bit annoying. Let's write our own helper. I'd like a utility that has race options at least for:

  1. First resolved (the first error or result from the first resolved function)
  2. First value (the first resolved value, ignoring all errors)

Once completed, we should cancel any remaining tasks, so they can finish as soon as possible, because we don't care about their results. We also want to specify priorities per-task. With that in mind, let's make a simple struct that can hold a task that we want to schedule to run.

We want to write this as an extension on Task, so it's easy to find. To avoid complier errors about Task's generic parameters not being able to be inferred, we'll provide some dummy values. We're using Never for both because that's where other static methods like Task.sleep live, but it doesn't really matter.

extension Task where Success == Never, Failure == Never {
    struct Blueprint<Output> {
        var priority: TaskPriority = .medium
        var operation: @Sendable () async throws -> Output
    }
}

Task.Blueprint does not accept or return a Task handle, but an async closure to run when the race starts. This is because if we create a Task it would start executing right away, and we want a fair race. If this matters less for your use case, or you don't care if the race is entirely fair, you can just use Task instead of Task.Blueprint.

We'll also mark the clousure as @Sendable because it will be concurrently executing. This enforces that we will not modify any shared mutable state in the closure, to ensure we're not introducing any data races.

Now let's start racing.

Our "first resolved" race should always return a value, but should rethrow the very first error produced and cancel other running tasks. We care about the first task to finish, we don't care about success in this case.

extension Task where Success == Never, Failure == Never {
    static func race<Output>(firstResolved tasks: [Blueprint<Output>]) async throws -> Output {
        // (0)
        assert(!tasks.isEmpty, "You must race at least 1 task.")
        // (1)
        return try await withThrowingTaskGroup(of: Output.self) { group -> Output in
            // (2)
            for task in tasks {
                group.addTask(priority: task.priority) {
                    try await task.operation()
                }
            }
            // (3)
            defer { group.cancelAll() }
            // (4)
            if let firstToResolve = try await group.next() {
                return firstToResolve
            } else {
                fatalError("At least 1 task should be scheduled.")
            }
        }
    }
}
  1. If there's no tasks this is a logic error–we should crash.
  2. We create a task group so the tasks we race can run concurrently.
  3. We add all the tasks to the group, each with their specified priority.
  4. When the group is finished (when the first task resolves), cancel all the other tasks. Swift task cancellation is not automatic so each task will need to correctly implement a cancellation handler and respond appropriately, if required. As an aside, since I started using async Swift, my usage of the defer keyword has increased by 10000%.
  5. We retrieve the first task to resolve and then return. If a task throws before any value is returned then the whole group will rethrow that error and other tasks will be cancelled (thanks to the defer).

Our "first value" race should never throw, instead waiting for the first produced value. If no value is produced (all the tasks encountered errors) we return nil.

extension Task where Success == Never, Failure == Never {
    static func race<Output>(firstValue tasks: [Blueprint<Output>]) async -> Output? {
        return await withThrowingTaskGroup(of: Output.self) { group -> Output? in
            for task in tasks {
                group.addTask(priority: task.priority) {
                    try await task.operation()
                }
            }
            defer { group.cancelAll() }
            while let nextResult = await group.nextResult() {
                switch nextResult {
                case .failure:
                    continue
                case .success(let result):
                    return result
                }
            }
            // If all the racing tasks error, we will reach this point.
            return nil
        }
    }
}

It's pretty similar, but this time we ignore errors until we have a value.

Please remember that these won't be perfectly fair races–it depends on a huge number of factors:

  1. The number of tasks vs cores on the target machine.
  2. The number of async functions called by each task.
  3. The order we add tasks to the group.
  4. How the Dispatch scheduler is feeling today.

Luckily, that's not important to us. We don't care if it's a slightly unfair race if two tasks finish at almost the same time. We do care if there is a large difference in the duration of the tasks–in that case, we always want the faster one to finish first. If exact timing is important to you, this is not the post for you.

For example, I'm using this to implement a timeout mechanism.

public extension Task where Success == Never, Failure == Never {
    /// Sleep for the specified `TimeInterval`.
    @inlinable static func sleep(duration: TimeInterval) async throws {
        try await sleep(nanoseconds: UInt64(duration*1e9))
    }
}

public struct TimeoutError: Error, LocalizedError {
    /// When the timeout occurred.
    public let occurred: Date = Date()
    public var errorDescription: String? {
        "The operation timed out."
    }
}

/// Run a new task that will fail after `delay`.
/// You should ensure that the task run here responds to a cancellation event as soon as possible.
/// - returns: The value if the operation did not timeout.
/// - throws: `TimeoutError` if the operation timed out.
public func withTimeout<T>(
    delay: TimeInterval,
    priority: TaskPriority = .medium,
    run task: @Sendable @escaping () async throws -> T
) async throws -> T {
    return try await Task.race(firstResolved: [
        .init {
            try await Task.sleep(duration: delay)
            throw TimeoutError()
        },
        .init(priority: priority) {
            try await task()
        },
    ])
}

Happy racing!

Complete Implementation

public extension Task where Success == Never, Failure == Never {
    /// Blueprint for a task that should be run, but not yet.
    struct Blueprint<Output> {
        public var priority: TaskPriority
        public var operation: @Sendable () async throws -> Output

        public init(
            priority: TaskPriority = .medium,
            operation: @escaping @Sendable () async throws -> Output
        ) {
            self.priority = priority
            self.operation = operation
        }
    }
}
public extension Task where Success == Never, Failure == Never {
    /// Race for the first result by any of the provided tasks.
    ///
    /// This will return the first valid result or throw the first thrown error by any task.
    static func race<Output>(firstResolved tasks: [Blueprint<Output>]) async throws -> Output {
        assert(!tasks.isEmpty, "You must race at least 1 task.")
        return try await withThrowingTaskGroup(of: Output.self) { group -> Output in
            for task in tasks {
                group.addTask(priority: task.priority) {
                    try await task.operation()
                }
            }

            defer { group.cancelAll() }
            if let firstToResolve = try await group.next() {
                return firstToResolve
            } else {
                // There will be at least 1 task.
                fatalError("At least 1 task should be scheduled.")
            }
        }
    }

    /// Race for the first valid value.
    ///
    /// Ignores errors that may be thrown and waits for the first result.
    /// If all tasks fail, returns `nil`.
    static func race<Output>(firstValue tasks: [Blueprint<Output>]) async -> Output? {
        return await withThrowingTaskGroup(of: Output.self) { group -> Output? in
            for task in tasks {
                group.addTask(priority: task.priority) {
                    try await task.operation()
                }
            }

            defer { group.cancelAll() }
            while let nextResult = await group.nextResult() {
                switch nextResult {
                case .failure:
                    continue
                case .success(let result):
                    return result
                }
            }

            // If all the racing tasks error, we will reach this point.
            return nil
        }
    }
}

References