The majority of this lecture is borrowed or adapted from Session Types for Rust.

Along with typestate, the theme of these last two lectures is looking at high-level invariants of computational systems. Specifically, invariants that we usually express with words and verify with unit tests, but by being clever with Rust, are able to statically verify with the type system. Last lecture, we looked at state machines and how to enforce state transitions. Today, we’re going to look at a related topic: two-party communication protocols.

Communication protocols arise when multiple independently acting entities need to coordinate to perform some action. This is particularly common in distributed systems and networking, e.g. a client and server establishing a secure session via TLS, but more generally can apply to situations like:

  • A customer using an ATM or vending machine.
  • A chatbot holding a conversation.
  • A Bitcoin miner broadcasting a mined block.

In this lecture, we are going to consider the case of two-party communication protocols, i.e. where exactly two entities are communicating.

Session type formalism

Let’s consider the example of an ATM. In plain English, we can describe a simple protocol for deposits and withdrawls:

  • The client communicates his/her ID to the ATM
  • The ATM then answers either ok or err
    • In the first case, the client then proceeds to request either a deposit or withdraw
      • For a deposit the client first sends an amount, then the ATM responds with the updated balance
      • For a withdraw the client sends the amount to withdraw, and the ATM responds with either ok or err to indicate whether or not the transaction was successful
    • If the ATM answers err, then the session terminates.

Now our question is: how can we turn this plain English into something more formal, more like a program? Last lecture, we didn’t have this issue because state machines have a well-known/studied formalism (set of states, transitions between states). This is a less settled question for concurrency protocols, but today we’re going to explore one particular formalism: session types. Session types have a storied history since their introduction in 1994, but my goal today is to acquaint you with the core intuitions behind them, and explore an implementation in Rust. To get a sense of how they work, here’s a formalization of the ATM protocol from the view of the ATM:

Session types have four core operations: send/receive, which indicate sending and receiving messages of a particular type from the other party, as well as choose/offer, which indicate a branch point where the host party can choose to enter one of several sub-protocols, or offer to let the other party select a sub-protocol. For example, the top line declares the initial protocol for the ATM: first receive an a user ID, and then either enter the authorized ATM sub-protocol or return an err.

A few things to note here: first, session types specify communication behavior but not logical behavior, e.g. a user ID could be rejected for any number of reasons not specified in the type, the protocol simply allows for an error to occur at that point in the dialogue. Second, semicolons indicate a sequence of actions, e.g. an ID must be received before choosing ok/err. Lastly, an epsilon indicates the termination of the dialogue. We can formalize this grammar as:

Duality

Note that I carefully said the above session type described the two-party protocol from the view of the ATM. There should also be a corresponding session type describing the protocol from the view of the customer (also called the “dual”), so that both can be ultimately implemented and verified against the appropriate session type. We could manually create the dual as follows:

An astute reader may notice that the client looks suspiciously similar to the ATM, except with all the send/recv and choose/offer flipped. This is by design! It turns out that given our particular formulation of session types, we can systematically construct a dual session type for a given input session type, indicated by a bar . Those rules are:

From these rules, we can simply define our client as . Not only does this reduce the work in defining the protocol, it ensures that each party’s protocol is consistent with the others’ (that this is true is a metatheoretical fact that can be proven about the dual construction rules).

Recursion

One other important construct missing from our list of session type primitives is a notion of recursion. For example, in the ATM, once the client is authenticated we would like for them to be able to continually execute transactions without reauthenticating. For this, we can introduce a notion of a recursive session type that looks similar to recursion in the lambda calculus.

Essentially, the defines a branch point along with a session type variable, where a usage of that variable indicates a return to the start of the branch. Here, an reference means returning to the start of the auth protocol. We can extend our grammar:

Along with dual rules:

Session types in Rust

Think back to last lecture: the typestate pattern allowed us to take a specific state machine, and mechanically encode it as types in Rust. Typestates are general, meaning they apply to any kind of finite state machine, but they are verbose, requiring a fair amount of code for any specific state machine that we want to encode.

Session types are like typestate, but specialized for the class of state machines used in two-party communications. The send/receive/offer/choose/close constructs form a state machine library, or a set of constructs from which a class of state machines can be represented. To implement them in Rust, we are going to create a session type library which can be reused for many different kinds of communication.

Session type primitives

First, we need to define the session type language in Rust’s type system. When encoding typestate for File, we defined two new types Reading and Eof to represent the state of a file. Session types are at a higher level of abstraction, i.e. they are a language of states, or a set of generic rules for composing states. In Rust, we can represent state composition as polymorphic structs:

use std::marker::PhantomData;

