TensorMesh
A fast, differentiable, JIT-free, debugging-friendly finite element library for PyTorch.
Why TensorMesh
Built for the workflow where finite elements meet deep learning — without sacrificing the speed and accuracy you expect from a real FEM library.
GPU-Native & Differentiable
Built on PyTorch — move the entire FEM workflow to GPU with one line. Autograd flows seamlessly through assembly and solve for end-to-end differentiable PDE pipelines.
Tensorized Assembly
A fully tensorized Map-Reduce algorithm powered by TensorGalerkin fuses element-wise ops into monolithic GPU kernels — order-of-magnitude speedups over CPU-based FEM stacks.
JIT-Free & Debugging-Friendly
Eager execution with no compilation overhead. Dynamic meshes, adaptive refinement, and interactive workflows just work — no recompilation latency, no opaque traces.
Element & Mesh Support
Triangular, tetrahedral, pyramid, and prismatic elements. Automated mesh generation for common geometries with seamless Gmsh and VTK-HDF5 I/O.
Flexible Solvers
Powered by torch-sla — linear, nonlinear, and eigenvalue solvers across CPU/GPU backends with autograd, batched solves, and multi-GPU scaling.
Pythonic API
Custom weak forms in pure Python — no DSL, no form compiler. If you can write PyTorch, you can write FEM.
See it in action
Real outputs from the example gallery — meshes, fields, and animations rendered straight from TensorMesh.
3D Poisson
Tetrahedral mesh, cut view of the scalar field.
Animation
Allen–Cahn phase field
Nonlinear time evolution with Newton iteration per step.
Animation
Wave equation
Explicit central-difference time integration.
Hyperelastic rubber
Large-deformation solid mechanics with a Newton solver.
Lid-driven cavity
Incompressible Navier–Stokes; velocity field and streamlines.
Magnetostatics
3D magnetic field around a current-carrying wire (stabilized nodal curl–curl).
Animation
Topology optimization
Compliance minimization via the Optimality Criteria method.
Physics-informed learning
A network trained to minimize the assembled Galerkin residual.
30+ runnable examples
Eleven problem categories, every script ships with the repo and the rendered notebook lives in the docs.
Basics
Mesh viz, basis functions, element gallery.
ViewPoisson
2D/3D Poisson, batched RHS, h-adaptivity.
ViewDiffusion
Heat equation and Allen–Cahn phase field.
ViewWave
Explicit central-difference time integration.
ViewSolid
Hyperelasticity, contact, plasticity, large deformation.
ViewFluid
Lid-driven cavity, cylinder flow, Rayleigh–Bénard, Taylor–Green.
ViewMagnetostatics
3D Maxwell — field around a wire via nodal curl–curl.
ViewInverse design
Coefficient ID and density-based topology optimization, via autograd.
ViewPhysics-informed
Train a network to minimize the assembled Galerkin residual.
ViewDataset
Batch mesh & field generation for ML training.
ViewDistributed
Multi-GPU assembly and mesh partitioning.
ViewFrom mesh to solution in pure Python
A complete Poisson solver — no DSL, no JIT, no surprises. Just PyTorch autograd flowing through every step.
import math
import torch
from tensormesh import ElementAssembler, NodeAssembler, Mesh, Condenser
# 1. Triangular mesh of the unit square.
mesh = Mesh.gen_rectangle(chara_length=0.05)
# 2. Stiffness weak form: a(u, v) = ∫ ∇u · ∇v dΩ
class LaplaceAssembler(ElementAssembler):
def forward(self, gradu, gradv):
return gradu @ gradv
# 3. Load weak form: l(v) = ∫ f v dΩ
class SourceAssembler(NodeAssembler):
def forward(self, v, f):
return f * v
# 4. Source term, evaluated at every mesh node.
x, y = mesh.points[:, 0], mesh.points[:, 1]
f_vals = 2 * math.pi**2 * torch.sin(math.pi * x) * torch.sin(math.pi * y)
# 5. Assemble.
K = LaplaceAssembler.from_mesh(mesh)()
b = SourceAssembler.from_mesh(mesh)(point_data={"f": f_vals})
# 6. Apply Dirichlet BCs via static condensation, then solve.
condenser = Condenser(mesh.boundary_mask)
K_, b_ = condenser(K, b)
u = condenser.recover(K_.solve(b_, verbose=True))
[torch-sla] solve: n=431, nnz=2859, dtype=float64, device=cpu, symmetric=True, spd=False, backend=scipy, method=lu
L2 error: 3.135e-03
The same script runs unchanged on GPU with
mesh = mesh.cuda(), and
becomes differentiable with
mesh.points.requires_grad_(True).
Speed without compromises
Benchmarked against FEniCS, Firedrake, MFEM, scikit-fem, JAX-FEM, and torch-fem on 3D Poisson, linear elasticity, and topology optimization.
3D Poisson — total time vs DOFs
Wall-clock time on tetrahedral meshes for every framework, CPU and CUDA. TensorMesh (CUDA) scales linearly past 10⁶ DOFs.
3D Linear elasticity — total time vs DOFs
Same comparison on a vector-valued elasticity problem. CUDA backends widen the lead on larger meshes.
Stuck on a PDE?
Ask the maintainers.
Pick a channel on the right, or email Shizheng Wen at shizheng.wen@sam.math.ethz.ch for research collaborations.