How do you reason about a probabilistic distributed system?

How do you reason about a probabilistic distributed system?

In which I am stunted upon by coin flips

Wasn’t too long ago that I felt pretty good about my knowledge of distributed systems. All someone really needed in order to understand them, I thought, was a thorough understanding of the paxos protocol and a willingness to reshape your brain in the image of TLA⁺. Maybe add a dash of conflict-free replicated datatypes, just so you know what “eventual consistency” means. Past that it’s just some optimizations and mashups which come easily to your TLA⁺-addled brain.

This belief proved surprisingly robust over a number of years, even surviving an aborted attempt at analyzing the Nano cryptocurrency. It was only after encountering the snowflake family of consensus protocols that I realized my theory just wasn’t up to the challenge. The issue was probability: snowflake protocols reach consensus by iteratively polling sets of other nodes at random, and the argument that consensus is eventually reached is a statistical argument deriving an upper bound on the probability of failure.

I didn’t dislike probability & statistics, I just tried to keep my distance as much as possible. All the algorithms in distributed systems I’d encountered so far involved nondeterminism, sure, but not probability. I’d assumed nondeterminism was just a more flexible way of reasoning about probability. This idea of mine would prove to be a source of great unnecessary confusion as I learned the art of reasoning about probabilistic distributed systems, so I’ll do you a favor and give you the core lesson of this entire post in one sentence:

You cannot model probability with nondeterminism, and you cannot model nondeterminism with probability.

Models: they’re good, folks!

Have you ever been writing some multithreaded code, happily plugging in a mutex here, a semaphore there, or even just using some nice message-passing primitives to make your threads all get along? Maybe you’ll be familiar, then, with what often comes next. A scratch at the back of your mind, a thought - “oh, wait…” - as you realize something weird will happen if thread \(A\) manages to reach some step before thread \(B\) has finished its assigned task. No worries! Slap on another WaitHandle, problem solved. Except the problem wasn’t solved. Not really. You consider it a bit more - what if thread \(C\) comes in with a message at this inopportune time? You realize with dawning horror you’re actually tracing cracks in the foundation. Patch them with mutexes! Semaphores! Anything! Alas, you are beyond help. It’s around this time that your brain, catching a glimpse of the infinite plane of combinatorial state explosion, wisely ducks its head back down for the day and leaves you with a woozy, fuzzy, clenching feeling for having the gall to ask it to fix all this.

I’ve felt like this many times, and formal models are the only cure I’ve ever found. Your brain isn’t built to hold massive state spaces in its working memory, so don’t even try. Let a model checking program churn through all those states to find the bugs. At this point I won’t even touch a multithreaded program or distributed system without whipping up a quick TLA⁺ spec of its desired workings. I just specify all the possible events in the system, how those events affect the system state, what things I always want to remain true (the invariants), then let the model checker rip. In TLA⁺, we model concurrency with nondeterminism; in a concurrent system, we have no idea whether thread \(A\) will execute a step before thread \(B\). We can represent this with a nondeterministic state machine as follows:

A nondeterministic state machine. Start state 0 has transition labeled A to state 1, and labeled B to state 2. State 1 has a transition labeled B to state 3. State 2 has a transition labeled A to state 4.

So you’ll be in state \(s_3\) if thread \(A\) executes its step before thread \(B\), and state \(s_4\) if thread \(B\) executes its step before thread \(A\). Maybe \(s_3\) and \(s_4\) are even the same state, who knows. The model checker will explore both of these possible execution orders, and in a well-designed concurrent system we should never end up in a bad state just because of a certain order of execution.

Readers might wonder how exactly this models concurrency, where steps can happen uh, concurrently. The short answer is you have to ensure all the steps in your model are atomic or independent: either impossible in the real world for two of your steps to happen at the exact same time (for example, by assuming use of a lower-level hardware synchronization primitive) or impossible for execution of one step to directly affect the same variables as another step (for example, if the steps are executed on different computers within a timespan less than the network latency between them). If the steps in your model satisfy this requirement, checking all possible execution orders accurately models concurrency. If they don’t, you need to break the steps down further so they do. This model nicely captures & exposes all that is difficult about concurrency.

