diff --git a/Python/jaxpi.py b/Python/jaxpi.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f185f3c3d470e7e215c033275b23c3828d74743
--- /dev/null
+++ b/Python/jaxpi.py
@@ -0,0 +1,67 @@
+#
+# jaxpi.py
+# Neil Gershenfeld 12/21/24
+# Jax pi calculation benchmark
+# pi = 3.14159265358979323846
+#
+import jax
+import jax.numpy as jnp
+import numpy as np
+import time
+#
+NPTS = 100000000
+#
+a = 0.5
+b = 0.75
+c = 0.25
+#
+# alternate compilation values to prevent caching
+#
+a0 = 0.6
+b0 = 0.7
+c0 = 0.2
+#
+print("\nNumPy version:")
+def num_calcpi(a,b,c):
+   i = np.arange(1,(NPTS+1),dtype=float)
+   pi = np.sum(a/((i-b)*(i-c)))
+   return pi
+start_time = time.time()
+pi = num_calcpi(a,b,c)
+end_time = time.time()
+mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
+print("NPTS = %d, pi = %f"%(NPTS,pi))
+print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))
+#
+print("\ncompile Jax version:")
+def jax_calcpi(a,b,c):
+   i = jnp.arange(1,(NPTS+1),dtype=float)
+   pi = jnp.sum(a/((i-b)*(i-c)))
+   return pi
+start_time = time.time()
+pi = jax_calcpi(a0,b0,c0).block_until_ready()
+end_time = time.time()
+print("time = %f"%(end_time-start_time))
+#
+print("\nrun Jax version:")
+start_time = time.time()
+pi = jax_calcpi(a,b,c).block_until_ready()
+end_time = time.time()
+mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
+print("NPTS = %d, pi = %f"%(NPTS,pi))
+print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))
+#
+print("\ncompile Jax Jit version:")
+jax_jit_calcpi = jax.jit(jax_calcpi)
+start_time = time.time()
+pi = jax_jit_calcpi(a0,b0,c0).block_until_ready()
+end_time = time.time()
+print("time = %f"%(end_time-start_time))
+#
+print("\nrun Jax Jit version:")
+start_time = time.time()
+pi = jax_jit_calcpi(a,b,c).block_until_ready()
+end_time = time.time()
+mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
+print("NPTS = %d, pi = %f"%(NPTS,pi))
+print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))
diff --git a/README.md b/README.md
index d12e4546e1d6e1e5748e0a0f5cbce048fcb78b4d..fa682a5f6c04ae5ef2a59370840bd1f9cc843777 100644
--- a/README.md
+++ b/README.md
@@ -15,6 +15,7 @@
 |1,090|[numbapig.py](Python/numbapig.py)|Python, Numba, CUDA, 5120 cores|NVIDIA V100|March, 2020|
 |1,062|[taichipi.py](Python/taichipi.py)|Python, Taichi, 5120 cores|NVIDIA V100|March, 2023|
 |811|prior|Cray XT4|C, MPI, 2048 processes|prior|
+|604|[jaxpi.py](Python/jaxpi.py)|Python, Jax, 5120 cores|NVIDIA V100|December, 2024|
 |501|[rayonpi.rs](Rust/rayonpi.rs)|Rust, Rayon, 96 cores<br>cargo run --release|Graviton4|December, 2024|
 |484|[threadpi.rs](Rust/threadpi.rs)|Rust, threads, 96 cores<br>cargo run --release -- 96|Graviton4|December, 2024|
 |315|[numbapip.py](Python/numbapip.py)|Python, Numba, parallel, fastmath<br>96 cores|Intel 2x Xeon Platinum 8175M|February, 2020|