//
// cudapit.cu
// Neil Gershenfeld 3/1/20
// calculation of pi by a CUDA multi-GPU thread sum
// pi = 3.14159265358979323846 
//
#include <iostream>
#include <chrono>
#include <thread>
#include <vector>
#include <cstdint>
uint64_t blocks = 1024;
uint64_t threads = 1024;
uint64_t nloop = 10000000;
uint64_t npts = blocks*threads;
std::vector<double> results;
__global__ void init(double *arr,uint64_t nloop,uint64_t npts,int index) {
   uint64_t i = blockIdx.x*blockDim.x+threadIdx.x;
   uint64_t start = nloop*i+npts*nloop*index+1;
   uint64_t end = nloop*(i+1)+npts*nloop*index+1;
   arr[i] = 0;
   for (uint64_t j = start; j < end; ++j)
      arr[i] += 0.5/((j-0.75)*(j-0.25));
   }
__global__ void reduce_sum(double *arr,uint64_t len) {
   uint64_t i = blockIdx.x*blockDim.x+threadIdx.x;
   if (i < len)
      arr[i] += arr[i+len];
   }
void reduce(double *arr) {
   uint64_t len = npts >> 1;
   while (1) {
      reduce_sum<<<blocks,threads>>>(arr,len);
      len = len >> 1;
      if (len == 0)
         return;
      }
   }
void sum(int index) {
   cudaSetDevice(index);
   double harr[1],*darr;
   cudaMalloc(&darr,npts*sizeof(double));
   init<<<blocks,threads>>>(darr,nloop,npts,index);
   reduce(darr);
   cudaDeviceSynchronize();
   cudaMemcpy(harr,darr,8,cudaMemcpyDeviceToHost);
   results[index] = harr[0];
   cudaFree(darr);
   }
int main(void) {
   int ngpus;
   cudaGetDeviceCount(&ngpus);
   std::thread threads[ngpus];
   double pi = 0;
   auto tstart = std::chrono::high_resolution_clock::now();        
   for (int i = 0; i < ngpus; ++i) {
      results.push_back(0);
      threads[i] = std::thread(sum,i);
      }
   for (int i = 0; i < ngpus; ++i) {
      threads[i].join();
      pi += results[i];
      }
   auto tend = std::chrono::high_resolution_clock::now();        
	auto dt = std::chrono::duration_cast<std::chrono::microseconds>(tend-tstart).count();
   auto gflops = npts*nloop*ngpus*5.0/dt/1e3;
   std::cout << "npts: " << npts << " nloop: " << nloop << " ngpus: " << ngpus << " pi: " << pi << '\n';
   std::cout << "time: " << 1e-6*dt << " estimated GFlops: " << gflops << '\n';
   return 0;
   }