Program Listing for File mesh.hpp

Return to documentation for file (src/mesh/mesh.hpp)

#pragma once

#include <algorithm>
#include <cstdint>

#include "output.hpp"

#ifndef __LILIM_DEFAULT_NGHOST
#define __LILIM_DEFAULT_NGHOST 2
#endif

namespace lili::mesh {
enum class MeshGhostLocation : uint8_t {
  XPrev = 0,
  XNext = 1,
  YPrev = 2,
  YNext = 3,
  ZPrev = 4,
  ZNext = 5,
};

struct MeshSize {
  int dim;
  int nx;
  int ny;
  int nz;
  int ngx;
  int ngy;
  int ngz;
  double lx;
  double ly;
  double lz;
  double x0;
  double y0;
  double z0;
};

class MeshSizeC {
 public:
  int dim;
  int nx;
  int ny;
  int nz;
  int ngx;
  int ngy;
  int ngz;
  double lx;
  double ly;
  double lz;
  double x0;
  double y0;
  double z0;
};

void PrintMeshSize(const MeshSize& mesh_size, lili::output::LiliCout& lout);

void UpdateMeshSizeDim(MeshSize& mesh_size);

template <typename T>
class Mesh {
 public:
  Mesh()
      : dim_(0),
        nx_(0),
        ny_(0),
        nz_(0),
        ngx_(0),
        ngy_(0),
        ngz_(0),
        data_(nullptr) {}

  // Size-based initialization
  Mesh(int nx)
      : dim_(1),
        nx_(nx),
        ny_(1),
        nz_(1),
        ngx_(0),
        ngy_(0),
        ngz_(0),
        data_(nullptr) {
    InitializeData();
  }

  Mesh(int nx, int ny)
      : dim_(2),
        nx_(nx),
        ny_(ny),
        nz_(1),
        ngx_(0),
        ngy_(0),
        ngz_(0),
        data_(nullptr) {
    InitializeData();
  }

  Mesh(int nx, int ny, int nz)
      : dim_(nz > 1 ? 3 : (ny > 1 ? 2 : 1)),
        nx_(nx),
        ny_(ny),
        nz_(nz),
        ngx_(0),
        ngy_(0),
        ngz_(0),
        data_(nullptr) {
    InitializeData();
  }

  Mesh(int nx, int ny, int nz, int ng)
      : dim_(nz > 1 ? 3 : (ny > 1 ? 2 : 1)),
        nx_(nx),
        ny_(ny),
        nz_(nz),
        ngx_(ng),
        ngy_(ng),
        ngz_(ng),
        data_(nullptr) {
    InitializeData();
  }

  Mesh(int nx, int ny, int nz, int ngx, int ngy, int ngz)
      : dim_(nz > 1 ? 3 : (ny > 1 ? 2 : 1)),
        nx_(nx),
        ny_(ny),
        nz_(nz),
        ngx_(ngx),
        ngy_(ngy),
        ngz_(ngz),
        data_(nullptr) {
    InitializeData();
  }

  Mesh(const MeshSize& domain_size)
      : dim_(domain_size.dim),
        nx_(domain_size.nx),
        ny_(domain_size.ny),
        nz_(domain_size.nz),
        ngx_(domain_size.ngx),
        ngy_(domain_size.ngy),
        ngz_(domain_size.ngz),
        data_(nullptr) {
    InitializeData();
  }

  Mesh(const Mesh& other)
      : dim_(other.dim_),
        nx_(other.nx_),
        ny_(other.ny_),
        nz_(other.nz_),
        ngx_(other.ngx_),
        ngy_(other.ngy_),
        ngz_(other.ngz_),
        data_(nullptr) {
    InitializeData();
    std::copy(other.data_, other.data_ + other.nt(), data_);
  }

  Mesh(Mesh&& other) noexcept : Mesh() { swap(*this, other); }

  ~Mesh() {
    if (data_ != nullptr) {
      delete[] data_;
    }
  }

  friend void swap(Mesh<T>& first, Mesh<T>& second) noexcept {
    using std::swap;
    swap(first.dim_, second.dim_);
    swap(first.nx_, second.nx_);
    swap(first.ny_, second.ny_);
    swap(first.nz_, second.nz_);
    swap(first.ngx_, second.ngx_);
    swap(first.ngy_, second.ngy_);
    swap(first.ngz_, second.ngz_);
    swap(first.ntx_, second.ntx_);
    swap(first.nty_, second.nty_);
    swap(first.ntz_, second.ntz_);
    swap(first.nt_, second.nt_);
    swap(first.data_, second.data_);
  }

  // Getters
  constexpr int dim() const { return dim_; };
  constexpr int nx() const { return nx_; };
  constexpr int ny() const { return ny_; };
  constexpr int nz() const { return nz_; };
  constexpr int ngx() const { return ngx_; };
  constexpr int ngy() const { return ngy_; };
  constexpr int ngz() const { return ngz_; };
  constexpr int ntx() const { return ntx_; };
  constexpr int nty() const { return nty_; };
  constexpr int ntz() const { return ntz_; };
  constexpr int nt() const { return nt_; };
  constexpr T* data() const { return data_; };