// "S" is shorthand for "sigma", meaning "the rest of the session"
pub struct Send<T, S>(PhantomData<(T, S)>);
pub struct Recv<T, S>(PhantomData<(T, S)>);
pub struct Offer<Left, Right>(PhantomData<(Left, Right)>);
pub struct Choose<Left, Right>(PhantomData<(Left, Right)>);
pub struct Close; // equivalent to epsilon

Like before, we have to use PhantomData for Rust not to raise a compiler error about unused struct type arguments. Given these structs, we can use them to define the ATM protocol:

type Id = String;
type AtmDeposit = Recv<u64, Send<u64, Close>>;
type AtmWithdraw = Recv<u64, Choose<Close, Close>>;
type AtmServer =
  Recv<Id,
  Choose<
    Offer<AtmDeposit, AtmWithdraw>,
    Close>>;

Observe that we are only defining types, not expressions. At its core, our session type library allows us to formally describe communication as types, and then later check that an implementation adheres to the type. This encoding of session types into Rust is equivalent to the earlier language, although we have lost our branch labels on offer/choose. To simplify the lecture, we are not going to look at implementing recursive session types (you’ll see that on your assignment).

Dual through traits

One particularly cool part of this construction is that we can express our dual rules using Rust’s trait system of associated types and conditional trait implementations. We start by defining our core trait:

pub trait HasDual {
  type Dual;
}

If a type implements HasDual, then it has a dual type, and that dual type is the associated type Dual. For example, the dual of a close is a close, which we can write as:

impl HasDual for Close {
  type Dual = Close;
}

Where it gets cool is that we can use conditional trait implementation to translate the inductive dual rules into impl blocks.

impl<T, S> HasDual for Send<T, S> where S: HasDual {
  type Dual = Recv<T, S::Dual>;
}

impl<T, S> HasDual for Recv<T, S> where S: HasDual {
  type Dual = Send<T, S::Dual>;
}

These rules say that when S implements HasDual, that means it has an associated type S::Dual which is its dual. We can then insert that associated type into the container, swapping a Send to Recv and vice versa. The choose/offer rules look the same:

impl<Left, Right> HasDual for Choose<Left, Right>
where Left: HasDual, Right: HasDual {
  type Dual = Offer<Left::Dual, Right::Dual>;
}

impl<Left, Right> HasDual for Offer<Left, Right>
where Left: HasDual, Right: HasDual {
  type Dual = Choose<Left::Dual, Right::Dual>;
}

Ultimately, we can use Rust’s fully qualified trait syntax to create the ATM’s client type:

type AtmClient = <AtmServer as HasDual>::Dual;

Session-typed channels

Now we have successfully translated the session type system and its inductive dual rules into Rust. However, we still haven’t defined what checks whether a Rust program correctly implements a particular session type. Since session types are fundamentally about passing messages, we’re going to build our implementation on top of Rust’s channels, just like we saw in the concurrency lecture.

A session-typed channel is a Chan struct that implements the communication for the session type S. We would like it to adhere to the following interface:

use std::sync::mpsc::{Sender, Receiver};

pub struct Chan<S> { ... }

impl<S> Chan<S> where S: HasDual {
  pub fn new() -> (Chan<S>, Chan<S::Dual>) { ... }
}

For example, the following echo program should be well-typed:

type Server = Recv<u64, Close>;
type Client = <Server as HasDual>::Dual;

fn server(c: Chan<Server>) {
  let (c, n): (Chan<Close>, u64) = c.recv();
  println!("{}", n);
  c.close();
}

fn client(c: Chan<Client>) {
  let c: Chan<Close> = c.send(5);
  c.close();
}

use std::thread;

fn main() {
  let (server_chan, client_chan) = Chan::new();
  let server_thread = thread::spawn(move || server(server_chan));
  let client_thread = thread::spawn(move || client(client_chan));

  server_thread.join().unwrap();
  client_thread.join().unwrap();
}

The basic idea is that when a channel is parameterized by a session type S, there is some action that is available to the channel, whether it’s receiving, sending, closing, offering, or choosing. For example, in the above code, when we created the server_chan of type Chan<Server>, the first operation is to c.recv() which is consistent with the server type Recv<u64, Close>.

Our construction should ensure that only the desired action is available at any given point in the protocol. The way we ensure this is that after performing an action, like c.recv(), we change the type of the channel. We don’t actually change the channel value itself, just its type. For example, we need implementations like the following:

use std::marker;

impl Chan<Close> {
  pub fn close(self) { ... }
}

impl<T, S> Chan<Send<T, S>> where T: marker::Send + 'static
{
  pub fn send(self, x: T) -> Chan<S> { ... }
}

These says: when a Chan is in a Close state, then the only method on it is fn close(self), which consumes the channel and returns nothing, since the session is over. When a Chan is in a Send<T, S> state, the only method on it is fn send(self, x: T) -> Chan<S> that should send the value to the other party, and return a new channel starting at the next step in the protocol.

