Are you looking for a way to choose one task or another? Do you want to execute a task based on a condition? Do you have multiple tasks, but only one should be executed if a criterion is valid? You’ve come to the right place! The BranchPythonOperator does precisely what you are looking for. It’s common to have DAGs with different execution flows, and you want to follow only one, depending on a value or a condition. The BranchPythonOperator allows you to choose one branch among branches of your DAG. To better understand it, it is worth looking at the PythonOperator here. Before moving forward, if you are looking for a great course to master Airflow, check out this one. Ready? Go!
Using a use case is the best way to understand a concept. Let’s say you have the following data pipeline:
A quick explanation about the DAG: the first three tasks train machine learning models. Once they are complete,
choose_best_model runs, then, either the task
is_inaccurate is triggered. Which depends on the accuracy of the best machine-learning model. If the precision is above 5.0, we trigger
How can you do this? Is there a mechanism to choose one task or another according to a condition or a value?
Yes! With the BranchPythonOperator!
The BranchPythonOperator allows you to follow a specific path in your DAG according to a condition. It evaluates the condition that is itself in a Python callable function. As for the PythonOperator, the BranchPythonOperator executes a Python function that returns a single task ID or a list of task IDs corresponding to the task(s) to run. Look at the example below.
Can you guess which task will run after
is_accurate ! As the accuracy is greater than 5, the BranchPythonOperator returns the task id
is_accurate (note that it’s a String). Therefore,
is_accurate gets triggered. Great! But wait a minute, what happens to the other task? Does it stay with no status at all?
Here is precisely what you obtain by triggering the DAG:
choose_best_model and is_accurate succeed whereas
is_inaccurate is skipped. By default, downstream tasks not returned by the BranchPythonOperator are skipped! Also, tasks following skipped tasks are skipped as well. In the example, if you had a task after is_inaccurate, that task would have been skipped too. You change this behavior with the Trigger Rules.
How to use the BranchPythonOperator?
Implementing the BranchPythonOperator is easy:
from airflow import DAG from airflow.operators.python import BranchPythonOperator from airflow.operators.empty import EmptyOperator from datetime import datetime def _choose_best_model(): accuracy = 6 if accuracy > 5: return 'accurate' return 'inaccurate' with DAG('branching', start_date=datetime(2023, 1, 1), schedule='@daily', catchup=False): choose_best_model = BranchPythonOperator( task_id='choose_best_model', python_callable=_choose_best_model ) accurate = EmptyOperator( task_id='accurate' ) inaccurate = EmptyOperator( task_id='inaccurate' ) choose_best_model >> [accurate, inaccurate]
The Taskflow API version:
from airflow.decorators import dag, task from datetime import datetime @dag( start_date=datetime(2023, 1, 1), schedule='@daily', catchup=False ) def branching(): @task.branch def choose_best_model(accuracy=6): if accuracy > 5: return 'is_accurate' return 'is_inaccurate' @task def is_accurate(): pass @task def is_inaccurate(): pass choose_best_model() >> [is_accurate(), is_inaccurate()] branching()
Regardless of your implementation, the BranchPythonOperator always expects a Python function that returns a single task ID or a list of task IDs.
The BranchPythonOperator with XComs
Where does the accuracy come from? Usually, the value you use for the BranchPythonOperator’s condition comes from another task using XComs. If you don’t know an XCom, look at the tutorial here.
Let me show you an example:
from airflow.decorators import dag, task from datetime import datetime @dag( start_date=datetime(2023, 1, 1), schedule='@daily', catchup=False ) def branching(): @task def ml_a(): return 6 @task.branch def choose_best_model(accuracy): if accuracy > 5: return 'is_accurate' return 'is_inaccurate' @task def is_accurate(): pass @task def is_inaccurate(): pass accuracy = ml_a() choose_best_model(accuracy) branching()
ml_a returns the value six. That creates an XCOM with that value that is later shared with
choose_best_model. Looking at the XComs, you will see:
ml_a produces the first XCom with the key
return_value and the value
6. Interestingly, the BranchPythonOperator creates not one but two XComs! One with the key
skipmixin_key so the Airflow Scheduler knows what tasks to run and what to skip. Then, a second XCom that is optional.
To remove this optional XCom, I recommend you to set
@task.branch(do_xcom_push=False) def choose_best_model(accuracy): if accuracy > 5: return 'is_accurate' return 'is_inaccurate'
Now, what if you want to compare multiple accuracies? Pull multiple XComs at once!
# without Taskflow def _choose_best_model(ti=None): accuracies = ti.xcom_pull(key='accuracy', task_ids=['ml_a', 'ml_b', 'ml_c']) max_accuracy = max(accuracies) if max_accuracy > 5: return 'accurate' return 'inaccurate' with DAG(): choose_best_model = BranchPythonOperator( task_id='choose_best_model', python_callable=_choose_best_model ) # with Taskflow @task.branch(do_xcom_push=False) def choose_best_model(ml_a, ml_b, ml_c): max_accuracy = max([ml_a, ml_b, ml_c]) if max_accuracy > 5: return 'is_accurate' return 'is_inaccurate' choose_best_model(3, 4, 6)
There is something to pay attention to with the BranchPythonOperator. Suppose you have a task
store that you want to run regardless of the BranchPythonOperator’s result.
What do you think will happen to the task
store if run this DAG?
The task is skipped! Fortunatly, there is an easy way to solve that:
This trigger rule ensures that your task runs if all direct upstream tasks haven’t failed and at least one succeeded. Which is the case with
is_accurate. To apply that trigger rule to
# without Taskflow EmptyOperator(task_id='store', trigger_rule='none_failed_min_one_success') # with Taskflow @task(trigger_rule='none_failed_min_one_success') def store(): pass
Running the DAG again will give:
How to choose multiple tasks?
So far, you’ve been choosing a single task with the BranchPythonOperator. However, you can choose to run more than one. How? By returning a list of task IDs:
@task.branch() def choose_best_model(): return ['task_a', 'task_b', 'task_c']
BranchPythonOperator best practices
Here are a few best practices:
- Always return at least one task id, or you will get an error.
- Ensure that the tasks exist for the task ids returned by the BranchPythonOperator or you will get an error.
do_xcom_push=Falsedo avoid creating additional XComs.
- You will almost always use the trigger rule
none_failed_min_one_successfor the task merging your branches.
BranchPythonOperator in action!
Here is the video that covers the BranchPythonOperator:
That’s it about the BranchPythonOperator. I hope you enjoyed this tutorial. The BranchPythonOperator is useful as it allows solving new use cases for you data pipelines. If you want to learn more about Airflow, go check out my courses here.
Have a great day!