Probabilistic Programming

Rapid model development

BAli-Phy 4 (unreleased) contains a language for expressing a wide range of probabilistic models. The goals of the language are:
  • expressivity: to be expressive enough that researchers can spend their time designing models instead of designing new inference software.
  • automatic inference: the inference should largely take care of itself after the model is specified.
MCMC is used for inference.

BAli-Phy implements a universal probabilistic programming language (PPL). Universal PPLs allow inferring the number and relationship of random variables (See Ronquist et al, 2021). This differs from probabilistic graphical modeling (PGM) languages, such as Stan, BUGS, and RevBayes, where the model structure is fixed, and cannot be changed after it is initialized.

Theory: Bayesian hierarchical models as programs

A PPL allows users to write a probabilistic model in the form of a computer program. The model program draws random variables from their prior distribution, and incorporate data by calling functions to "observe" data from the data distribution. This is a natural way to write Bayesian hierarchical models.

Each time it is run, the model program will draw different values for the random variables from their prior distributions, and compute the prior probability and the likelihood of the observed data.

runmeansigmalog(prior)log(likelihood)
13.0008499121.130289503-3.24195035-15.5905309
25.766765900.364175082-2.049752039-1.977179616
3............

The sequence of random choices that are made during a program run is called a trace. The trace completely determines the course of a program run, and includes all the random variables that we wish to infer. If different program runs take different branches of a conditional statement, then their traces may include different random variables. In theory a model program can be written in any language that allows (i) recording the trace for a program run, and (ii) replaying the program given a trace.

Inference under the model involves sampling from the posterior distribution of program traces. The posterior probability of a trace is the product of (i) the prior probability of the trace and (ii) the likelihood of the observed data given the trace. The simplest way to approximate the posterior distribution is to run the model program many times and weight each trace by the likelihood. However, this is too inefficient in practice.

MCMC is a more efficient approach to inference. To conduct inference using MCMC, we need to be able to

  1. propose new program traces by modifying one or more random variables
  2. re-execute the model program with a modified trace
  3. decide whether to accept the proposed trace.

The host environment operates on model program traces to perform inference. In contrast, the model program describes distributions on data but has no concept of inference. Therefore, modifications to the model program are not controlled by the model program, and are in fact invisible to it. The host environment may also implement the model program language by running model programs inside an interpreter. The host environment operates on a different level from the model program, and may be written in a different language. In BAli-Phy, the host environment is written mostly in C++.

The naive approach to MCMC inference involves rerunning the entire model program from scratch whenever the trace changes. This is quite inefficient. BAli-Phy addresses this problem by determining which parts of the model program execution depend on random variables that have changed. Then it can rerun just the affected parts of the model program's execution, saving lots of computation time.

BAli-Phy uses Haskell as a model language, because Haskell makes it possible to determine which parts of the model program execution depend on a changed random variable. This is because Haskell represents control-flow statements (such as loops and if-then statements) as functions.

We can therefore construct an execution dependency graph, which contains edges between every function output and any inputs to that function that might change. When a random variable changes, it allows us to identify the part of the graph that depends on that random variable, and re-execute it. This graph is similar to the graph of a PGM. However, unlike a PGM, the shape of the graph is not fixed, but depends on values of the random variables.

Examples

  1. Linear regression
  2. Jointly infer tree and alignment
  3. Use GTR+M7 model to infer tree and alignment
Note: These examples won't work with version 3. They will be part of version 4, which is expected to come out in mid-2022.

To run these examples now you can

Haskell syntax

It probably helps to know that

  • f x is the function f applied to x
  • let x = y is a simple assignment (no side-effects)
  • do...x <- y... performs some kind of action and assigns the result to x.
    The action could be random sampling, an IO operation, etc, depending on the context.
  • map f [x1,x2,...] applies the function f to every element of the list.
    The result looks like [f x1, f x2, ...].
    It is used instead of for-loops.

For quick introductions to Haskell syntax you might want to take a look at the short interactive tutorial at tryhaskell.org or the quick tour of Haskell syntax.

Linear regression