This all seems crazy unsafe (and it will, in fact, use unsafe code), but it is safe as long as the underlying implementation is correct. First, we need to look at the core channel methods: how are we sending bits over the wire?

use std::mem::transmute;

pub struct Chan<S> {
  sender: Sender<Box<u8>>,
  receiver: Receiver<Box<u8>>,
  _data: PhantomData<S>,
}

impl<S> Chan<S> {
  unsafe fn write<T>(&self, x: T)
  where
    T: marker::Send + 'static,
  {
    let sender: &Sender<Box<T>> = transmute(&self.sender);
    sender.send(Box::new(x)).unwrap();
  }

  unsafe fn read<T>(&self) -> T
  where
    T: marker::Send + 'static,
  {
    let receiver: &Receiver<Box<T>> = transmute(&self.receiver);
    *receiver.recv().unwrap()
  }
}

A Chan contains a Sender for sending messages to the counterparty, and a Receiver for receiving messages. For flexibility, the channels can communicate arbitrary blobs of heap-allocated bytes, indicated by Box<u8>. The core functions of a Chan are the magically unsafe read and write functions, which take any type, convert it to a blob of bits, and send it (vice versa for receiving).

This is made possible by the enigmatic mem::transmute function, which can take value of some type, and convert it into any other type. Wow! That sounds pretty dangerous, because it is. (Not uncommon though: OCaml has an equivalent function called Obj.magic). You can think about this as equivalent to using void* pointers in C. While this is terribly unsafe, so long as we strictly adhere to our session types, we should only be sending or receiving values of the appropriate type. For starters, let’s look at send and recv:

impl<T, S> Chan<Send<T, S>>
where
  T: marker::Send + 'static,
{
  pub fn send(self, x: T) -> Chan<S> {
    unsafe {
      self.write(x);
      transmute(self)
    }
  }
}

impl<T, S> Chan<Recv<T, S>>
where
  T: marker::Send + 'static,
{
  pub fn recv(self) -> (Chan<S>, T) {
    unsafe {
      let a = self.read();
      (transmute(self), a)
    }
  }
}

The first implementation says: when we have a channel with session type Send<T, S>, then we can call a function send(self, x: T) that sends a T and returns a channel of type Chan<S>, the rest of the session type. We accomplish this by (unsafely) calling self.write to send the value, then again calling transmute to turn our Chan<Send<T, S>> into a Chan<S>. Similarly, in recv, we first read a value of type T, then transmute the channel from Chan<Recv<T, S>> into Chan<S>.

Lastly, we need to implement choose and offer. First, let’s define our desired API with a quick example.

type Server = Offer<Recv<u64, Close>, Send<u64, Close>>;
type Client = <Server as HasDual>::Dual;

fn server(c: Chan<Server>) {
  // offer should ask the client which branch to take
  match c.offer() {
    Branch::Left(c) => {
      let (c, n) = c.recv();
      println!("{}", n);
      c.close();
    }
    Branch::Right(c) => {
      let c = c.send(0);
      c.close();
    }
  }
}

fn client(c: Chan<Client>) {
  let c = c.left(); // tell server to run left branch
  let c = c.send(1); // send allowed because we're in left branch
  c.close();
}

We can concretely implement this API as follows:

impl<Left, Right> Chan<Choose<Left, Right>> {
  pub fn left(self) -> Chan<Left> {
    unsafe {
      self.write(true);
      transmute(self)
    }
  }

  pub fn right(self) -> Chan<Right> {
    unsafe {
      self.write(false);
      transmute(self)
    }
  }
}

pub enum Branch<L, R> {
  Left(L),
  Right(R),
}

impl<Left, Right> Chan<Offer<Left, Right>> {
  pub fn offer(self) -> Branch<Chan<Left>, Chan<Right>> {
    unsafe {
      if self.read() {
        Branch::Left(transmute(self))
      } else {
        Branch::Right(transmute(self))
      }
    }
  }
}

We internally represent the choice of branch to take with a simple boolean, written by choose and read by offer. We use separate left and right functions (as opposed to a single choose) function because the type of channel returned differs in each case. With that, we’re about done! We just need the Chan::new function:

use std::sync::mpsc::channel;

impl<S> Chan<S> where S: HasDual {
  pub fn new() -> (Chan<S>, Chan<S::Dual>) {
    let (sender1, receiver1) = channel();
    let (sender2, receiver2) = channel();
    let c1 = Chan {
      sender: sender1,
      receiver: receiver2,
      _data: PhantomData,
    };
    let c2 = Chan {
      sender: sender2,
      receiver: receiver1,
      _data: PhantomData,
    };
    (c1, c2)
  }
}

