Refactoring Python if Statements

Bryan Lott published on
10 min, 1946 words

Note: I'm not recommending the following refactoring to everyone, nor am I saying it was even a good idea. It did, however, work out quite well for myself and my team.

if pred_1 == 'first':
    output_data = output_data._replace(output_data_1=pred_3, output_data_2=pred_4)
    if pred_2 and self.state_management_3 in pred_2:
        output_data = merge_namedtuple(output_data, self.process_data_1())
    elif pred_3 == state_management_2.state1:
        if pred_4:
            output_data = merge_namedtuple(output_data, self.process_data_2())
        else:
            output_data = merge_namedtuple(output_data, self.process_data_3())
            if orig_data_2 in [state_management_1.STATE1, state_management_1.STATE2]:
                output_data = output_data._replace(output_data_1=orig_data_1, output_data_2=orig_data_2)
    elif pred_3 in [state_management_2.state2, state_management_2.state3]:
        if pred_4:
            output_data = merge_namedtuple(output_data, self.process_data_2())
        else:
            output_data = merge_namedtuple(output_data, self.process_data_3())
            output_data = output_data._replace(output_data_1=state_management_2.state1, output_data_2=state_management_1.STATE1)
elif pred_1 == 'second':
    if pred_2 and self.state_management_3 in pred_2:
        output_data = merge_namedtuple(output_data, self.process_data_1())
    elif pred_3 == state_management_2.state2:
        if pred_4 in [state_management_1.STATE3, state_management_1.STATE4, state_management_1.STATE5]:
            if pred5:
                output_data = merge_namedtuple(output_data, self.process_data_4())
        elif pred_4 in [state_management_1.STATE7, state_management_1.STATE8, state_management_1.STATE6, state_management_1.STATE9]:
            output_data = merge_namedtuple(output_data, self.process_data_2())
            if pred_4 == output_data.state_2:
                output_data = output_data._replace(state_1=None, state_2=None)
    elif pred_3 == state_management_2.state1 and not pred_4 and pred5:
        output_data = merge_namedtuple(output_data, self.process_data_4())
        output_data = output_data._replace(state_1=output_data.sub_data.state_1, state_2=output_data.sub_data.state_2)
    elif pred_3 == state_management_2.state1 and pred_4 and pred6:
        output_data = merge_namedtuple(output_data, self.process_data_2())

Still with me?

...

Yeah, pretty sure I lost about 90% of you with that if-block.

This isn't the actual code that I ended up refactoring, but it's convoluted enough to give the same gut-wrenching feeling from it. I've also done my best to anonymize it so to protect the guilty.

The White Whale

Story time!

I ran into this block of code my 3rd or 4th day on the job and recoiled in horror at it. At the same time, it fascinated me that this was the encoding of some admittedly convoluted but mission-critical business logic into Python. The "standard" refactor for an if-elif-elif-elif block in Python is to reach for a dictionary (hashmap for those not familiar with Python), encode the predicates into the keys and the functions to be called in the values.

Great... but... does that mean we need a nested dictionary here and, more importantly, would that actually make the code more readable?

# we start with the first "layer":
{'first': {}
 'second': {}}

# cool, that was easy... now the second "layer":
{'first': {pred_2: {}
           pred_3: {}
           }}

#uhhhh... we're going off the rails here...

The problem here comes down to the different layers not being symmetrical. They all have various conditions and an amalgamation of predicates, states, and even types (some are lists, others are enums, others are booleans).

So... un-refactorable?

I'll be honest, at this point I was wondering if it really was un-refactorable.

The One That Got Away

I left the code as-is for quite some time. Months actually. It was always sitting in the back of my head though, like an itch you can't quite scratch. Refactoring this bit of code was always on my to-do list, but I could never seem to find the time to really think about it. Add on to this, it was mission-critical business logic encoding. That always carries an inherent risk when you're refactoring. There were tests around the code, but they ended up having major holes.

A Light on the Horizon

Tests...

The golden rule of refactoring is that you never do it without a good set of tests to verify old and new logic.

Hah! Of course! Write unit tests around the entire if-statement-of-crazy. If nothing else, I figured this would give me a deeper understanding of the business logic and the all-crucial "why" to some of the crazy logic. While I was doing this, I sat down with a number of my coworkers to fill in the gaps in my knowledge. As always, ask the people that were there first!

Darkness

While I was chatting and interviewing my coworkers about this bit of code, there was an ominous phrase that kept popping up...

This isn't the first time someone's tried to refactor that if statement.

...great...

After some introspection and some poking and prodding at my ego, I finally decided that it'd make a fun side project to try. If I didn't succeed, no worries, at least I had added a ton of tests around it so the next time it needed to change, we could be safe about changing it!

The Hunt

After a number of false starts, thrown out code, and late nights I hit upon an interesting idea.

To look up a key in a Python dictionary, the get method does two things. First, it hashes the key you're looking up and compares it with all the hashed keys in the dictionary. The second thing it does an equality comparison of the unhashed lookup key and unhashed dictionary key. I assume the second round is to prevent hash collisions, but I honestly don't know.

At the same time, I was looking into using namedtuples as a way to encode multiple values into the key. They have the advantage of being immutable and have easy methods to get and set values (important for readability of the final code).

Shot and a Miss

That was my next thought, encode all the logic into a namedtuple "template" and compare the actual data against the keys in the dictionary.

