Are you looking for a way to choose one task or another? Do you want to execute a task based on a condition? Do you have multiple tasks, but only one should be executed if a criterion is valid? You’ve come to the right place! The BranchPythonOperator does precisely what you are looking for. It’s common to have DAGs with different execution flows, and you want to follow only one, depending on a value or a condition. The BranchPythonOperator allows you to choose one branch among branches of your DAG. To better understand it, it is worth looking at the PythonOperator here. Before moving forward, if you are looking for a great course to master Airflow, check out this one. Ready? Go!
Use case
Using a use case is the best way to understand a concept. Let’s say you have the following data pipeline:
A quick explanation about the DAG: the first three tasks train machine learning models. Once they are complete, choose_best_model
runs, then, either the task is_accurate
or is_inaccurate
is triggered. Which depends on the accuracy of the best machine-learning model. If the precision is above 5.0, we trigger is_accurate
, otherwise, is_inaccurate
.
How can you do this? Is there a mechanism to choose one task or another according to a condition or a value?
Yes! With the BranchPythonOperator!
The BranchPythonOperator
The BranchPythonOperator allows you to follow a specific path in your DAG according to a condition. It evaluates the condition that is itself in a Python callable function. As for the PythonOperator, the BranchPythonOperator executes a Python function that returns a single task ID or a list of task IDs corresponding to the task(s) to run. Look at the example below.
Can you guess which task will run after choose_best_model
?
is_accurate
! As the accuracy is greater than 5, the BranchPythonOperator returns the task id is_accurate
(note that it’s a String). Therefore, is_accurate
gets triggered. Great! But wait a minute, what happens to the other task? Does it stay with no status at all?
No!
Here is precisely what you obtain by triggering the DAG:
choose_best_model
and is_accurate succeed whereas is_inaccurate
is skipped. By default, downstream tasks not returned by the BranchPythonOperator are skipped! Also, tasks following skipped tasks are skipped as well. In the example, if you had a task after is_inaccurate, that task would have been skipped too. You change this behavior with the Trigger Rules.
How to use the BranchPythonOperator?
Implementing the BranchPythonOperator is easy:
from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.operators.empty import EmptyOperator
from datetime import datetime
def _choose_best_model():
accuracy = 6
if accuracy > 5:
return 'accurate'
return 'inaccurate'
with DAG('branching', start_date=datetime(2023, 1, 1), schedule='@daily', catchup=False):
choose_best_model = BranchPythonOperator(
task_id='choose_best_model',
python_callable=_choose_best_model
)
accurate = EmptyOperator(
task_id='accurate'
)
inaccurate = EmptyOperator(
task_id='inaccurate'
)
choose_best_model >> [accurate, inaccurate]
The Taskflow API version:
from airflow.decorators import dag, task
from datetime import datetime
@dag(
start_date=datetime(2023, 1, 1),
schedule='@daily',
catchup=False
)
def branching():
@task.branch
def choose_best_model(accuracy=6):
if accuracy > 5:
return 'is_accurate'
return 'is_inaccurate'
@task
def is_accurate():
pass
@task
def is_inaccurate():
pass
choose_best_model() >> [is_accurate(), is_inaccurate()]
branching()
Regardless of your implementation, the BranchPythonOperator always expects a Python function that returns a single task ID or a list of task IDs.
The BranchPythonOperator with XComs
Where does the accuracy come from? Usually, the value you use for the BranchPythonOperator’s condition comes from another task using XComs. If you don’t know an XCom, look at the tutorial here.
Let me show you an example:
from airflow.decorators import dag, task
from datetime import datetime
@dag(
start_date=datetime(2023, 1, 1),
schedule='@daily',
catchup=False
)
def branching():
@task
def ml_a():
return 6
@task.branch
def choose_best_model(accuracy):
if accuracy > 5:
return 'is_accurate'
return 'is_inaccurate'
@task
def is_accurate():
pass
@task
def is_inaccurate():
pass
accuracy = ml_a()
choose_best_model(accuracy)
branching()
The task ml_a
returns the value six. That creates an XCOM with that value that is later shared with choose_best_model
. Looking at the XComs, you will see:
ml_a
produces the first XCom with the key return_value
and the value 6
. Interestingly, the BranchPythonOperator creates not one but two XComs! One with the key skipmixin_key
so the Airflow Scheduler knows what tasks to run and what to skip. Then, a second XCom that is optional.
To remove this optional XCom, I recommend you to set do_xcom_push=False
.
@task.branch(do_xcom_push=False)
def choose_best_model(accuracy):
if accuracy > 5:
return 'is_accurate'
return 'is_inaccurate'
Now, what if you want to compare multiple accuracies? Pull multiple XComs at once!
# without Taskflow
def _choose_best_model(ti=None):
accuracies = ti.xcom_pull(key='accuracy', task_ids=['ml_a', 'ml_b', 'ml_c'])
max_accuracy = max(accuracies)
if max_accuracy > 5:
return 'accurate'
return 'inaccurate'
with DAG():
choose_best_model = BranchPythonOperator(
task_id='choose_best_model',
python_callable=_choose_best_model
)
# with Taskflow
@task.branch(do_xcom_push=False)
def choose_best_model(ml_a, ml_b, ml_c):
max_accuracy = max([ml_a, ml_b, ml_c])
if max_accuracy > 5:
return 'is_accurate'
return 'is_inaccurate'
choose_best_model(3, 4, 6)
The pitfall
There is something to pay attention to with the BranchPythonOperator. Suppose you have a task store
that you want to run regardless of the BranchPythonOperator’s result.
What do you think will happen to the task store
if run this DAG?
The task is skipped! Fortunatly, there is an easy way to solve that: none_failed_min_one_success
.
This trigger rule ensures that your task runs if all direct upstream tasks haven’t failed and at least one succeeded. Which is the case with is_accurate
. To apply that trigger rule to store
:
# without Taskflow
EmptyOperator(task_id='store', trigger_rule='none_failed_min_one_success')
# with Taskflow
@task(trigger_rule='none_failed_min_one_success')
def store():
pass
Running the DAG again will give:
How to choose multiple tasks?
So far, you’ve been choosing a single task with the BranchPythonOperator. However, you can choose to run more than one. How? By returning a list of task IDs:
@task.branch()
def choose_best_model():
return ['task_a', 'task_b', 'task_c']
Simple.
BranchPythonOperator best practices
Here are a few best practices:
- Always return at least one task id, or you will get an error.
- Ensure that the tasks exist for the task ids returned by the BranchPythonOperator or you will get an error.
- Set
do_xcom_push=False
do avoid creating additional XComs. - You will almost always use the trigger rule
none_failed_min_one_success
for the task merging your branches.
BranchPythonOperator in action!
Here is the video that covers the BranchPythonOperator:
Conclusion
That’s it about the BranchPythonOperator. I hope you enjoyed this tutorial. The BranchPythonOperator is useful as it allows solving new use cases for you data pipelines. If you want to learn more about Airflow, go check out my courses here.
Have a great day!