#include <stdlib.h>
#include <stdio.h>
#include <math.h>

#include <mpi.h>

#define GNUPLT_PARAM_FILE "diffusion.param.gnuplt"
#define OUTPUT_DATA_FILE "diffusion.dat"

#define NPOINTS    512
#define NITERS     100000
#define NOUT       1000
#define XLEN       2.0
#define DIFF_COEF  1.0

void write_step(FILE* datfile, 
                int my_rank, int world_size, int total_size, int my_size, 
                double* u, double t, double xlen, double dx) {
  double* tmpbuf = NULL;
  int*    counts = NULL;
  int*    displs = NULL;

  double x;

  if (my_rank == 0) {
    tmpbuf = (double*)malloc(sizeof(double)*total_size); 
    counts = (int*)malloc(sizeof(double)*world_size);
    displs = (int*)malloc(sizeof(double)*world_size);

    for (int rank = 0; rank < world_size; rank++) {
      displs[rank] = total_size * rank / world_size;
      counts[rank] = total_size * (rank + 1) / world_size - displs[rank];
    }
  }

  MPI_Gatherv(&u[1], my_size, MPI_DOUBLE, tmpbuf,
              counts, displs, MPI_DOUBLE, 0, MPI_COMM_WORLD);

  if (my_rank == 0) {
    for(int i = 0; i < total_size; i++) {
      x = dx * ((double)i - 0.5) - 0.5 * XLEN;
      fprintf(datfile, " %24.14lf %24.14lf %24.14lf\n", t, x, tmpbuf[i]);
    }

    fprintf(datfile, "\n");

    free(tmpbuf);
    free(displs);
    free(counts);
  }
}

int main(int argc, char *argv[]) {
  double dx, dt, dtdx2, x, t;
  
  FILE* datfile = NULL;
  FILE* paramfile = NULL;

  double *uold, *unew;

  int world_size, my_rank, my_start, my_end, my_size;
  int left_rank, right_rank;

  MPI_Request reqs[4];

  MPI_Init(&argc, &argv);

  MPI_Comm_size(MPI_COMM_WORLD, &world_size);
  MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);

  dx = XLEN / NPOINTS;
  dt = 0.5 * dx * dx / DIFF_COEF;
  dtdx2 = dt / (dx * dx);

  my_start = NPOINTS * my_rank / world_size;
  my_end = NPOINTS * (my_rank + 1) / world_size;
  my_size = my_end - my_start;
      
  left_rank =  my_rank > 0            ? my_rank - 1 : world_size-1;
  right_rank = my_rank < world_size-1 ? my_rank + 1 : 0;

  t = 0.0;
      
  uold = (double*)malloc(sizeof(double)*(my_size+2));
  unew = (double*)malloc(sizeof(double)*(my_size+2));

  // Initial conditions
  for(int i = 1; i <= my_size; i++) {
    x = dx * ((double)(i+my_start) - 0.5) - 0.5 * XLEN;
    uold[i] = 0.5 * cos(2.0 * M_PI * x / XLEN) + 0.5;
  }

  if (my_rank == 0) {
    uold[0] = 0.0;
    unew[0] = 0.0;
  }
      
  if (my_rank == world_size-1) {
    uold[my_size+1] = 0.0;
    unew[my_size+1] = 0.0;
  }

  if (my_rank == 0) {
    datfile = fopen(OUTPUT_DATA_FILE, "w");
    paramfile = fopen(GNUPLT_PARAM_FILE, "w");

    fprintf(paramfile, "npoints = %d\n", NPOINTS);
    fprintf(paramfile, "niters = %d\n", NITERS);
    fprintf(paramfile, "nout = %d\n", NOUT);
    fprintf(paramfile, "xlen = %lf\n", XLEN);
    fprintf(paramfile, "dt = %lf\n", dt);
    fclose(paramfile);
  }

  write_step(datfile, my_rank, world_size, NPOINTS, my_size, uold, t, XLEN, dx);

  for (int iter = 1; iter <= NITERS; iter++) {
    MPI_Isend(&uold[my_size],  1, MPI_DOUBLE, right_rank, 0, MPI_COMM_WORLD, &reqs[0]);
    MPI_Irecv(&uold[0],        1, MPI_DOUBLE, left_rank,  0, MPI_COMM_WORLD, &reqs[1]);

    MPI_Isend(&uold[1],         1, MPI_DOUBLE, left_rank,  1, MPI_COMM_WORLD, &reqs[2]);
    MPI_Irecv(&uold[my_size+1], 1, MPI_DOUBLE, right_rank, 1, MPI_COMM_WORLD, &reqs[3]);

    MPI_Waitall(4, reqs, MPI_STATUSES_IGNORE);

    for (int i = 1; i <= my_size; i++) {
      unew[i] = uold[i] + DIFF_COEF * dtdx2
                        * (uold[i+1] - 2.0 * uold[i] + uold[i-1]);
    }

    if (iter % NOUT == 0)
      write_step(datfile, my_rank, world_size, NPOINTS, my_size, uold, t, XLEN, dx);
        
    t += dt;

    double *tmpptr = uold;
    uold = unew;
    unew = tmpptr;
  }

  if (my_rank == 0) fclose(datfile);
  
  free(uold);
  free(unew);

  MPI_Finalize();

  return 0;
}