How to use the BranchPythonOperator

Looking for a way to choose one task or another? You want to execute a task based on a condition? You multiple tasks but only one should be executed if a criterion is true? You’ve come to the right place! The BranchPythonOperator does exactly what you are looking for. It’s common to have DAGs with different execution flows and you want to follow only one according to a value or a condition. The BranchPythonOperator allows you to choose one branch among branches of your DAG. To better understanding it, it worths to take a look at the PythonOperator right here. Oh and before moving forward, if you are looking for a great course to master Airflow, check it this one. Ready? Go!

Use Case

As usual, best way to understand a feature/concept is to have a use case. Let’s say, you have the following data pipeline:

Dag Example for the BranchPythonOeprator

Quick explanation about the DAG. The first three tasks are training machine learning models. Once they all complete, “Choosing Best ML” task is getting triggered. Then, either the task “is accurate” or “is inaccurate” should get executed according to the accuracy of the best ML model. Let’s say, if the accuracy is above 5.0 we trigger “Is accurate”, otherwise, “Is inaccurate” is run. How can you do this?Is there a mechanism to achieve this?

Yes there is! And you know it, otherwise you wouldn’t be there 😅

Give a warm welcome to the BranchPythonOperator!

The BranchPythonOperator

Ok, we are to happy to meet with the BranchPythonOperator but what does it do?

The BranchPythonOperator allows to follow a specific path according to a condition. That condition is evaluated in a python callable function. Like the PythonOperator, the BranchPythonOperator executes a Python function returning the task id of the next task to execute. Look at the example below. Can you guess which task is executed next? “Is accurate” or “Is inaccurate”?

How the BranchPythonOperator works

“Is accurate” ! Since accuracy is greater than 5, the BranchPythonOperator returns the task id “Is accurate. Therefore, “Is accurate” gets triggered. Great! But wait a minute… what happens for the other task? Does it stay with no status at all?

No!

Here is exactly what you obtain by triggering the DAG:

BranchPythonOperator with status
Task’s status after triggering the DAG

In our case, “Choosing Best ML” and “Is accurate” have succeeded whereas “Is inaccurate” has been skipped. Consequently, downstream tasks that are not returned by the BranchPythonOperator get skipped! Also, tasks following skipped tasks are skipped as well. In the example, if you put a task after “Is inaccurate”, that task will be skipped.

In practice

All right, you know the BranchPythonOperator and you know how it works. What about a bit of code to implement it?

from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.operators.dummy import DummyOperator
from datetime import datetime
default_args = {
    'start_date': datetime(2020, 1, 1)
}
def _choose_best_model():
    accuracy = 6
    if accuracy > 5:
            return 'accurate'
    return 'inaccurate'
with DAG('branching', schedule_interval='@daily', default_args=default_args, catchup=False) as dag:
    choose_best_model = BranchPythonOperator(
        task_id='choose_best_model',
        python_callable=_choose_best_model
    )
    accurate = DummyOperator(
        task_id='accurate'
    )
    inaccurate = DummyOperator(
        task_id='inaccurate'
    )
    choose_best_model >> [accurate, inaccurate]
    

The code above gives you the same data pipeline as shown before. To run the code, install Docker on your computer. Then, go to my beautiful repository to get the docker compose file that will help you running Airflow on your computer. Airflow 2.0, not 1.10.14 😎 Clone the repo, go into it. Create a file branching.py in the folder airflow-data/dags. Copy paste the code in that file and execute the command docker-compose up -d in the folder docker-airflow.

Pay attention to the arguments of the BranchPythonOperator. It expects a task_id and a python_callable function. If you take a look at the python function _choose_best_result(). You can see the condition returning the task id, either “accurate” or “inaccurate”. Here, the function returns “accurate”, therefore, the next task to tigger is “accurate”.

Behind the scene of the BranchPythonOperator

The BranchPythonOperator inherits from the PythonOperator. As a result, parameters of PythonOperator are accessible in the BranchPythonOperator. You can access the context of the task instance to pull XComs. You can give additional arguments through op_kwargs and op_args. If you don’t know what I’m talking about, check this article.

For instance, let’s say you want to fetch the accuracy by pulling a XCom, you can do that:

