XGCa
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
access_add.hpp
Go to the documentation of this file.
1 #ifndef ACCESS_ADD_HPP
2 #define ACCESS_ADD_HPP
3 #include "space_settings.hpp"
4 
5 /* ScatterType template specifies whether the class uses atomics or array replication
6  * */
7 enum class ScatterType{
8  Atomic,
10 };
11 
12 #ifdef USE_GPU
13 # ifdef USE_ARRAY_REPLICATION
14 # error Cannot use ARRAY REPLICATION on GPU
15 # endif
16 #else
17 # ifndef USE_ARRAY_REPLICATION
18 # error Must use ARRAY REPLICATION if executing on CPU for now
19 # endif
20 #endif
21 
22 #ifdef USE_ARRAY_REPLICATION
24 #else
26 #endif
27 
28 // returns omp thread if on CPU (using array replication strategy for scatter) or
29 // returns 0 on GPU (uses atomics, no replication needed)
30 KOKKOS_INLINE_FUNCTION int get_thread(){
31 #if defined(USE_ARRAY_REPLICATION) && defined(USE_OMP)
32  return omp_get_thread_num();
33 #else
34  return 0;
35 #endif
36 }
37 
38 // access_add (could use scatter view instead): Uses atomic if array replication is off
39 template<typename T>
40 KOKKOS_INLINE_FUNCTION void access_add(T* addr, T val){
41 #ifdef USE_ARRAY_REPLICATION
42  *addr += val;
43 #else
44  Kokkos::atomic_add(addr, val);
45 #endif
46 }
47 
48 /* Sums the contents of a View into the 0th index of its first dimension if USE_ARRAY_REPLICATION is on
49  */
50 template<typename T>
52 #ifdef USE_ARRAY_REPLICATION
53  int n_threads = view.extent(0);
54  int size_per_thread = view.size()/n_threads;
55 
56  auto thread_0_ptr = view.data();
57  auto thread_i_ptr = view.data();
58 
59  for(int i = 1; i<n_threads; i++){
60  thread_i_ptr += size_per_thread;
61  Kokkos::parallel_for("reduce_replicated_array", Kokkos::RangePolicy<HostExSpace>( 0, size_per_thread), KOKKOS_LAMBDA(const int idx){
62  thread_0_ptr[idx] += thread_i_ptr[idx];
63  });
64  }
65 #endif
66 }
67 
68 #endif
KOKKOS_INLINE_FUNCTION int get_thread()
Definition: access_add.hpp:30
KOKKOS_INLINE_FUNCTION void access_add(T *addr, T val)
Definition: access_add.hpp:40
ScatterType
Definition: access_add.hpp:7
void reduce_replicated_array(T &view)
Definition: access_add.hpp:51
constexpr ScatterType SCATTER_TYPE_GLOBAL
Definition: access_add.hpp:25
void parallel_for(const std::string name, int n_ptl, Function func, Option option, HostAoSoA aosoa_h, DeviceAoSoA aosoa_d)
Definition: streamed_parallel_for.hpp:252