JAX-DIPS: Neural bootstrapping of finite discretization methods and application to elliptic problems with discontinuities
Pouria Mistani , Samira Pakravan, Rajesh Ilango, and Frederic Gibou
Under Review
We present a scalable strategy for development of mesh-free hybrid neuro-symbolic partial differen-
tial equation solvers based on existing mesh-based numerical discretization methods. Particularly,
this strategy can be used to efficiently train neural network surrogate models for the solution func-
tions and operators of partial differential equations while retaining the accuracy and convergence
properties of the state-of-the-art numerical solvers. The presented neural bootstrapping method
(hereby dubbed NBM) is based on evaluation of the finite discretization residuals of the PDE sys-
tem obtained on implicit Cartesian cells centered on a set of random collocation points with respect
to trainable parameters of the neural network. We apply NBM to the important class of elliptic
problems with jump conditions across irregular interfaces in three spatial dimensions. We show
the method is convergent such that model accuracy improves by increasing number of collocation
points in the domain. The algorithms presented here are implemented and released 1 in a software
package named JAX-DIPS (https://github.com/JAX-DIPS/JAX-DIPS), standing for differentiable
interfacial PDE solver. JAX-DIPS is purely developed in JAX, offering end-to-end differentiability
from mesh generation to the higher level discretization abstractions, geometric integrations, and
interpolations, thus facilitating research into use of differentiable algorithms for developing hybrid
PDE solvers.