TensorMesh

TensorMesh

A fast, differentiable, JIT-free, debugging-friendly finite element library for PyTorch.

Developed at

Why TensorMesh

Built for the workflow where finite elements meet deep learning — without sacrificing the speed and accuracy you expect from a real FEM library.

GPU-Native & Differentiable

Built on PyTorch — move the entire FEM workflow to GPU with one line. Autograd flows seamlessly through assembly and solve for end-to-end differentiable PDE pipelines.

Tensorized Assembly

A fully tensorized Map-Reduce algorithm powered by TensorGalerkin fuses element-wise ops into monolithic GPU kernels — order-of-magnitude speedups over CPU-based FEM stacks.

JIT-Free & Debugging-Friendly

Eager execution with no compilation overhead. Dynamic meshes, adaptive refinement, and interactive workflows just work — no recompilation latency, no opaque traces.

Element & Mesh Support

Triangular, tetrahedral, pyramid, and prismatic elements. Automated mesh generation for common geometries with seamless Gmsh and VTK-HDF5 I/O.

Flexible Solvers

Powered by torch-sla — linear, nonlinear, and eigenvalue solvers across CPU/GPU backends with autograd, batched solves, and multi-GPU scaling.

Pythonic API

Custom weak forms in pure Python — no DSL, no form compiler. If you can write PyTorch, you can write FEM.

From mesh to solution in pure Python

A complete Poisson solver — no DSL, no JIT, no surprises. Just PyTorch autograd flowing through every step.

quickstart.py
import math
import torch
from tensormesh import ElementAssembler, NodeAssembler, Mesh, Condenser

# 1. Triangular mesh of the unit square.
mesh = Mesh.gen_rectangle(chara_length=0.05)

# 2. Stiffness weak form:  a(u, v) = ∫ ∇u · ∇v dΩ
class LaplaceAssembler(ElementAssembler):
    def forward(self, gradu, gradv):
        return gradu @ gradv

# 3. Load weak form:  l(v) = ∫ f v dΩ
class SourceAssembler(NodeAssembler):
    def forward(self, v, f):
        return f * v

# 4. Source term, evaluated at every mesh node.
x, y = mesh.points[:, 0], mesh.points[:, 1]
f_vals = 2 * math.pi**2 * torch.sin(math.pi * x) * torch.sin(math.pi * y)

# 5. Assemble.
K = LaplaceAssembler.from_mesh(mesh)()
b = SourceAssembler.from_mesh(mesh)(point_data={"f": f_vals})

# 6. Apply Dirichlet BCs via static condensation, then solve.
condenser = Condenser(mesh.boundary_mask)
K_, b_ = condenser(K, b)
u = condenser.recover(K_.solve(b_, verbose=True))
Numerical solution u(x, y) = sin(πx)sin(πy) on the unit square
$ python quickstart.py
[torch-sla] solve: n=431, nnz=2859, dtype=float64, device=cpu, symmetric=True, spd=False, backend=scipy, method=lu
L2 error: 3.135e-03

The same script runs unchanged on GPU with mesh = mesh.cuda(), and becomes differentiable with mesh.points.requires_grad_(True).

Read the full quickstart →

Speed without compromises

Benchmarked against FEniCS, Firedrake, MFEM, scikit-fem, JAX-FEM, and torch-fem on 3D Poisson, linear elasticity, and topology optimization.

3D Poisson — total time vs DOFs

3D Poisson — total time vs DOFs

Wall-clock time on tetrahedral meshes for every framework, CPU and CUDA. TensorMesh (CUDA) scales linearly past 10⁶ DOFs.

3D Linear elasticity — total time vs DOFs

3D Linear elasticity — total time vs DOFs

Same comparison on a vector-valued elasticity problem. CUDA backends widen the lead on larger meshes.

Open community · Apache 2.0

Stuck on a PDE?
Ask the maintainers.

Pick a channel on the right, or email Shizheng Wen at shizheng.wen@sam.math.ethz.ch for research collaborations.

Maintainers reply here Researchers welcome