What questions can we ask about this sort of model? The most important questions are reachability queries - can we reach a bad state (two caches disagreeing on a value, deadlock, dogs & cats living together, etc.) from the starting state? These questions are called safety properties, and if they are answered in the negative then the system is safe. Another type of query is something like “are we always guaranteed to eventually end up in a good state?” These are called liveness properties. Turns out these two types of questions can get you pretty far in concurrent & distributed systems. Definitely far enough to make a whole career out of writing rock-solid software in places others would falter. However, these questions also have a drawback: their answers are absolute. True or false. No probability involved, no room for nuance.

What if one of the threads flips a coin, and if it’s heads it does one thing, tails another? Entire state spaces, bifurcated by a probabilistic event. Maybe those state spaces contain further coin flips, or other types of randomness. In this system your questions might change from the form “is it possible to reach a bad state” to “what is the probability of reaching a bad state?” Unfortunately these types of questions just cannot be answered within the nondeterministic model used above. You cannot model probability with nondeterminism. We must use a new type of model, a state machine that handles probability directly.

Leaving the beautiful pure discrete realm

TLA⁺ can’t handle probability at this time, so we’d have to use a specialized modeling language like PRISM which handles probabilistic state machines. Let’s look at the standard hello-world example for probabilistic state machines: the 1976 Knuth-Yao method for simulating a fair six-sided die with a series of coin flips. This is really quite a neat problem and I encourage you to ponder it for a second before seeing how they did it! Any sequence of \(n\) coin flips will give you an event which has probability \(\frac{1}{2^n}\) of occurring. Simulating a fair six-sided die requires generating an event with probability \(\frac{1}{6}\) of occurring. You might then reason this problem is impossible, because you cannot evenly divide \(2^n\) by \(6\) for any \(n\) (this follows from the uniqueness of prime factorization). Indeed, there is no way to simulate a six-sided die with a finite number of coin flips. We have to use an algorithm which is not guaranteed to ever terminate, although vanishingly unlikely not to do so. Here it is:

A state machine where each transition is labeled with either H or T for heads or tails. The state machine fans out like a four-level full binary tree from the start state, with the exception of the paths only flipping heads or only flipping tails. Starting from state 0 H goes to state 1, then H goes to state 3; however, from state 3 only T goes to one of the six termination states while H goes back to state 1 to form an infinite loop. There is an analogous loop on the T half of the tree.

You can see that if you somehow only flip heads, or only flip tails, you’ll never reach one of the accepting states (here labeled with the die number they represent). There are some fun ways to contextualize the probabilities of you only flipping heads or tails a certain number of times in a row. For example, there are only around \(2^{268}\) subatomic particles in the observable universe; if you manage to flip heads 268 times in a row, that’s the same as picking the correct subatomic particle out of a universe-wide random draw. Maybe go look at the Hubble Ultra-Deep Field as you ponder this probability. Another way is assuming you’re between the ages of 25-34 and live in the USA, your annual all-cause mortality rate is about 129/100,000. Assuming deaths are uniformly distributed throughout the year, this means your chances of dying today are about 1 in 283,000. This is just 18-19 all-heads or all-tails coin flips in a row. What I’m saying is that you really, really shouldn’t worry about having to flip the coin very many times.

This probabilistic state machine model we’ve created is called a Discrete-Time Markov Chain, or DTMC. In DTMCs, every transition has an associated probability and the probabilities of all out-flowing transitions must sum to one for every state (accepting states can be thought to have a loopback with probability 1). The above rumination on termination probabilities is summed up in the long run theorem: in the long run, every path in a finite Markov chain ends in an absorbing state, which is a state (or group of states) from which there is an entrance but no exit. What questions can we ask of DTMCs? The most interesting one - the reason why we’re here - is “what is the probability of eventually reaching a certain state?” The long run theorem tells us we have a 100% chance of eventually reaching one of the Knuth-Yao state machine’s accepting states. What about the probability of ending up in a specific accepting state? It should be \(\frac{1}{6}\). Is it?