  // Operators

  Mesh<T>& operator=(Mesh<T> other) {
    swap(*this, other);
    return *this;
  };

  Mesh<T>& operator=(T value) {
    for (int i = 0; i < nt_; ++i) {
      data_[i] = value;
    }
    return *this;
  };

  Mesh<T>& operator+=(T value) {
    for (int i = 0; i < nt_; ++i) {
      data_[i] += value;
    }
    return *this;
  };

  // Raw access operator (1D)
  T operator()(int i) const { return data_[i]; };
  T& operator()(int i) { return data_[i]; };

  // Smart access operator (3D)
  T operator()(int i, int j, int k) const {
    return data_[ngx_ + i + ntx_ * (ngy_ + j + nty_ * (ngz_ + k))];
  };
  T& operator()(int i, int j, int k) {
    return data_[ngx_ + i + ntx_ * (ngy_ + j + nty_ * (ngz_ + k))];
  };

  void UpdateTotalSizes() {
    dim_ = (nz_ > 1) ? 3 : ((ny_ > 1) ? 2 : 1);

    ntx_ = nx_ + 2 * ngx_;
    nty_ = ny_ + 2 * ngy_;
    ntz_ = nz_ + 2 * ngz_;

    nt_ = ntx_ * nty_ * ntz_;
  };

  bool SameSizeAs(const Mesh& other) {
    return (nx_ == other.nx_ && ny_ == other.ny_ && nz_ == other.nz_ &&
            ngx_ == other.ngx_ && ngy_ == other.ngy_ && ngz_ == other.ngz_);
  };

  void InitializeData() {
    // Update total mesh sizes
    UpdateTotalSizes();

    // Check if the data is already allocated
    if (data_ != nullptr) {
      delete[] data_;
    }

    // Allocate memory
    data_ = new T[nt_]();
  };

  void Resize(int nx, int ny, int nz, int ngx, int ngy, int ngz) {
    // Update mesh sizes
    bool size_changed = false;

    if (nx != nx_) {
      nx_ = nx;
      size_changed = true;
    }
    if (ny != ny_) {
      ny_ = ny;
      size_changed = true;
    }
    if (nz != nz_) {
      nz_ = nz;
      size_changed = true;
    }
    if (ngx != ngx_) {
      ngx_ = ngx;
      size_changed = true;
    }
    if (ngy != ngy_) {
      ngy_ = ngy;
      size_changed = true;
    }
    if (ngz != ngz_) {
      ngz_ = ngz;
      size_changed = true;
    }

    // Update total mesh sizes
    UpdateTotalSizes();

    // Reallocate memory if needed
    if (size_changed) {
      if (data_ != nullptr) {
        delete[] data_;
      }
      data_ = new T[nt_]();
    }
  };

  void Shrink(int nx, int ny, int nz, int ngx, int ngy, int ngz) {
    // Store the old size
    int old_nt = nt_;

    // Update mesh sizes
    nx_ = nx;
    ny_ = ny;
    nz_ = nz;
    ngx_ = ngx;
    ngy_ = ngy;
    ngz_ = ngz;

    // Update total mesh sizes
    UpdateTotalSizes();

    // Check if the new mesh size is different from the old one
    if (nt_ != old_nt) {
      std::cerr << "Cannot shrink the mesh inplace..." << std::endl;
      exit(2);
    }
  };

