19 #ifndef NDARRAY_EIN_REDUCE_H 20 #define NDARRAY_EIN_REDUCE_H 30 using enable_if_ein_op =
31 std::enable_if_t<std::is_same<typename T::is_ein_op, std::true_type>::value>;
34 using enable_if_ein_assign =
35 std::enable_if_t<std::is_same<typename T::is_assign, std::true_type>::value>;
43 template <
class Derived>
45 using is_ein_op = std::true_type;
46 using is_assign = std::false_type;
50 const Derived& derived()
const {
return *
static_cast<const Derived*
>(
this); }
52 auto operator-()
const {
return make_ein_op_negate(derived()); }
53 template <
class T,
class = enable_if_ein_op<T>>
54 auto operator+(
const T& r)
const {
55 return make_ein_op_add(derived(), r);
57 template <
class T,
class = enable_if_ein_op<T>>
58 auto operator-(
const T& r)
const {
59 return make_ein_op_sub(derived(), r);
61 template <
class T,
class = enable_if_ein_op<T>>
62 auto operator*(
const T& r)
const {
63 return make_ein_op_mul(derived(), r);
65 template <
class T,
class = enable_if_ein_op<T>>
66 auto operator/(
const T& r)
const {
67 return make_ein_op_div(derived(), r);
73 template <
class Op,
size_t... Is>
74 struct ein_op :
public ein_op_base<ein_op<Op, Is...>> {
76 ein_op(Op op) : op(std::
move(op)) {}
79 static constexpr
index_t MaxIndex =
sizeof...(Is) == 0 ? -1 : variadic_max(Is...);
85 NDARRAY_INLINE decltype(op(Is...)) operator()(const Idx& i)
const {
86 return op(std::get<Is>(i)...);
90 template <
class T,
class = enable_if_ein_op<T>>
91 auto operator=(
const T& r)
const {
92 return make_ein_op_assign(*
this, r);
94 template <
class T,
class = enable_if_ein_op<T>>
95 auto operator+=(
const T& r)
const {
96 return make_ein_op_add_assign(*
this, r);
98 template <
class T,
class = enable_if_ein_op<T>>
99 auto operator-=(
const T& r)
const {
100 return make_ein_op_sub_assign(*
this, r);
102 template <
class T,
class = enable_if_ein_op<T>>
103 auto operator*=(
const T& r)
const {
104 return make_ein_op_mul_assign(*
this, r);
109 template <
class Op,
class Derived>
110 struct ein_unary_op :
public ein_op_base<Derived> {
112 ein_unary_op(
const Op& op) : op(op) {}
113 static constexpr
index_t MaxIndex = Op::MaxIndex;
118 struct ein_negate_op :
public ein_unary_op<Op, ein_negate_op<Op>> {
119 using base = ein_unary_op<Op, ein_negate_op<Op>>;
120 ein_negate_op(
const Op& op) : base(op) {}
122 NDARRAY_INLINE
auto operator()(
const Idx& i)
const {
128 auto make_ein_op_negate(
const Op& op) {
129 return ein_negate_op<Op>(op);
133 template <
class Type,
class Op>
134 struct ein_cast_op :
public ein_unary_op<Op, ein_cast_op<Type, Op>> {
135 using base = ein_unary_op<Op, ein_cast_op<Type, Op>>;
136 ein_cast_op(
const Op& op) : base(op) {}
138 NDARRAY_INLINE
auto operator()(
const Idx& i)
const {
139 return static_cast<Type
>(base::op(i));
144 template <
class OpA,
class OpB,
class Derived>
145 struct ein_binary_op :
public ein_op_base<Derived> {
148 ein_binary_op(
const OpA& a,
const OpB& b) : op_a(a), op_b(b) {}
149 static constexpr
index_t MaxIndex = internal::max(OpA::MaxIndex, OpB::MaxIndex);
152 #define NDARRAY_MAKE_EIN_BINARY_HELPERS(name, op) \ 153 template <class OpA, class OpB> \ 154 auto make_##name(const OpA& a, const OpB& b) { \ 155 return name<OpA, OpB>(a, b); \ 158 #define NDARRAY_MAKE_EIN_BINARY_OP(name, op, is_assign_) \ 159 template <class OpA, class OpB> \ 160 struct name : public ein_binary_op<OpA, OpB, name<OpA, OpB>> { \ 161 using base = ein_binary_op<OpA, OpB, name>; \ 162 name(const OpA& a, const OpB& b) : base(a, b) {} \ 163 using is_assign = is_assign_; \ 164 template <class Idx> \ 165 NDARRAY_INLINE auto operator()(const Idx& i) const { \ 166 return base::op_a(i) op base::op_b(i); \ 169 NDARRAY_MAKE_EIN_BINARY_HELPERS(name, op) 171 #define NDARRAY_MAKE_EIN_BINARY_FN(name, fn) \ 172 template <class OpA, class OpB> \ 173 struct name : public ein_binary_op<OpA, OpB, name<OpA, OpB>> { \ 174 using base = ein_binary_op<OpA, OpB, name>; \ 175 name(const OpA& a, const OpB& b) : base(a, b) {} \ 176 template <class Idx> \ 177 NDARRAY_INLINE auto operator()(const Idx& i) const { \ 178 using internal::min; \ 179 using internal::max; \ 180 return fn(base::op_a(i), base::op_b(i)); \ 183 NDARRAY_MAKE_EIN_BINARY_HELPERS(name, op) 186 NDARRAY_MAKE_EIN_BINARY_OP(ein_op_add, +, std::false_type);
187 NDARRAY_MAKE_EIN_BINARY_OP(ein_op_sub, -, std::false_type);
188 NDARRAY_MAKE_EIN_BINARY_OP(ein_op_mul, *, std::false_type);
189 NDARRAY_MAKE_EIN_BINARY_OP(ein_op_div, /, std::false_type);
190 NDARRAY_MAKE_EIN_BINARY_FN(ein_op_min, min);
191 NDARRAY_MAKE_EIN_BINARY_FN(ein_op_max, max);
193 NDARRAY_MAKE_EIN_BINARY_OP(ein_op_assign, =, std::true_type);
194 NDARRAY_MAKE_EIN_BINARY_OP(ein_op_add_assign, +=, std::true_type);
195 NDARRAY_MAKE_EIN_BINARY_OP(ein_op_sub_assign, -=, std::true_type);
196 NDARRAY_MAKE_EIN_BINARY_OP(ein_op_mul_assign, *=, std::true_type);
198 #undef NDARRAY_MAKE_EIN_BINARY_FN 199 #undef NDARRAY_MAKE_EIN_BINARY_OP 200 #undef NDARRAY_MAKE_EIN_BINARY_HELPERS 206 template <
class Type,
class Op>
207 auto cast(
const internal::ein_op_base<Op>& op) {
208 return internal::ein_cast_op<Type, Op>(op.derived());
212 template <
class OpA,
class OpB>
213 auto min(
const internal::ein_op_base<OpA>& a,
const internal::ein_op_base<OpB>& b) {
214 return internal::make_ein_op_min(a.derived(), b.derived());
216 template <
class OpA,
class OpB>
217 auto max(
const internal::ein_op_base<OpA>& a,
const internal::ein_op_base<OpB>& b) {
218 return internal::make_ein_op_max(a.derived(), b.derived());
234 template <
class Dim0,
class... Dims,
235 class = std::enable_if_t<!any(not_equal(Dim0::Min, Dims::Min)...)>,
236 class = std::enable_if_t<!any(not_equal(Dim0::Extent, Dims::Extent)...)>>
237 NDARRAY_INLINE
const Dim0& reconcile_dim(
const Dim0& dim0,
const Dims&... dims) {
238 if (dim0.stride() != 0) {
242 assert(
all(dims.is_in_range(dim0)...));
247 assert(
all(dim0.min() == dims.min()...));
248 assert(
all(dim0.extent() == dims.extent()...));
254 NDARRAY_INLINE
dim<0, 1, 0> reconcile_dim() {
return {}; }
256 template <
class... Dims,
size_t... Is>
257 NDARRAY_INLINE
auto reconcile_dim(
const std::tuple<Dims...>& dims, index_sequence<Is...>) {
258 return reconcile_dim(std::get<Is>(dims)...);
260 template <
class... Dims>
261 NDARRAY_INLINE
auto reconcile_dim(
const std::tuple<Dims...>& dims) {
262 return reconcile_dim(dims, make_index_sequence<
sizeof...(Dims)>());
266 template <
class T,
class Shape>
268 return op.
shape().dims();
271 NDARRAY_INLINE std::tuple<> dims_of(
const T& op) {
272 return std::tuple<>();
276 template <index_t NewStr
ide, index_t Min, index_t Extent, index_t Str
ide>
280 template <index_t NewStr
ide,
class Dim>
281 NDARRAY_INLINE
auto with_stride(
const std::tuple<Dim>& maybe_dim) {
282 return std::make_tuple(with_stride<NewStride>(std::get<0>(maybe_dim)));
284 template <index_t NewStr
ide>
285 NDARRAY_INLINE std::tuple<> with_stride(std::tuple<> maybe_dim) {
290 class is_inferred_shape {};
291 class is_result_shape {};
292 class is_operand_shape {};
295 template <
size_t Dim,
class Dims,
size_t... Is>
296 NDARRAY_INLINE
auto gather_dims(is_result_shape,
const ein_op<Dims, Is...>& op) {
298 return get_or_empty<index_of<Dim, Is...>()>(dims_of(op.op));
300 template <
size_t Dim,
class Dims,
size_t... Is>
301 NDARRAY_INLINE
auto gather_dims(is_inferred_shape,
const ein_op<Dims, Is...>& op) {
303 return with_stride<dynamic>(get_or_empty<index_of<Dim, Is...>()>(dims_of(op.op)));
305 template <
size_t Dim,
class Dims,
size_t... Is>
306 NDARRAY_INLINE
auto gather_dims(is_operand_shape,
const ein_op<Dims, Is...>& op) {
308 return with_stride<0>(get_or_empty<index_of<Dim, Is...>()>(dims_of(op.op)));
311 template <
size_t Dim,
class Kind,
class Op,
class X>
312 NDARRAY_INLINE
auto gather_dims(Kind kind,
const ein_unary_op<Op, X>& op) {
313 return gather_dims<Dim>(kind, op.op);
315 template <
size_t Dim,
class Kind,
class OpA,
class OpB,
class X>
316 NDARRAY_INLINE
auto gather_dims(Kind kind,
const ein_binary_op<OpA, OpB, X>& op) {
317 return std::tuple_cat(gather_dims<Dim>(kind, op.op_a), gather_dims<Dim>(kind, op.op_b));
320 template <
size_t Dim,
class Kind0,
class Op0,
class Kind1,
class Op1>
321 NDARRAY_INLINE
auto gather_dims(Kind0 kind0,
const Op0& op0, Kind1 kind1,
const Op1& op1) {
322 return std::tuple_cat(gather_dims<Dim>(kind0, op0), gather_dims<Dim>(kind1, op1));
324 template <
size_t... Is,
class... KindAndOps>
325 NDARRAY_UNIQUE
auto make_ein_reduce_shape(
326 index_sequence<Is...>,
const KindAndOps&... kind_and_ops) {
327 return make_shape(reconcile_dim(gather_dims<Is>(kind_and_ops...))...);
338 template <
size_t... Is,
class Op,
class = internal::enable_if_callable<Op, decltype(Is)...>>
340 return internal::ein_op<Op, Is...>(std::move(op));
342 template <
size_t... Is,
class T,
class Shape,
class Alloc,
343 class = std::enable_if_t<
sizeof...(Is) == Shape::rank()>>
345 return ein<Is...>(op.
ref());
347 template <
size_t... Is,
class T,
class Shape,
class Alloc,
348 class = std::enable_if_t<
sizeof...(Is) == Shape::rank()>>
350 return ein<Is...>(op.
ref());
364 template <
size_t I0,
class T>
365 auto ein(T* x,
size_t N) {
370 template <
size_t I0,
class T,
size_t N>
418 template <
class Expr,
class =
internal::enable_if_ein_assign<Expr>>
420 constexpr
index_t LoopRank = Expr::MaxIndex + 1;
426 auto reduction_shape = internal::make_ein_reduce_shape(internal::make_index_sequence<LoopRank>(),
427 internal::is_result_shape(), expr.op_a, internal::is_operand_shape(), expr.op_b);
439 for_each_index_in_order(reduction_shape, expr);
446 template <
size_t... ResultIs,
class Expr,
class = internal::enable_if_ein_op<Expr>>
447 NDARRAY_UNIQUE
auto make_ein_reduce_shape(
const Expr& expr) {
448 auto result_shape = internal::make_ein_reduce_shape(
449 internal::index_sequence<ResultIs...>(), internal::is_inferred_shape(), expr);
475 template <
class T,
size_t... ResultIs,
class Expr,
class Alloc = std::allocator<T>,
476 class = internal::enable_if_ein_op<Expr>>
478 const Expr& expr,
const T& init = T(),
const Alloc& alloc = Alloc()) {
479 auto result_shape = make_ein_reduce_shape<ResultIs...>(expr);
480 auto result = make_array<T>(result_shape, init, alloc);
487 #endif // NDARRAY_EIN_REDUCE_H NDARRAY_HOST_DEVICE Shape & shape()
Definition: array.h:2105
NDARRAY_INLINE auto ein_reduce(const Expr &expr)
Definition: ein_reduce.h:419
NDARRAY_HOST_DEVICE array_ref< T, Shape > make_array_ref(T *base, const Shape &shape)
Definition: array.h:1970
const interval< 0,-1 > all
Definition: array.h:363
auto ein(Op op)
Definition: ein_reduce.h:339
NDARRAY_INLINE auto make_ein_sum(const Expr &expr, const T &init=T(), const Alloc &alloc=Alloc())
Definition: ein_reduce.h:477
auto cast(const internal::ein_op_base< Op > &op)
Definition: ein_reduce.h:207
Main header for array library.
void move(const array_ref< TSrc, ShapeSrc > &src, const array_ref< TDst, ShapeDst > &dst)
Definition: array.h:2804
NDARRAY_HOST_DEVICE auto make_shape(Dims...dims)
Definition: array.h:1040
std::ptrdiff_t index_t
Definition: array.h:87
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t min() const
Definition: array.h:293
NDARRAY_HOST_DEVICE auto make_compact(const Shape &s)
Definition: array.h:1620
auto min(const internal::ein_op_base< OpA > &a, const internal::ein_op_base< OpB > &b)
Definition: ein_reduce.h:213
array_ref< T, Shape > ref()
Definition: array.h:2675
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t extent() const
Definition: array.h:296