array
C++ library for multi-dimensional arrays
matrix.h
Go to the documentation of this file.
1 // Copyright 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
18 #ifndef NDARRAY_MATRIX_H
19 #define NDARRAY_MATRIX_H
20 
21 #include "array/array.h"
22 
23 namespace nda {
24 
30 template <index_t Rows = dynamic, index_t Cols = dynamic>
32 
35 template <class T, index_t Rows = dynamic, index_t Cols = dynamic, class Alloc = std::allocator<T>>
37 template <class T, index_t Rows = dynamic, index_t Cols = dynamic>
39 template <class T, index_t Rows = dynamic, index_t Cols = dynamic>
41 
43 template <index_t Length = dynamic>
45 
46 template <class T, index_t Length = dynamic, class Alloc = std::allocator<T>>
48 template <class T, index_t Length = dynamic>
50 template <class T, index_t Length = dynamic>
52 
55 template <class T, index_t Rows, index_t Cols>
57 template <class T, index_t Length>
59 
61 template <class Shape, class Fn>
62 void for_each_matrix_index(const Shape& s, Fn&& fn) {
63  for (index_t i : s.i()) {
64  for (index_t j : s.j()) {
65  fn(std::tuple<index_t, index_t>(i, j));
66  }
67  }
68 }
69 
70 template <index_t Rows, index_t Cols>
71 class shape_traits<matrix_shape<Rows, Cols>> {
72 public:
74 
75  template <class Fn>
76  static void for_each_index(const shape_type& s, Fn&& fn) {
77  for_each_matrix_index(s, fn);
78  }
79 
80  template <class Ptr, class Fn>
81  static void for_each_value(const shape_type& s, Ptr base, Fn&& fn) {
83  s, [=, fn = std::move(fn)](const typename shape_type::index_type& i) { fn(base[s[i]]); });
84  }
85 };
86 
87 } // namespace nda
88 
89 #endif // NDARRAY_MATRIX_H
index_of_rank< rank()> index_type
Definition: array.h:1076
Definition: array.h:1036
static NDARRAY_HOST_DEVICE void for_each_value(const Shape &shape, Ptr base, Fn &&fn)
Definition: array.h:1882
Definition: array.h:1961
Definition: array.h:1963
Main header for array library.
Definition: absl.h:10
Definition: array.h:1867
Definition: array.h:3077
std::ptrdiff_t index_t
Definition: array.h:87
Definition: array.h:231
void for_each_matrix_index(const Shape &s, Fn &&fn)
Definition: matrix.h:62
static NDARRAY_HOST_DEVICE void for_each_index(const Shape &shape, Fn &&fn)
Definition: array.h:1874