  void CopyToGhost(const Mesh& other, MeshGhostLocation gl) {
    switch (gl) {
      case MeshGhostLocation::XPrev:
        // Make sure the other mesh has the same relevant size
        if (other.ny() != ny_ || other.nz() != nz_ || other.nx() < ngx_) {
          std::cerr << "Invalid ghost mesh size..." << std::endl;
          exit(2);
        } else {
          // Cache variable
          int noff = other.nx();
          // Copy data
          for (int i = -ngx_; i < 0; ++i) {
            for (int j = 0; j < ny_; ++j) {
              for (int k = 0; k < nz_; ++k) {
                (*this)(i, j, k) = other(noff + i, j, k);
              }
            }
          }
        }
        break;
      case MeshGhostLocation::XNext:
        // Make sure the other mesh has the same relevant size
        if (other.ny() != ny_ || other.nz() != nz_ || other.nx() < ngx_) {
          std::cerr << "Invalid ghost mesh size..." << std::endl;
          exit(2);
        } else {
          // Cache variable
          int noff = -nx_;
          // Copy data
          for (int i = nx_; i < (nx_ + ngx_); ++i) {
            for (int j = 0; j < ny_; ++j) {
              for (int k = 0; k < nz_; ++k) {
                (*this)(i, j, k) = other(noff + i, j, k);
              }
            }
          }
        }
        break;
      case MeshGhostLocation::YPrev:
        // Make sure the other mesh has the same relevant size
        if (other.nx() != nx_ || other.nz() != nz_ || other.ny() < ngy_) {
          std::cerr << "Invalid ghost mesh size..." << std::endl;
          exit(2);
        } else {
          // Cache variable
          int noff = other.ny();
          // Copy data
          for (int i = 0; i < nx_; ++i) {
            for (int j = -ngy_; j < 0; ++j) {
              for (int k = 0; k < nz_; ++k) {
                (*this)(i, j, k) = other(i, noff + j, k);
              }
            }
          }
        }
        break;
      case MeshGhostLocation::YNext:
        // Make sure the other mesh has the same relevant size
        if (other.nx() != nx_ || other.nz() != nz_ || other.ny() < ngy_) {
          std::cerr << "Invalid ghost mesh size..." << std::endl;
          exit(2);
        } else {
          // Cache variable
          int noff = -ny_;
          // Copy data
          for (int i = 0; i < nx_; ++i) {
            for (int j = ny_; j < (ny_ + ngy_); ++j) {
              for (int k = 0; k < nz_; ++k) {
                (*this)(i, j, k) = other(i, noff + j, k);
              }
            }
          }
        }
        break;
      case MeshGhostLocation::ZPrev:
        // Make sure the other mesh has the same relevant size
        if (other.nx() != nx_ || other.ny() != ny_ || other.nz() < ngz_) {
          std::cerr << "Invalid ghost mesh size..." << std::endl;
          exit(2);
        } else {
          // Cache variable
          int noff = other.nz();
          // Copy data
          for (int i = 0; i < nx_; ++i) {
            for (int j = 0; j < ny_; ++j) {
              for (int k = -ngz_; k < 0; ++k) {
                (*this)(i, j, k) = other(i, j, noff + k);
              }
            }
          }
        }
        break;
      case MeshGhostLocation::ZNext:
        // Make sure the other mesh has the same relevant size
        if (other.nx() != nx_ || other.ny() != ny_ || other.nz() < ngz_) {
          std::cerr << "Invalid ghost mesh size..." << std::endl;
          exit(2);
        } else {
          // Cache variable
          int noff = -nz_;
          // Copy data
          for (int i = 0; i < nx_; ++i) {
            for (int j = 0; j < ny_; ++j) {
              for (int k = nz_; k < (nz_ + ngz_); ++k) {
                (*this)(i, j, k) = other(i, j, noff + k);
              }
            }
          }
        }
        break;
      default:
        std::cerr << "Invalid ghost location..." << std::endl;
        break;
    }
  };

  T LinearInterpolation(double x) const {
    // Cache variables
    int ix = static_cast<int>(x);
    double xd = x - ix;

    // Interpolate
    return (1.0 - xd) * (*this)(ix, 0, 0) + xd * (*this)(ix + 1, 0, 0);
  };

  T BilinearInterpolation(double x, double y) const {
    // Cache variables
    int ix = static_cast<int>(x);
    int iy = static_cast<int>(y);

    double xd = x - ix;
    double yd = y - iy;

    // Interpolate
    return (1.0 - xd) *
               ((1.0 - yd) * (*this)(ix, iy, 0) + yd * (*this)(ix, iy + 1, 0)) +
           xd * ((1.0 - yd) * (*this)(ix + 1, iy, 0) +
                 yd * (*this)(ix + 1, iy + 1, 0));
  }

  T TrilinearInterpolation(double x, double y, double z) const {
    // Cache variables
    int ix = static_cast<int>(x);
    int iy = static_cast<int>(y);
    int iz = static_cast<int>(z);

    double xd = x - ix;
    double yd = y - iy;
    double zd = z - iz;

    // Interpolate
    return (1.0 - xd) * ((1.0 - yd) * ((1.0 - zd) * (*this)(ix, iy, iz) +
                                       zd * (*this)(ix, iy, iz + 1)) +
                         yd * ((1.0 - zd) * (*this)(ix, iy + 1, iz) +
                               zd * (*this)(ix, iy + 1, iz + 1))) +
           xd * ((1.0 - yd) * ((1.0 - zd) * (*this)(ix + 1, iy, iz) +
                               zd * (*this)(ix + 1, iy, iz + 1)) +
                 yd * ((1.0 - zd) * (*this)(ix + 1, iy + 1, iz) +
                       zd * (*this)(ix + 1, iy + 1, iz + 1)));
  };

  T Interpolation(double x, double y, double z) const {
    if (dim_ == 2) {
      return BilinearInterpolation(x, y);
    } else if (dim_ == 3) {
      return TrilinearInterpolation(x, y, z);
    } else {
      return LinearInterpolation(x);
    }
  }

 private:
  int dim_;                   // Mesh dimension
  int nx_, ny_, nz_;          // Mesh sizes
  int ngx_, ngy_, ngz_;       // Ghost cells sizes (same for before and after)
  int ntx_, nty_, ntz_, nt_;  // Total mesh sizes (including ghost cells)

  T* data_;  // Pointer to the data block
};

void SaveMesh(Mesh<double>& mesh, const char* file_name, const char* data_name,
              bool include_ghost = false);

void LoadMeshTo(Mesh<double>& mesh, const char* file_name,
                const char* data_name, bool include_ghost = false);
}  // namespace lili::mesh