Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud
We introduce a new family of deep neural network models. Instead of specifying a discrete sequence of hidden layers, we parameterize the derivative of the hidden state using a neural network. The output of the network is computed using a black-box differential equation solver. These continuous-depth models have constant memory cost, adapt their evaluation strategy to each input, and can explicitly trade numerical precision for speed. We demonstrate these properties in continuous-depth residual networks and continuous-time latent variable models. We also construct continuous normalizing flows, a generative model that can train by maximum likelihood, without partitioning or ordering the data dimensions. For training, we show how to scalably backpropagate through any ODE solver, without access to its internal operations. This allows end-to-end training of ODEs within larger models.
| Task | Dataset | Metric | Value | Model |
|---|---|---|---|---|
| Imputation | PhysioNet Challenge 2012 | mse (10^-3) | 3.907 | Latent ODE (RNN enc.) |
| Imputation | PhysioNet Challenge 2012 | mse (10^-3) | 5.93 | RNN-VAE |
| Imputation | MuJoCo | MSE (10^2, 50% missing) | 0.447 | Latent ODE (RNN enc.) |
| Imputation | MuJoCo | MSE (10^2, 50% missing) | 6.1 | RNN-VAE |
| Time Series Forecasting | MuJoCo | MSE (10^-2, 50% missing) | 1.377 | Latent ODE (RNN enc.) |
| Time Series Forecasting | MuJoCo | MSE (10^-2, 50% missing) | 1.782 | RNN-VAE |
| Time Series Forecasting | USHCN-Daily | MSE | 0.83 | NeuralODE-VAE-Mask |
| Time Series Forecasting | USHCN-Daily | MSE | 0.96 | NeuralODE-VAE |
| Time Series Forecasting | PhysioNet Challenge 2012 | MSE stdev | 0.145 | RNN-VAE |
| Time Series Forecasting | PhysioNet Challenge 2012 | mse (10^-3) | 3.055 | RNN-VAE |
| Time Series Forecasting | PhysioNet Challenge 2012 | MSE stdev | 0.052 | Latent ODE (RNN enc.) |
| Time Series Forecasting | PhysioNet Challenge 2012 | mse (10^-3) | 3.162 | Latent ODE (RNN enc.) |
| Feature Engineering | PhysioNet Challenge 2012 | mse (10^-3) | 3.907 | Latent ODE (RNN enc.) |
| Feature Engineering | PhysioNet Challenge 2012 | mse (10^-3) | 5.93 | RNN-VAE |
| Feature Engineering | MuJoCo | MSE (10^2, 50% missing) | 0.447 | Latent ODE (RNN enc.) |
| Feature Engineering | MuJoCo | MSE (10^2, 50% missing) | 6.1 | RNN-VAE |
| Time Series Analysis | MuJoCo | MSE (10^-2, 50% missing) | 1.377 | Latent ODE (RNN enc.) |
| Time Series Analysis | MuJoCo | MSE (10^-2, 50% missing) | 1.782 | RNN-VAE |
| Time Series Analysis | USHCN-Daily | MSE | 0.83 | NeuralODE-VAE-Mask |
| Time Series Analysis | USHCN-Daily | MSE | 0.96 | NeuralODE-VAE |
| Time Series Analysis | PhysioNet Challenge 2012 | MSE stdev | 0.145 | RNN-VAE |
| Time Series Analysis | PhysioNet Challenge 2012 | mse (10^-3) | 3.055 | RNN-VAE |
| Time Series Analysis | PhysioNet Challenge 2012 | MSE stdev | 0.052 | Latent ODE (RNN enc.) |
| Time Series Analysis | PhysioNet Challenge 2012 | mse (10^-3) | 3.162 | Latent ODE (RNN enc.) |
| Multivariate Time Series Forecasting | MuJoCo | MSE (10^-2, 50% missing) | 1.377 | Latent ODE (RNN enc.) |
| Multivariate Time Series Forecasting | MuJoCo | MSE (10^-2, 50% missing) | 1.782 | RNN-VAE |
| Multivariate Time Series Forecasting | USHCN-Daily | MSE | 0.83 | NeuralODE-VAE-Mask |
| Multivariate Time Series Forecasting | USHCN-Daily | MSE | 0.96 | NeuralODE-VAE |
| Multivariate Time Series Forecasting | PhysioNet Challenge 2012 | MSE stdev | 0.145 | RNN-VAE |
| Multivariate Time Series Forecasting | PhysioNet Challenge 2012 | mse (10^-3) | 3.055 | RNN-VAE |
| Multivariate Time Series Forecasting | PhysioNet Challenge 2012 | MSE stdev | 0.052 | Latent ODE (RNN enc.) |
| Multivariate Time Series Forecasting | PhysioNet Challenge 2012 | mse (10^-3) | 3.162 | Latent ODE (RNN enc.) |