Data availability proof-friendly state tree transitions

OK, you’re absolutely right. It is definitely possible to store the entire state in a sparse binary Merkle tree, with 2**160 nodes. There are two key tricks needed for this to happen:

  1. Setting up an initial (empty) trie. Every bottom level node is 0x00 * 32. If every level N node is X, every level N+1 node is sha3(X + X). Hence, you can calculate the root hash and intermediate nodes of an empty trie with only 160 rounds of computation.
  2. Compressing a Merkle branch for a value. One disadvantage of this kind of tree is that the length of a Merkle branch considered naively is much longer than it is for a Patricia tree; a binary Patricia tree with N nodes has an average Merkle branch length of ~32 * log2(N) bytes in theory, and about 10% more in practice because of overhead, but a sparse binary tree would have a Merkle branch length of ~32 * 160 bytes, or even more if we decide to increase the address length. But this can be solved if we notice that for most of the Merkle path, the subtree being considered will have only one node, and so the opposite node (ie. the node that would go into the proof) is a node of an empty subtree. There is only one possible hash for an empty subtree at any given height, so this can be compressed down to one bit.

A Merkle proof length could be fairly easily shortened to 64 + 32 * log2(N) bytes, where the 64 bytes of overhead consists of (i) the key, and (ii) 256 bits where the ith bit represents whether the ith opposite node should be read from the Merkle proof or taken as the empty subtree root for that height; this is comparable data overhead to the existing binary Patricia tree and can be further optimized if needed.

The key benefit here is a massive gain in simplicity. Here’s a code implementation of get/set and Merkle proof creation and verification:

def new_tree(db):
    h = b'\x00' * 32
    for i in range(256):
        newh = sha3(h + h)
        db.put(newh, h + h)
        h = newh
    return h

def key_to_path(k):
    o = 0
    for c in k:
        o = (o << 8) + c
    return o

def get(db, root, key):
    v = root
    path = key_to_path(key)
    for i in range(256):
        if (path >> 255) & 1:
            v = db.get(v)[32:]
        else:
            v = db.get(v)[:32]
        path <<= 1
    return v

def update(db, root, key, value):
    v = root
    path = path2 = key_to_path(key)
    sidenodes = []
    for i in range(256):
        if (path >> 255) & 1:
            sidenodes.append(db.get(v)[:32])
            v = db.get(v)[32:]
        else:
            sidenodes.append(db.get(v)[32:])
            v = db.get(v)[:32]
        path <<= 1
    v = value
    for i in range(256):
        if (path2 & 1):
            newv = sha3(sidenodes[-1] + v)
            db.put(newv, sidenodes[-1] + v)
        else:
            newv = sha3(v + sidenodes[-1])
            db.put(newv, v + sidenodes[-1])
        path2 >>= 1
        v = newv
        sidenodes.pop()
    return v

def make_merkle_proof(db, root, key):
    v = root
    path = key_to_path(key)
    sidenodes = []
    for i in range(256):
        if (path >> 255) & 1:
            sidenodes.append(db.get(v)[:32])
            v = db.get(v)[32:]
        else:
            sidenodes.append(db.get(v)[32:])
            v = db.get(v)[:32]
        path <<= 1
    return sidenodes

def verify_proof(proof, root, key, value):
    path = key_to_path(key)
    v = value
    for i in range(256):
        if (path & 1):
            newv = sha3(proof[-1-i] + v)
        else:
            newv = sha3(v + proof[-1-i])
        path >>= 1
        v = newv
    return root == v