Let’s try to reason this out with basic probability. What are the chances of ending up in accepting state \(1\)? Well, you can get there by flipping \(HHT\). The probability of that happening is \(\frac{1}{2} \cdot \frac{1}{2} \cdot \frac{1}{2} = \frac{1}{8} \). But you can also get there by flipping \(HHHHT\). The probability of that happening is \(\frac{1}{2^5} = \frac{1}{32} \). We have to add this to the first probability, so now our probability is \(\frac{1}{8} + \frac{1}{32} = \frac{5}{32}\). But we can also get there by flipping \(HHHHHHT\), with probability \(\frac{1}{2^7} = \frac{1}{128}\). I’m sure you can see where this is going. We’re dealing with something truly horrific, an infinite sum of infinite products. If we repeat this process a few more times we can see it numerically converging to \(\frac{1}{6}\), or \(0.16666\ldots\) but how do we get a nice closed-form solution? To avoid sending my readers through a math-heavy meatgrinder from which few would emerge, I’ve pushed the algorithm’s explanation to this post’s appendix (or you can just get PRISM to calculate it and never think about this again). For now just know you can write the DTMC as a matrix and the problem reduces to solving a simple system of linear equations. The Knuth-Yao state machine indeed accurately simulates a fair six-sided die.

Okay, so now we can model probabilistic systems. Are we done? Sadly no. Remember the other half of this post’s lesson: you cannot model nondeterminism with probability. Let’s go back to our concurrent system where either thread \(A\) or thread \(B\) can take a step; we don’t know which will execute first. How do we model this in a DTMC? “Easy!” you might say. “Each thread has a \(\frac{1}{2}\) chance of going first, or a \(\frac{1}{n}\) chance if there are \(n\) threads in the system. Plug in these probabilities and fire up the model checker!” Bzzt. Wrong. This is probably the most conceptually-difficult part of this post. Basically, by modeling your threads in this way, you are making an assumption about the thread scheduler. Your assumptions are the foundation of your model; they must accurately correspond to the system you’re reasoning about. If they don’t, your model is useless. It’ll be able to generate a lot of nice-looking numbers that hold absolutely no relation to reality. In this case we can’t assume, in general, that the thread scheduler will assign each thread processor time with uniform probability. That assumption makes even less sense in a fully distributed system with processes running on separate computers connected by a network.

So what can we assume about the scheduler? Well, nothing - that’s why we need nondeterminism. It enables us to explore what happens under every possible scheduling system. So how do we model that in a DTMC? We can’t. We need something new. Something that combines probability with the power of nondeterminism.

An automata to surpass DTMCs

Meet the Markov Decision Process (MDP). It’s a prickly entity, prone to sucking up your mind’s comprehension ability as you muddle through treatises on probabilistic temporal logic, and your computer’s memory as the model checker explores its depths. At first glance, it’s literally just a DTMC with nondeterministic steps. The tricky part is how that changes what questions you can ask about the model.

Let’s think of a very simple system which uses both probability and nondeterminism. We have two threads, \(A\) and \(B\), which are scheduled nondeterministically. We also have two coins, an unfair one \(U\) with a \(\frac{3}{4}\) chance of landing on heads and a \(\frac{1}{4}\) chance of landing on tails, and a fair one \(F\) with a 50/50 chance of landing on heads or tails. Whichever thread goes first grabs the unfair coin \(U\) and gives it a flip. The thread going second grabs the remaining coin \(F\) and gives it a flip itself. Here’s how this looks as a MDP; we label states with \(A_x B_y\), where \(x\) and \(y\) are one of \(\_\), \(H\), or \(T\) to represent not-yet-flipped, heads, or tails respectively for threads \(A\) and \(B\):

A MDP state machine where each transition between states is a fork. The two nondeterministic transitions from the start state are labeled A and B, each forking into two further branches of probability 3/4 and 1/4 respectively leading into four states in total. Each of those four states is then connected to two out of four possible final states by transitions with probability 3/4 and 1/4.