Well, there's a problem with that. At runtime, we don't always have all the data necessary for a given branch of the codebase and, in fact, some of the data may be set, but it may not matter.

The upshot of this is that some branches of the if statement have an importance priority of the data looked up that's different than other branches.

A New Drawing Board

So, again, I was stumped. If you compare two namedtuples with one another, all the data has to match or they aren't considered equal. Even if I could solve that problem, there was still the problem of the hash lookup in the dictionary. If two namedtuples don't have the exact same data in them, their hashes won't match and thus looking it up in a dictionary will always fail.

Digging into Python's hash magic methods, I hit upon a simple idea. Make all the hashes the same. That way Python is forced to use the equality function on the namedtuples when it compares them against the dictionary keys.

A New Hope

Thankfully everything in Python has public visibility. The upshot of this is it means that I can create my own custom namedtuple that provides all the benefits of immutability as well as the readability of the fields, along with a custom equality function.

What this all boils down to is this bit of code for creating the Template for the dictionary. These are meant to be filled with data and used as the keys in the lookup dictionary. In addition, this namedtuple is used to fill out the data to be compared against the keys in the dictionary.

# Decision matrix used in determining which data process to kick off
Template = namedtuple("Template", ['pred1', 'pred2', 'pred3', 'pred4', 'pred5', 'pred6', 'pred7'])
# Default any non-filled out fields to None
Template.__new__.__defaults__ = (None,) * len(Template._fields)
# Disable the hash function
Template.__hash__ = lambda self: 1
# Set our custom equality function
Template.__eq__ = rd_eq

Custom Equality

"Custom Equality" is such a great term, and it really threw the more science-minded folks for a loop!

To ease their trepidation, I told them that I wasn't really setting up a custom equality, I was making sure that it works as a template rather than an exact match. If data shows up in the template, it's important, if it doesn't, then it's not and can be anything.

Note that the bottom code is more complex than just a simple template. If you notice in the original code, there's a lot of if x in ['a', 'b', 'c', 'd'] thrown around along with at least one if ['a', 'b'] == ['a', 'b', 'c'] style. That makes comparing against a template messy.

def rd_eq(template, other):
    """Compare fields in template to fields in other if template.field is not None
    template and other are expected to be namedtuples, tuples, or lists.
    Order of the elements in each *does* matter.
    If an element is a list, make sure that all the subelements of that list appear in the template.
    :param template: namedtuple, non-null values are used for comparison
    :param other: namedtuple, all/some/none values can be filled out and will be compared with non-null values from template
    :return: true if the non-null values in template "match" in other, false otherwise
    """
    # compare template type to other type
    if template.__class__ != other.__class__ or len(template) != len(other):
        return False
    for template_val, other_val in zip(template, other):
        if template_val is not None:  # only "filled out" template values matter
            # compare element types
            if isinstance(template_val, list):
                if isinstance(other_val, list):
                    # if all the values in the "template" don't exist in the "other", return false
                    if len(set(template_val).intersection(set(other_val))) != len(set(template_val)):
                        return False
                else:
                    if other_val not in template_val:
                        return False
            elif template_val != other_val:
                return False
    return True

It's worth reading through a few times, but in a nutshell:

  • If the value of a field in the template is filled in, try to compare it
  • If the value of the field is a list and the value being compared against isn't, make sure the value being compared against is in the list.
  • If the value of the field is a list and the value being compared against is also a list, make sure the values in the template exist in the value being compared against.

The Final Frontier

The last step in the crazy mess is to get our custom dictionary up and running.

First, we set up the list of templates in order from most-restrictive to least.

# Just a coupling of data function and basic log message to send
FnRunner = namedtuple("FnData", ['fn', 'log_message'])

decisions = [
    (Template(pred2=['val1']),
     FnRunner(fn=self.process_data_2, log_message='Running data process 2')),
    (Template(pred1='first',
              pred3=[state_management_2.state1, state_management_2.state3, state_management_2.state4],
                    pred4=True),
     FnRunner(fn=self.process_data_1, log_message='Running data process 1'))
# <snip>
]

Then we shove it in an ordered dictionary to preserve the restrictiveness (fun fact, ordered dictionaries are compared in-order)

self.decision_tree = OrderedDict(decisions)

From there, we get the data to compare against and shove it into a Template namedtuple and fire off the comparison:

# get the realtime data to compare against the templates
runner_data = Template(get_decision_data())

# create a noop function if we don't find a lookup in self.decision_tree
runner_noop = FnRunner(fn=lambda *args, **kwargs: RunnerResults(),
                       log_message="Nothing to do, ignoring")

# lookup the correct data runner function to run
# this is where the equality is called and everything is bashed against one another
data_runner_fn = self.decision_tree.get(runner_data, runner_noop)

# finally, run the function we found
output_data, log_message = data_runner_fn()

Conclusion

Once I was done with all of this crazyness and had a working prototype that passed at least some of the tests I bounced it off of the team. Their reactions were a pretty mixed bag, everything from disbelief that I had actually refactored the if-statement-of-crazy to curiosity at how I had done it.

To be honest, I'm still not 100% satisfied with the result, but the feedback in the following weeks was nothing but positive. This is telling because around the same time, the business logic now encoded in the dictionary changed drastically. This change was estimated to take around 2-4 weeks in the old if statement. In the new one, it was done in 2 days.

I think the increase in productivity speaks for itself.