airflow_branchpythonoperator

Are you looking for a way to choose one task or another? Do you want to execute a task based on a condition? Do you have multiple tasks, but only one should be executed if a criterion is valid? You’ve come to the right place! The BranchPythonOperator does precisely what you are looking for. It’s common to have DAGs with different execution flows, and you want to follow only one, depending on a value or a condition. The BranchPythonOperator allows you to choose one branch among branches of your DAG. To better understand it, it is worth looking at the PythonOperator here. Before moving forward, if you are looking for a great course to master Airflow, check out this one. Ready? Go!

Use case

Using a use case is the best way to understand a concept. Let’s say you have the following data pipeline:

airflow dag with the branchpythonoperator

A quick explanation about the DAG: the first three tasks train machine learning models. Once they are complete, choose_best_model runs, then, either the task is_accurate or is_inaccurate is triggered. Which depends on the accuracy of the best machine-learning model. If the precision is above 5.0, we trigger is_accurate, otherwise, is_inaccurate

How can you do this? Is there a mechanism to choose one task or another according to a condition or a value?

Yes! With the BranchPythonOperator!

The BranchPythonOperator

The BranchPythonOperator allows you to follow a specific path in your DAG according to a condition. It evaluates the condition that is itself in a Python callable function. As for the PythonOperator, the BranchPythonOperator executes a Python function that returns a single task ID or a list of task IDs corresponding to the task(s) to run. Look at the example below.

Can you guess which task will run after choose_best_model?

BranchPythonOperator

is_accurate ! As the accuracy is greater than 5, the BranchPythonOperator returns the task id is_accurate (note that it’s a String). Therefore, is_accurate gets triggered. Great! But wait a minute, what happens to the other task? Does it stay with no status at all?

No!

Here is precisely what you obtain by triggering the DAG:

branchpythonperator status

choose_best_model and is_accurate succeed whereas is_inaccurate is skipped. By default, downstream tasks not returned by the BranchPythonOperator are skipped! Also, tasks following skipped tasks are skipped as well. In the example, if you had a task after is_inaccurate, that task would have been skipped too. You change this behavior with the Trigger Rules.

How to use the BranchPythonOperator?

Implementing the BranchPythonOperator is easy:

from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.operators.empty import EmptyOperator
from datetime import datetime

def _choose_best_model():
  accuracy = 6
  if accuracy > 5:
    return 'accurate'
  return 'inaccurate'

with DAG('branching', start_date=datetime(2023, 1, 1), schedule='@daily', catchup=False):
  choose_best_model = BranchPythonOperator(
    task_id='choose_best_model',
    python_callable=_choose_best_model
  )

  accurate = EmptyOperator(
    task_id='accurate'
  )

  inaccurate = EmptyOperator(
    task_id='inaccurate'
  )

  choose_best_model >> [accurate, inaccurate]

The Taskflow API version:

from airflow.decorators import dag, task
from datetime import datetime

@dag(
  start_date=datetime(2023, 1, 1),
  schedule='@daily',
  catchup=False
)
def branching():

  @task.branch
  def choose_best_model(accuracy=6):
    if accuracy > 5:
      return 'is_accurate'
    return 'is_inaccurate'

  @task
  def is_accurate():
    pass

  @task
  def is_inaccurate():
    pass

  choose_best_model() >> [is_accurate(), is_inaccurate()]

branching()

Regardless of your implementation, the BranchPythonOperator always expects a Python function that returns a single task ID or a list of task IDs.

The BranchPythonOperator with XComs

Where does the accuracy come from? Usually, the value you use for the BranchPythonOperator’s condition comes from another task using XComs. If you don’t know an XCom, look at the tutorial here.

Let me show you an example:

from airflow.decorators import dag, task
from datetime import datetime

@dag(
  start_date=datetime(2023, 1, 1),
  schedule='@daily',
  catchup=False
)
def branching():

  @task
  def ml_a():
    return 6

  @task.branch
  def choose_best_model(accuracy):
    if accuracy > 5:
      return 'is_accurate'
    return 'is_inaccurate'

  @task
  def is_accurate():
    pass

  @task
  def is_inaccurate():
    pass

  accuracy = ml_a()
  choose_best_model(accuracy)

branching()

The task ml_a returns the value six. That creates an XCOM with that value that is later shared with choose_best_model. Looking at the XComs, you will see:

airflow xcom branchpythonoperator

ml_a produces the first XCom with the key return_value and the value 6. Interestingly, the BranchPythonOperator creates not one but two XComs! One with the key skipmixin_key so the Airflow Scheduler knows what tasks to run and what to skip. Then, a second XCom that is optional.

To remove this optional XCom, I recommend you to set do_xcom_push=False.

@task.branch(do_xcom_push=False)
def choose_best_model(accuracy):
  if accuracy > 5:
    return 'is_accurate'
  return 'is_inaccurate'

Now, what if you want to compare multiple accuracies? Pull multiple XComs at once!

# without Taskflow

def _choose_best_model(ti=None):
  accuracies = ti.xcom_pull(key='accuracy', task_ids=['ml_a', 'ml_b', 'ml_c'])
  max_accuracy = max(accuracies)
  if max_accuracy > 5:
    return 'accurate'
  return 'inaccurate'

with DAG():
  choose_best_model = BranchPythonOperator(
    task_id='choose_best_model',
    python_callable=_choose_best_model
  )

# with Taskflow

@task.branch(do_xcom_push=False)
def choose_best_model(ml_a, ml_b, ml_c):
  max_accuracy = max([ml_a, ml_b, ml_c])
  if max_accuracy > 5:
    return 'is_accurate'
  return 'is_inaccurate'

choose_best_model(3, 4, 6)

The pitfall

There is something to pay attention to with the BranchPythonOperator. Suppose you have a task store that you want to run regardless of the BranchPythonOperator’s result.

What do you think will happen to the task store if run this DAG?

The task is skipped! Fortunatly, there is an easy way to solve that: none_failed_min_one_success.

This trigger rule ensures that your task runs if all direct upstream tasks haven’t failed and at least one succeeded. Which is the case with is_accurate. To apply that trigger rule to store:

# without Taskflow

EmptyOperator(task_id='store', trigger_rule='none_failed_min_one_success')

# with Taskflow

@task(trigger_rule='none_failed_min_one_success')
def store():
  pass

Running the DAG again will give:

How to choose multiple tasks?

So far, you’ve been choosing a single task with the BranchPythonOperator. However, you can choose to run more than one. How? By returning a list of task IDs:

@task.branch()
def choose_best_model():
  return ['task_a', 'task_b', 'task_c']

Simple.

BranchPythonOperator best practices

Here are a few best practices:

  • Always return at least one task id, or you will get an error.
  • Ensure that the tasks exist for the task ids returned by the BranchPythonOperator or you will get an error.
  • Set do_xcom_push=False do avoid creating additional XComs.
  • You will almost always use the trigger rule none_failed_min_one_success for the task merging your branches.

BranchPythonOperator in action!

Here is the video that covers the BranchPythonOperator:

Conclusion

That’s it about the BranchPythonOperator. I hope you enjoyed this tutorial. The BranchPythonOperator is useful as it allows solving new use cases for you data pipelines. If you want to learn more about Airflow, go check out my courses here.

Have a great day!

Leave a Reply

Your email address will not be published. Required fields are marked *