This constructs two channels, ensuring bidirectional communication can occur as necessary.

Error catching

So we have a complete construction, but does it have the desired effects? If we try to implement a session type incorrectly, what will happen? Let’s take a look. Recall the simple protocol from earlier:

type Server = Recv<u64, Close>;
type Client = <Server as HasDual>::Dual;

Let’s say we accidentally try to implement the server by sending instead of receiving.

fn server(c: Chan<Server>) {
  let c = c.send(1);
  c.close();
}
error[E0599]: no method named `send` found for type `Chan<Recv<u64, Close>>` in the current scope
   --> test.rs:157:13
    |
49  | pub struct Chan<S> {
    | ---------------------- method `send` not found for this
...
157 |   let c = c.send(1);
    |             ^^^^

Good, now what if we try to do this later in the protocol?

fn server(c: Chan<Server>) {
  let (c, n) = c.recv();
  println!("{}", n);
  let c = c.send(1);
  c.close();
}
error[E0599]: no method named `send` found for type `Chan<Close>` in the current scope
   --> test.rs:159:13
    |
49  | pub struct Chan<S> {
    | ---------------------- method `send` not found for this
...
159 |   let c = c.send(1);
    |             ^^^^

Same thing, that means our channels are successfully transitioning types. Does our dual construction work correctly?

fn client(c: Chan<Client>) {
  let (c, n) = c.recv();
  println!("{}", n);
  c.close();
}
error[E0599]: no method named `recv` found for type `Chan<Send<u64, Close>>` in the current scope
   --> test.rs:163:18
    |
49  | pub struct Chan<S> {
    | ---------------------- method `recv` not found for this
...
163 |   let (c, n) = c.recv();
    |                  ^^^^

Yup. Does our interface enforce that messages sent and received have the right types?

fn client(c: Chan<Client>) {
  let c = c.send("Hello");
  c.close();
}
error[E0308]: mismatched types
   --> test.rs:163:18
    |
163 |   let c = c.send("Hello");
    |                  ^^^^^^^ expected u64, found reference
    |
    = note: expected type `u64`
               found type `&'static str`

Wonderful! But let’s think about subtler edge cases. What if we tried to reuse a channel?

fn client(c: Chan<Client>) {
  c.send(1);
  let c = c.send(2);
  c.close();
}
error[E0382]: use of moved value: `c`
   --> test.rs:164:11
    |
163 |   c.send(1);
    |   - value moved here
164 |   let c = c.send(2);
    |           ^ value used here after move
    |
    = note: move occurs because `c` has type `Chan<Send<u64, Close>>`, which does not implement the `Copy` trait

Aha! Rust to the rescue. Because we defined our methods to consume ownership of the input channel, this means that we must execute actions in the intended order (this is similar in spirit to enforcing state machine transitions from last time).

Complete ATM example

At last, we can finish our verified ATM example.

type Id = String;
type AtmDeposit = Recv<u64, Send<u64, Close>>;
type AtmWithdraw = Recv<u64, Choose<Close, Close>>;
type AtmServer =
  Recv<Id,
  Choose<
    Offer<AtmDeposit, AtmWithdraw>,
    Close>>;
type AtmClient = <AtmServer as HasDual>::Dual;

fn approved(_id: &str) -> bool { true }

pub fn atm_server(c: Chan<AtmServer>) {
  let (c, id) = c.recv();
  if !approved(&id) {
    c.right().close();
    return;
  }
  let mut balance = 100; // get balance for id

  let c = c.left();
  match c.offer() {
    Branch::Left(c) => { // Deposit
      let (c, amt) = c.recv();
      balance += amt;
      c.send(balance).close();
    }
    Branch::Right(c) => { // Withdraw
      let (c, amt) = c.recv();
      if balance >= amt {
        balance -= amt;
        c.left().close();
      } else {
        c.right().close();
      }
    }
  }
}

fn atm_client(c: Chan<AtmClient>) {
  let id = String::from("wcrichto");
  let c = c.send(id);
  match c.offer() {
    Branch::Left(c) => {
      let c = c.right(); // withdraw
      let c = c.send(105);
      match c.offer() {
        Branch::Left(c) => {
          println!("Withdrawl succeeded.");
          c.close();
        }
        Branch::Right(c) => {
          println!("Insufficient funds.");
          c.close()
        }
      }
    }
    Branch::Right(c) => {
      println!("Invalid authorization");
      c.close();
    }
  }
}

use std::thread;

fn main() {
  let (server_chan, client_chan) = Chan::new();
  let server_thread = thread::spawn(move || atm_server(server_chan));
  let client_thread = thread::spawn(move || atm_client(client_chan));

  server_thread.join().unwrap();
  client_thread.join().unwrap();

  // Prints "Insufficient funds."
}