Each MDP state transition has two components: a label so the nondeterministic scheduler can pick (“decide” to take) that step, and some number of transitions forking off of this with associated probabilities (which must sum to 1). We can see how this is a hybrid of DTMCs and nondeterministic state machines: if each transition only has a single fork with probability 1, the MDP reduces to our basic nondeterministic state machine; if each state only has a single outgoing transition with some associated probabilities, the MDP reduces to a DTMC.

Here’s how our MDP looks in PRISM code:

mdp

module StrangeCoinGame
	// Not-yet-flipped: 0, heads: 1, tails: 2
	AFlip : [0 .. 2] init 0;
	BFlip : [0 .. 2] init 0;

	// Choose one of the threads to go first and flip the unfair coin
	[] AFlip = 0 & BFlip = 0 -> 0.75 : (AFlip' = 1) + 0.25 : (AFlip' = 2);
	[] AFlip = 0 & BFlip = 0 -> 0.75 : (BFlip' = 1) + 0.25 : (BFlip' = 2);

	// The second thread flips the fair coin
	[] AFlip != 0 & BFlip = 0 -> 0.5 : (BFlip' = 1) + 0.5 : (BFlip' = 2);
	[] AFlip = 0 & BFlip != 0 -> 0.5 : (AFlip' = 1) + 0.5 : (AFlip' = 2);

	// Loopback in accepting states
	[] AFlip != 0 & BFlip != 0 -> (AFlip' = AFlip) & (BFlip' = BFlip);
endmodule

Now for the grand finale! What questions can we ask of this model? Let’s start with something simple, like “what is the probability of reaching state \(A_H B_T\)?” I encourage readers to mull this one over for a bit. If \(A\) goes before \(B\), the probability of reaching state \(A_H B_T\) is \(\frac{3}{4} \cdot \frac{1}{2} = \frac{3}{8}\). However, if \(B\) goes before \(A\), the probability of reaching state \(A_H B_T\) is \(\frac{1}{4} \cdot \frac{1}{2} = \frac{1}{8}\). Remember we don’t know/assume anything about the probability of \(A\) going before \(B\) or vice-versa. How, then, are we supposed to answer the question “what is the probability of reaching state \(A_H B_T\)?” We can’t! It’s an invalid question! The only questions we can ask about MDPs are questions about the maximum or minimum probabilities of reaching a state, across all possible execution orders. This makes more sense when you consider how MDPs are model checked, which is by generating a DTMC for every possible order of execution (expensive!) then finding the reachability probability in each of those DTMCs and taking the global max or min. So, the claims about your system become “it will reach a bad state with at most X% probability” or “it always has at least a Y% probability of success”.

Working with distributed systems requires a catastrophic mindset. If there’s an ordering of events that could cause your system to fail, you must assume it will happen and evaluate your design within that regime (at scale, it’s a certainty the behavior will occur sooner or later). So when dealing with probability, the ordering of events that gives the highest chance of failure is your chance of failure. And that’s how you reason about a probabilistic distributed system.

Melting the snow

Actually using MDPs to analyze the Snowflake protocols is deserving of its own post, since this one is getting quite long and I still have to write the above-promised appendix. Instead I’ll just throw a bunch of links at you. Here is the original paper presenting the Snowflake family of protocols, posted pseudonymously on IPFS by “Team Rocket” (almost certainly Emin Gün Sirer et al., let’s be real); it is being productized by a company called Ava Labs. Here is a good writeup & summary of the protocols by Murat Demirbas, a professor who researches distributed systems at SUNY Buffalo.

Sarah Jamie Lewis is working on analyzing the snowflake protocols with MDPs and also a type of model we didn’t cover called Continuous-Time Markov Chains (CTMCs) - perhaps CTMCs will be the subject of another post. She’s developing a interesting attack called Snowfall using Byzantine response delays, detailed here. Her formal models can all be found in this git repo.

For myself I’ve modeled the most basic Snowflake protocol (called Slush) as a MDP in PRISM here. Look forward to a future post on what I learned - model checking MDPs is very expensive and the model is difficult to scale! CTMCs are apparently more scalable than MPDs, although I still don’t understand them very well so what they lose in model fidelity is unknown to me.