def _choose_best_model(ti):
    accuracy = ti.xcom_pull(key='accuracy', task_ids=['task_where_accuracy_was pushed'])
    if accuracy > 5:
            return 'accurate'
    return 'inaccurate'

ti is the task instance object required for accessing your XComs. Notice that in Airflow 2.0, you don’t have to use the parameter provide_context anymore. The BranchPythonOperator stays the same:

choose_best_model = BranchPythonOperator(
    task_id='choose_best_model',
    python_callable=_choose_best_model
)

Last but not least, each time you trigger the BranchPythonOperator, 2 XComs are created:

branchpythonoperator
XComs generated by the BranchPythonOperator

The first XCom indicates which task to follow next. The second XCom indicates the value returned by the python callable function (default behaviour, any returned value, creates a XCom). Be careful! XComs are not automatically removed. Therefore, the more BranchPythonOperators you trigger, the more XComs will be stored automatically. You can reduce that number to one, by setting the parameter do_xcom_push=False. Again, if you don’t know what this parameter is, go check out my other post.

choose_best_model = BranchPythonOperator(
    task_id='choose_best_model',
    python_callable=_choose_best_model,
    do_xcom_push=False
)

In any cases, don’t forget to remove your XComs time to time as they are NOT automatically removed. If you have hundreds of DAGs, your DB won’t be happy.

The BranchPythonOperator and the Skipped status

When you trigger the BranchPythonOperator, one task is trigger next and the others are skipped. There is one little catch that you have to be aware of. Let’s say you have the following DAG.

The problem with the skipped status

Can you guess, what will be the status of the task “storing”?

The task will be skipped. YES, skipped!

Why? Because, by default “storing” expects that all of its parents succeed before getting triggered. Since one of its parent “Is inaccurate”, has been skipped, the task gets skipped as well. The status is propagated to the downstream tasks. How can we solve this?

By changing the trigger rule of the “storing” task. I won’t go into the details of trigger rules, but trigger rules allow to modify the way your tasks get triggered. The rules by which they get triggered. By default, it’s “all_success”, meaning, a task is triggered when all upstream tasks succeeded. In that specific case, we have to change that.

    storing = DummyOperator(
        task_id='storing',
        trigger_rule='none_failed_or_skipped'
    )
    choose_best_model >> [accurate, inaccurate] >> storing

Notice the trigger rule in “storing”. That means, the task gets triggered if at least one parent succeeds.

How to execute multiple tasks with the BranchPythonOperator

So far, we’ve seen how to execute one task after the BranchPythonOperator. What about if we want to execute many?

Simple!

You just have to return multiple task ids!

Here is an example

def _choose_best_model():
    accuracy = 9
    if accuracy > 7:
            return ['super_accurate', 'accurate']
    elif accuracy > 5:
            return 'accurate'
    return 'inaccurate'
with DAG('branching', schedule_interval='@daily', default_args=default_args, catchup=False) as dag:
    super_accurate = DummyOperator(
        task_id='super_accurate'
    )
    choose_best_model >> [super_accurate, accurate, inaccurate] >> storing

As you can see from the code, in _choose_best_mode, we return two task ids “super_accurate” and “accurate” if the accuracy is greater than 7. That gives us:

The BranchPythonOperator with multiple task ids

The BranchPythonOperator in action!

Conclusion

That’s it about the BranchPythonOperator. I hope you really enjoyed what you’ve learned. The BranchPythonOperator is super useful. Use it to make complex data pipelines and be careful with the little catches that we saw during that article. If you want to learn more about Airflow, go check my course The Complete Hands-On Introduction to Apache Airflow right here. Or if you already know Airflow and want to go way much further, enrol in my 12 hours course here.

Have a great day! 🙂 

Interested by learning more? Stay tuned and get special promotions!

2 thoughts on “How to use the BranchPythonOperator”

  1. Hi Marc,

    Thank you very much. Your explanation with diagram about branchPythonOperator is
    very clear. Amazing. Thank you for your help. I need to write something like this.
    I need to bookmark your website.

  2. Hi Marc,

    Amazing explanation and thank you for your guide.
    I have a doubt, how can I use a branchpythonoperaton and return the task_id but also a xcom object like a df so the selected task can use that object?

Leave a Comment

Your email address will not be published. Required fields are marked *