Here is a short program that performs linear regression. Here the goal is to find a line f(x)=a*x+b that best predicts y[i] from x[i]. The data ys gives y[i] at each location x[i] in xs.

import           Probability
import           Data.Frame

model xs ys = do

    b     <- normal 0.0 1.0

    a     <- normal 0.0 1.0

    sigma <- exponential 1.0

    let f x = b * x + a

    ys ~> independent [ normal (f x) sigma | x <- xs ]

    let loggers = ["b" %=% b, "a" %=% a, "sigma" %=% sigma]

    return loggers

main = do
  xy_data <- readTable "xy.csv"

  let xs = xy_data $$ ("x", AsDouble)
      ys = xy_data $$ ("y", AsDouble)

  mcmc $ model xs ys

  • sampling from a distribution looks like b <- normal 0.0 1.0.
    (This specifies a prior term.)
  • observing data from a distribution looks like data ~> distribution.
    (This specifies a likelihood term.)
  • defining a function looks like let f x = b*x + a.
    (This defines the best fit line.)
  • logging parameters is done by the code return ["b" %=% b, "a" %=% a, ...].
    (This writes a corresponding JSON object each MCMC iteration.)

Note that for each x, the distribution of y(x) is normal (f x) sigma. The term f x is the location predicted by the best-fit line, but there is a distribution because the observed point may not fall exactly on the line.

You can find this file in bali-phy/tests/prob_prog/regression/ and run it as bali-phy -m LinearRegression.hs --iter=1000.

Tree and alignment inference

Here is a short program that infers the tree and alignment from a FASTA file given on the command line.

import           Probability
import           Bio.Alignment
import           Bio.Alphabet
import           Tree
import           Tree.Newick
import           SModel
import           IModel
import           Probability.Distribution.OnTree
import           System.Environment  -- for getArgs

branch_length_dist topology branch = gamma 0.5 (2.0 / intToDouble n)
    where n = numBranches topology

model seq_data = do
    let taxa            = map sequence_name seq_data
        tip_seq_lengths = get_sequence_lengths dna seq_data

    -- Tree
    scale1 <- gamma 0.5 2.0
    tree   <- uniform_labelled_tree taxa branch_length_dist
    let tree1 = scale_branch_lengths scale1 tree

    -- Indel model
    logLambda   <- log_laplace (-4.0) 0.707
    mean_length <- (1.0 +) <$> exponential 10.0
    let imodel = rs07 logLambda mean_length tree

    -- Substitution model
    freqs  <- symmetric_dirichlet_on ["A", "C", "G", "T"] 1.0
    kappa1 <- log_normal 0.0 1.0
    kappa2 <- log_normal 0.0 1.0
    let tn93_model = tn93' dna kappa1 kappa2 freqs

    -- Alignment
    alignment <- random_alignment tree1 imodel tip_seq_lengths

    -- Observation
    seq_data ~> ctmc_on_tree tree1 alignment tn93_model

    return
        [ "tree1" %=% write_newick tree1
        , "logLambda" %=% logLambda
        , "mean_length" %=% mean_length
        , "kappa1" %=% kappa1
        , "kappa2" %=% kappa2
        , "frequencies" %=% freqs
        , "scale1" %=% scale1
        , "|T|" %=% tree_length tree
        , "scale1*|T|" %=% tree_length tree1
        , "|A|" %=% alignment_on_tree_length alignment
        ]

main = do
    [filename] <- getArgs

    let seq_data = load_sequences filename

    mcmc $ model seq_data

You can find this file in bali-phy/tests/prob_prog/infer_tree/1/ and run it as bali-phy -m Main.hs --iter=1000 --- 5d-muscle.fasta.

Tree and alignment inference under a dN/dS-across-sites model

Some of the value of this framework can be seen by showing how we can do alignment inference under a more complicated substitution model. Here we have factored the substitution model and its prior out into a function gtr_m7_model.

This uses a model of positive selection that (i) is based on GTR, not HKY and (ii) allows dNdS to have a Beta distribution across sites. This is based on the Haskell that would be generated by running