Finally, if you’re interested in the math & algorithms behind DTMCs, I can’t recommend this paper enough: Model Checking Meets Probability: A Gentle Introduction by Joost-Pieter Katoen, a professor at RWTH Aachen University. Without this paper I would never have been able to understand this material & write this post. Alternatively if you’re in search of an enormous tome containing all that is known about formal models and the checking thereof, see the Handbook of Model Checking, just published in 2018.

Corrections

Ron Pressler correctly points out here that it isn’t nondeterminism per se which fails at modeling probability, but rather our inability to express properties over the domain of all system behaviors (beyond \(\forall\) and \(\exists\)). If we could write a TLA⁺ function that sums the value (or takes the max/min) of variables across every single possible system behavior, we could reason usefully about probability in TLA⁺. Unfortunately we cannot do that at this time.

Appendix: calculating reachability probabilities

What follows will be a condensed & simplified version of the algorithm presented in the paper Model Checking Meets Probability: A Gentle Introduction. Recall our Knuth-Yao DTMC:

A state machine where each transition is labeled with either H or T for heads or tails. The state machine fans out like a four-level full binary tree from the start state, with the exception of the paths only flipping heads or only flipping tails. Starting from state 0 H goes to state 1, then H goes to state 3; however, from state 3 only T goes to one of the six termination states while H goes back to state 1 to form an infinite loop. There is an analogous loop on the T half of the tree.

Let’s try to calculate the probability of reaching accepting state \(2\) from the starting state. We use a simple recursive algorithm to convert this into an easily-solved system of linear equations. First, some definitions:

  • \(P(s, t)\) is the probability associated with the transition between state \(s\) and \(t\)
  • \(G\) is the set of goal states for which we want to calculate the reachability probability
  • \(x_s\) is the probability of reaching \(G\) from a specific state \(s\)

Our objective is to find \(x_s\) where \(s\) is the start state and \(G\) is a set \(\{2\}\) containing only the accepting state \(2\), but in order to do that we have to find \(x_s\) for every state \(s\) in the DTMC. We do it with three simple rules:

  1. Base case 1: if \(s \in G\), then \(x_s\) = 1
  2. Base case 2: if \(G\) is not reachable from \(s\), then \(x_s\) = 0
  3. Recursive case: otherwise, \(x_s = \sum_{t \notin G} P(s, t) \cdot x_t + \sum_{u \in G} P(s, u)\)

The equation in the recursive case looks fairly horrific, but worry not - we’ll get there. For the base cases it’s trivial to mark all the states in \(G\) as 1, and we can run a breadth-first search backwards from the states in \(G\) to find all the states from which \(G\) is reachable and mark the others as 0. For the recursive case it’s easiest to think about splitting this into two sub-cases: for a given \(s\), the first \(\sum_{t \notin G}\) is the probability of reaching \(G\) in a roundabout way by going through some intermediate state(s) (this is recursive since it depends on the probability of reaching \(G\) from those states). The second \(\sum_{u \in G}\) is the probability of reaching \(G\) directly, in a single step. Add them together and you get \(x_s\).

Let’s apply this algorithm to our example; here are the base cases:

  1. \(x_2 = 1\)
  2. \(x_{s_2} = x_{s_5} = x_{s_6} = x_1 = x_3 = x_4 = x_5 = x_6 = 0\)

For \(x_{s_0}, x_{s_1}, x_{s_3},\) and \(x_{s_4}\) we have:

  • \(x_{s_0} = \frac{1}{2} x_{s_1} + \frac{1}{2} x_{s_2}\)
  • \(x_{s_1} = \frac{1}{2} x_{s_3} + \frac{1}{2} x_{s_4}\)
  • \(x_{s_3} = \frac{1}{2} x_{s_1} + \frac{1}{2} x_{1}\)
  • \(x_{s_4} = \frac{1}{2} x_3 + \frac{1}{2} \)

This is a system of linear equations! Using Gaussian elimination to solve for \(x_{s_0}\), we see that it indeed equals \(\frac{1}{6}\). The above paper explains how to mechanically translate this into a matrix that can be solved with a quick call to a linear algebra library, but I hope you see how this algorithm works conceptually.