XGCa
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
broadcast_views.hpp
Go to the documentation of this file.
1 #ifndef BROADCAST_VIEWS_HPP
2 #define BROADCAST_VIEWS_HPP
3 
4 #include "space_settings.hpp"
5 #include "my_mpi.hpp"
6 
7 // Utility to resize a view based on extents
8 template <typename ViewType>
9 void resize_view(ViewType& view, const std::vector<int>& extents) {
10  if constexpr (ViewType::rank == 1) {
11  view = ViewType(view.label(), extents[0]);
12  } else if constexpr (ViewType::rank == 2) {
13  view = ViewType(view.label(), extents[0], extents[1]);
14  } else if constexpr (ViewType::rank == 3) {
15  view = ViewType(view.label(), extents[0], extents[1], extents[2]);
16  } else if constexpr (ViewType::rank == 4) {
17  view = ViewType(view.label(), extents[0], extents[1], extents[2], extents[3]);
18  } else if constexpr (ViewType::rank == 5) {
19  view = ViewType(view.label(), extents[0], extents[1], extents[2], extents[3], extents[4]);
20  } else if constexpr (ViewType::rank == 6) {
21  view = ViewType(view.label(), extents[0], extents[1], extents[2], extents[3], extents[4], extents[5]);
22  }
23 }
24 
25 // Function template to broadcast and resize an arbitrary number of Kokkos Views
26 template <typename... Views>
27 void broadcast_views(const MPI_Comm& comm, const int ROOT_RANK, Views&... views) {
28  int my_rank;
29  int nranks;
30  MPI_Comm_rank( comm, &my_rank);
31  MPI_Comm_size( comm, &nranks);
32 
33  // Allocate extents and use extents found on root rank
34  std::vector<std::vector<int>> extents_list;
35  ([&](){
36  std::vector<int> extents(Views::rank);
37  if(my_rank==ROOT_RANK){
38  for (int i = 0; i < Views::rank; i++) {
39  extents[i] = views.extent(i);
40  }
41  }
42  extents_list.push_back(extents);
43  }(), ...); // Fold expression for multiple views
44 
45  // Loop through each view, broadcast its extents
46  for (int v = 0; v < (sizeof...(Views)); v++) {
47  MPI_Bcast(extents_list[v].data(), extents_list[v].size(), MPI_INT, ROOT_RANK, comm);
48  }
49 
50  // Resize each view on all other ranks to match extents from Rank ROOT_RANK
51  if(my_rank!=ROOT_RANK){
52  int view_index = 0;
53  ([&](){
54  resize_view(views, extents_list[view_index]);
55  view_index++;
56  }(), ...);
57  }
58 
59  // Broadcast view data from ROOT_RANK
60  ([&](){
61  MPI_Bcast(views.data(), views.size(), get_mpi_type(views), ROOT_RANK, comm);
62  }(), ...);
63 }
64 
65 #endif
void broadcast_views(const MPI_Comm &comm, const int ROOT_RANK, Views &...views)
Definition: broadcast_views.hpp:27
MPI_Datatype get_mpi_type()
Definition: my_mpi.hpp:200
void resize_view(ViewType &view, const std::vector< int > &extents)
Definition: broadcast_views.hpp:9