bali-phy bglobin.fasta -S function[w,gtr+x3+dNdS[w]]+m7 --test -Inone

import           Probability
import           Bio.Alignment
import           Bio.Alphabet
import           Tree
import           Tree.Newick
import           IModel
import           SModel
import           Probability.Distribution.OnTree
import           System.Environment  -- for getArgs

gtr_m7_model codons = do

    -- GTR model parameters
    let nucs = getNucleotides codons
    sym <- symmetric_dirichlet_on (letter_pair_names nucs) 1.0
    pi  <- symmetric_dirichlet_on (letters nucs) 1.0
    let gtr_model = gtr' sym pi nucs

    -- Positive selection model based on GTR
    let pos_sel_model w = gtr_model & x3 codons & dNdS w

    -- M7 model parameters
    mu    <- uniform 0.0 1.0            -- mean of dN/dS
    gamma <- beta 1.0 10.0              -- spread of dN/dS
    let m7_model = pos_sel_model & m7 mu gamma 4

    let loggers =
            [ "gtr:sym" %=% sym
            , "gtr:pi" %=% pi
            , "m7:mu" %=% mu
            , "m7:gamma" %=% gamma
            ]
    return (m7_model, loggers)

branch_length_dist topology b = gamma 0.5 (2.0 / intToDouble n)
    where n = numBranches topology

model seq_data = do

    let the_codons      = codons dna standard_code
        taxa            = map sequence_name seq_data
        tip_seq_lengths = get_sequence_lengths the_codons seq_data

    -- Tree
    scale1 <- gamma 0.5 2.0
    tree   <- uniform_labelled_tree taxa branch_length_dist
    let tree1 = scale_branch_lengths scale1 tree

    -- Indel model
    logLambda   <- log_laplace (-4.0) 0.707
    mean_length <- (1.0 +) <$> exponential 10.0
    let imodel = rs07 logLambda mean_length tree

    -- Substitution model
    (m7_model, log_m7_smodel) <- gtr_m7_model the_codons

    -- Alignment
    alignment <- random_alignment tree1 imodel tip_seq_lengths

    -- Observation
    seq_data ~> ctmc_on_tree tree1 alignment m7_model

    return
        [ "tree1" %=% write_newick tree1
        , "scale" %=% scale1
        , "S1" %>% log_m7_smodel
        , "|T|" %=% tree_length tree
        , "scale1*|T|" %=% tree_length tree1
        ]

main = do
    [filename] <- getArgs

    let seq_data = load_sequences filename

    mcmc $ model seq_data

You can run it as bali-phy -m M7.hs --iter=1000 --- bglobin.fasta.

Language properties

The modeling language is a functional language, and uses Haskell syntax. Features currently implemented include:

  1. Random control flow works, allowing if-then-else and loops that depend on random variables.
  2. Composite Objects work, and can be used to define random data structures.
  3. [unreleased] Random numbers of random variables. Random variables can be conditionally created, without the need for reversible-jump methods.
  4. [unreleased] Lazy random variables. Infinite lists of random variables can be created. Random variables are only instantiated if they are accessed
  5. MCMC works, even when the number of variables is changing.
  6. Functions work, and can be used to define random variables.
  7. Modules work, and allow code to be factored in a clean manner.
  8. Packages work, and allow researchers to distribute their work separately from the BAli-Phy architecture.
  9. Optimization works, and speeds up the users code via techniques such as inlining.
  10. Recursive random variables. Random processes on trees that are not known in advance.
  11. JSON logging. This enables logging inferred parameters when their dimension and number is not fixed.

Features that are expected to be completed by mid-2022 include:

  • Type checking. Type checking will enable polymorphism and give useful error messages for program errors.
  • Time Trees and the relaxed clock. Rooted trees implemented as a data structure within the language. (partially implemented)
  • Custom MCMC moves. The ability to add custom MCMC transition kernels will be added. (partially implemented)
  • Port alignment/tree inference. Move alignment and tree inference completely to the model framework.

comments and suggestions: benjamin . redelings * gmail + com