<aside> 📎
Author: Thomas Ghorbanian
</aside>
Scientific computing has long faced a trade-off: high-level libraries offer convenience, rich syntax, and diverse data types that enable rapid prototyping, while low-level languages provide fine-grained control over memory management, enabling direct kernel and hardware-specific optimizations crucial for high performance.
In recent years, pairing Python interfaces with high-performance kernels in lower-level languages has emerged as a new scientific software design pattern (see: JAX, Numba, Taichi). The approach separates code semantics from compiler optimizations, and continued improvements in the latter are narrowing performance trade-offs. The success of this framework has encouraged more researchers to adopt high-level tooling, leading to development of domain-specific extensions in fields traditionally dominated by low-level libraries, such as molecular dynamics (MD), computational fluid dynamics (CFD), and finite element analysis (FEA).
As this paradigm continues to gain momentum, many new researchers have grown comfortable using these tools as black boxes. However, developing with both high-level and low-level mindsets need not be mutually exclusive. Understanding the architecture of high-level tools can significantly enhance our capabilities as scientific software developers. This architectural knowledge enables us to write more efficient code, leverage advanced features effectively, and diagnose performance bottlenecks with greater precision.
We'll take JAX as an example, exploring how insights into its internal workings can lead to more thoughtful algorithm design and implementation.
While all high-level libraries ultimately execute as machine instructions, JAX distinguishes itself in its translation process. Libraries like NumPy already offer excellent performance for array operations by leveraging pre-compiled C and Fortran code. JAX takes this one step further by employing a compilation strategy that allows for optimization across operation boundaries and more effective use of hardware accelerators.
Figure 1. High level overview of JAX’s execution pipeline.
JAX achieves these optimizations through three key technologies:
When a function is first called, JAX creates a blueprint of its computational structure via tracing:
Tracer
objects, which hold only abstract information about the inputs (shape and data type), rather than their actual values.Tracer
objects are intercepted (recorded) as a sequence of granular primitives (e.g. add
, multiply
, sin
).This process generates a jaxpr (JAX exPRession) - a static, intermediate representation of the function. Jaxprs are independent of specific input values, serving as the basis for JAX to transform (technology 2) and compile (technology 3) computations without re-executing the original Python code.
Transformations in JAX are higher order functions (they take functions as input and produce new, modified functions as output). Like creating recipe variations (e.g. a larger batch version, a calorie-reduced version, or an optimized quick-prep version), transformations of functions are what make JAX inherently versatile.
JAX's transformations work by reinterpreting the computational graph and jaxpr representation of the function generated during tracing. For example: