diff --git a/Python/jaxpi.py b/Python/jaxpi.py new file mode 100644 index 0000000000000000000000000000000000000000..0f185f3c3d470e7e215c033275b23c3828d74743 --- /dev/null +++ b/Python/jaxpi.py @@ -0,0 +1,67 @@ +# +# jaxpi.py +# Neil Gershenfeld 12/21/24 +# Jax pi calculation benchmark +# pi = 3.14159265358979323846 +# +import jax +import jax.numpy as jnp +import numpy as np +import time +# +NPTS = 100000000 +# +a = 0.5 +b = 0.75 +c = 0.25 +# +# alternate compilation values to prevent caching +# +a0 = 0.6 +b0 = 0.7 +c0 = 0.2 +# +print("\nNumPy version:") +def num_calcpi(a,b,c): + i = np.arange(1,(NPTS+1),dtype=float) + pi = np.sum(a/((i-b)*(i-c))) + return pi +start_time = time.time() +pi = num_calcpi(a,b,c) +end_time = time.time() +mflops = NPTS*5.0/(1.0e6*(end_time-start_time)) +print("NPTS = %d, pi = %f"%(NPTS,pi)) +print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops)) +# +print("\ncompile Jax version:") +def jax_calcpi(a,b,c): + i = jnp.arange(1,(NPTS+1),dtype=float) + pi = jnp.sum(a/((i-b)*(i-c))) + return pi +start_time = time.time() +pi = jax_calcpi(a0,b0,c0).block_until_ready() +end_time = time.time() +print("time = %f"%(end_time-start_time)) +# +print("\nrun Jax version:") +start_time = time.time() +pi = jax_calcpi(a,b,c).block_until_ready() +end_time = time.time() +mflops = NPTS*5.0/(1.0e6*(end_time-start_time)) +print("NPTS = %d, pi = %f"%(NPTS,pi)) +print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops)) +# +print("\ncompile Jax Jit version:") +jax_jit_calcpi = jax.jit(jax_calcpi) +start_time = time.time() +pi = jax_jit_calcpi(a0,b0,c0).block_until_ready() +end_time = time.time() +print("time = %f"%(end_time-start_time)) +# +print("\nrun Jax Jit version:") +start_time = time.time() +pi = jax_jit_calcpi(a,b,c).block_until_ready() +end_time = time.time() +mflops = NPTS*5.0/(1.0e6*(end_time-start_time)) +print("NPTS = %d, pi = %f"%(NPTS,pi)) +print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops)) diff --git a/README.md b/README.md index d12e4546e1d6e1e5748e0a0f5cbce048fcb78b4d..fa682a5f6c04ae5ef2a59370840bd1f9cc843777 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ |1,090|[numbapig.py](Python/numbapig.py)|Python, Numba, CUDA, 5120 cores|NVIDIA V100|March, 2020| |1,062|[taichipi.py](Python/taichipi.py)|Python, Taichi, 5120 cores|NVIDIA V100|March, 2023| |811|prior|Cray XT4|C, MPI, 2048 processes|prior| +|604|[jaxpi.py](Python/jaxpi.py)|Python, Jax, 5120 cores|NVIDIA V100|December, 2024| |501|[rayonpi.rs](Rust/rayonpi.rs)|Rust, Rayon, 96 cores<br>cargo run --release|Graviton4|December, 2024| |484|[threadpi.rs](Rust/threadpi.rs)|Rust, threads, 96 cores<br>cargo run --release -- 96|Graviton4|December, 2024| |315|[numbapip.py](Python/numbapip.py)|Python, Numba, parallel, fastmath<br>96 cores|Intel 2x Xeon Platinum 8175M|February, 2020|