19 #ifndef NDARRAY_ARRAY_H 20 #define NDARRAY_ARRAY_H 36 #include <type_traits> 41 #define NDARRAY_INLINE inline __attribute__((always_inline)) 42 #elif defined(__clang__) 44 #define NDARRAY_INLINE __forceinline__ 46 #define NDARRAY_INLINE inline __attribute__((always_inline)) 49 #define NDARRAY_INLINE inline 56 #define NDARRAY_UNIQUE static 60 #define NDARRAY_HOST_DEVICE __device__ __host__ 62 #define NDARRAY_HOST_DEVICE 65 #if defined(__GNUC__) || defined(__clang__) 66 #define NDARRAY_RESTRICT __restrict__ 68 #define NDARRAY_RESTRICT 72 #define NDARRAY_PRINT_ERR(...) printf(__VA_ARGS__) 74 #define NDARRAY_PRINT_ERR(...) fprintf(stderr, __VA_ARGS__) 79 using size_t = std::size_t;
83 #ifdef NDARRAY_INT_INDICES 85 #define NDARRAY_INDEX_T_FMT "%d" 88 #define NDARRAY_INDEX_T_FMT "%td" 109 template <
typename T>
110 NDARRAY_HOST_DEVICE
typename std::add_rvalue_reference<T>::type declval() noexcept;
112 NDARRAY_INLINE constexpr
index_t abs(
index_t x) {
return x >= 0 ? x : -x; }
114 NDARRAY_INLINE constexpr
index_t is_static(
index_t x) {
return x != dynamic; }
115 NDARRAY_INLINE constexpr
index_t is_dynamic(
index_t x) {
return x == dynamic; }
117 NDARRAY_INLINE constexpr
index_t is_resolved(
index_t x) {
return x != unresolved; }
118 NDARRAY_INLINE constexpr
index_t is_unresolved(
index_t x) {
return x == unresolved; }
120 constexpr
bool is_dynamic(
index_t a,
index_t b) {
return is_dynamic(a) || is_dynamic(b); }
124 constexpr
bool not_equal(
index_t a,
index_t b) {
return is_static(a) && is_static(b) && a != b; }
126 template <index_t A, index_t B>
127 using disable_if_not_equal = std::enable_if_t<!not_equal(A, B)>;
130 constexpr
index_t static_abs(
index_t x) {
return is_dynamic(x) ? dynamic : abs(x); }
135 return is_dynamic(a, b) ? dynamic : (a < b ? a : b);
138 return is_dynamic(a, b) ? dynamic : (a > b ? a : b);
143 template <index_t Value>
144 struct constexpr_index {
148 NDARRAY_HOST_DEVICE constexpr_index(
index_t value = Value) { assert(value == Value); }
149 NDARRAY_HOST_DEVICE constexpr_index& operator=(
index_t value) {
150 assert(value == Value);
153 NDARRAY_HOST_DEVICE NDARRAY_INLINE
operator index_t()
const {
return Value; }
157 struct constexpr_index<dynamic> {
161 NDARRAY_HOST_DEVICE constexpr_index(
index_t value) : value_(value) {}
162 NDARRAY_HOST_DEVICE constexpr_index& operator=(
index_t value) {
166 NDARRAY_INLINE NDARRAY_HOST_DEVICE
operator index_t()
const {
return value_; }
170 template <
typename T>
171 constexpr
const T& max(
const T& a,
const T& b) {
172 return (a < b) ? b : a;
175 template <
typename T>
176 constexpr
const T& min(
const T& a,
const T& b) {
177 return (b < a) ? b : a;
193 NDARRAY_INLINE NDARRAY_HOST_DEVICE
bool operator==(
const index_iterator& r)
const {
196 NDARRAY_INLINE NDARRAY_HOST_DEVICE
bool operator!=(
const index_iterator& r)
const {
227 NDARRAY_INLINE NDARRAY_HOST_DEVICE
index_t operator[](
index_t n)
const {
return i_ + n; }
230 template <index_t Min, index_t Extent, index_t Str
ide>
247 template <index_t Min_ = dynamic, index_t Extent_ = dynamic>
250 internal::constexpr_index<Min_> min_;
251 internal::constexpr_index<Extent_> extent_;
254 static constexpr
index_t Min = Min_;
255 static constexpr
index_t Extent = Extent_;
256 static constexpr
index_t Max = internal::static_sub(internal::static_add(Min, Extent), 1);
267 : interval(min, internal::is_static(Extent) ? Extent : 1) {}
268 NDARRAY_HOST_DEVICE interval() : interval(internal::is_static(Min) ? Min : 0) {}
270 NDARRAY_HOST_DEVICE interval(
const interval&) =
default;
271 NDARRAY_HOST_DEVICE interval(interval&&) =
default;
272 NDARRAY_HOST_DEVICE interval& operator=(
const interval&) =
default;
273 NDARRAY_HOST_DEVICE interval& operator=(interval&&) =
default;
279 class = internal::disable_if_not_equal<Min, CopyMin>,
280 class = internal::disable_if_not_equal<Extent, CopyExtent>>
282 : interval(other.min(), other.extent()) {}
284 class = internal::disable_if_not_equal<Min, CopyMin>,
285 class = internal::disable_if_not_equal<Extent, CopyExtent>>
287 set_min(other.
min());
288 set_extent(other.
extent());
293 NDARRAY_INLINE NDARRAY_HOST_DEVICE
index_t min()
const {
return min_; }
294 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void set_min(
index_t min) { min_ = min; }
296 NDARRAY_INLINE NDARRAY_HOST_DEVICE
index_t extent()
const {
return extent_; }
297 NDARRAY_INLINE NDARRAY_HOST_DEVICE
index_t size()
const {
return extent_; }
298 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void set_extent(
index_t extent) { extent_ = extent; }
301 NDARRAY_INLINE NDARRAY_HOST_DEVICE
index_t max()
const {
return min_ + extent_ - 1; }
302 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void set_max(
index_t max) { set_extent(max - min_ + 1); }
306 return min_ <= at && at <= max();
310 template <index_t OtherMin, index_t OtherExtent>
313 return min_ <= at.
min() && at.
max() <= max();
315 template <index_t OtherMin, index_t OtherExtent, index_t OtherStr
ide>
316 NDARRAY_INLINE NDARRAY_HOST_DEVICE
bool is_in_range(
318 return min_ <= at.
min() && at.
max() <= max();
328 template <index_t OtherMin, index_t OtherExtent>
330 return min_ == other.
min() && extent_ == other.
extent();
332 template <index_t OtherMin, index_t OtherExtent>
334 return !operator==(other);
340 template <index_t Extent>
352 template <index_t Extent>
356 template <index_t Extent>
366 template <index_t Min, index_t Extent>
370 template <index_t Min, index_t Extent>
377 return internal::min(internal::max(x, min), max);
382 template <
class Range>
384 return clamp(x, r.min(), r.max());
395 template <index_t Min_ = dynamic, index_t Extent_ = dynamic, index_t Str
ide_ = dynamic>
401 internal::constexpr_index<Stride_> stride_;
403 using base_range::extent_;
404 using base_range::min_;
407 using base_range::Extent;
408 using base_range::Max;
409 using base_range::Min;
411 static constexpr
index_t Stride = Stride_;
412 static constexpr
index_t DefaultStride = internal::is_static(Stride) ? Stride : unresolved;
424 NDARRAY_HOST_DEVICE
dim(
index_t extent) : dim(internal::is_static(Min) ? Min : 0, extent) {}
425 NDARRAY_HOST_DEVICE dim() : dim(internal::is_static(Extent) ? Extent : 0) {}
428 : dim(interval.
min(), interval.
extent(), stride) {}
429 NDARRAY_HOST_DEVICE dim(
const dim&) =
default;
430 NDARRAY_HOST_DEVICE dim(dim&&) =
default;
431 NDARRAY_HOST_DEVICE dim& operator=(
const dim&) =
default;
432 NDARRAY_HOST_DEVICE dim& operator=(dim&&) =
default;
439 class = internal::disable_if_not_equal<Min, CopyMin>,
440 class = internal::disable_if_not_equal<Extent, CopyExtent>,
441 class = internal::disable_if_not_equal<Stride, CopyStride>>
443 : dim(other.min(), other.extent()) {
444 set_stride(other.
stride());
447 class = internal::disable_if_not_equal<Min, CopyMin>,
448 class = internal::disable_if_not_equal<Extent, CopyExtent>,
449 class = internal::disable_if_not_equal<Stride, CopyStride>>
451 set_min(other.
min());
452 set_extent(other.
extent());
453 set_stride(other.
stride());
457 using base_range::begin;
458 using base_range::end;
459 using base_range::extent;
460 using base_range::size;
461 using base_range::is_in_range;
462 using base_range::max;
463 using base_range::min;
464 using base_range::set_extent;
465 using base_range::set_max;
466 using base_range::set_min;
470 NDARRAY_INLINE NDARRAY_HOST_DEVICE
index_t stride()
const {
return stride_; }
471 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void set_stride(
index_t stride) {
472 if (internal::is_static(Stride_)) {
473 assert(internal::is_unresolved(stride) || stride == Stride_);
481 return (at - min_) * stride_;
486 template <index_t OtherMin, index_t OtherExtent, index_t OtherStr
ide>
488 return min_ == other.
min() && extent_ == other.
extent() && stride_ == other.
stride();
490 template <index_t OtherMin, index_t OtherExtent, index_t OtherStr
ide>
492 return !operator==(other);
497 template <index_t Extent, index_t Str
ide = dynamic>
502 template <index_t Min = dynamic, index_t Extent = dynamic>
507 template <index_t Str
ide>
512 template <index_t Min = dynamic, index_t Extent = dynamic>
520 template <index_t InnerExtent = dynamic>
521 class split_iterator {
527 : i(i), outer_max(outer_max) {}
529 NDARRAY_HOST_DEVICE
bool operator==(
const split_iterator& r)
const {
530 return i.
min() == r.i.min();
532 NDARRAY_HOST_DEVICE
bool operator!=(
const split_iterator& r)
const {
533 return i.
min() != r.i.min();
539 NDARRAY_HOST_DEVICE split_iterator& operator+=(
index_t n) {
541 if (is_static(InnerExtent)) {
545 i.set_min(i.
min() + InnerExtent * n);
548 if (i.
min() <= outer_max && i.
max() > outer_max) { i.set_min(outer_max - InnerExtent + 1); }
554 i.set_extent(max - i.
min() + 1);
558 NDARRAY_HOST_DEVICE split_iterator operator+(
index_t n)
const {
559 split_iterator<InnerExtent> result(*
this);
562 NDARRAY_HOST_DEVICE split_iterator& operator++() {
565 NDARRAY_HOST_DEVICE split_iterator operator++(
int) {
566 split_iterator<InnerExtent> result(*
this);
571 NDARRAY_HOST_DEVICE
index_t operator-(
const split_iterator& r)
const {
572 return r.i.extent() > 0 ? (i.
max() - r.i.min() + r.i.extent() - i.
extent()) / r.i.extent() : 0;
576 split_iterator result(*
this);
582 template <index_t InnerExtent = dynamic>
585 using iterator = split_iterator<InnerExtent>;
592 NDARRAY_HOST_DEVICE split_result(iterator begin, iterator end) : begin_(begin), end_(end) {}
594 NDARRAY_HOST_DEVICE iterator begin()
const {
return begin_; }
595 NDARRAY_HOST_DEVICE iterator end()
const {
return end_; }
597 NDARRAY_HOST_DEVICE
index_t size()
const {
return end_ - begin_; }
598 NDARRAY_HOST_DEVICE iterator operator[](
index_t i)
const {
return begin_ + i; }
613 template <index_t InnerExtent, index_t Min, index_t Extent>
614 NDARRAY_HOST_DEVICE internal::split_result<InnerExtent>
split(
616 assert(v.
extent() >= InnerExtent);
620 template <index_t InnerExtent, index_t Min, index_t Extent, index_t Str
ide>
621 NDARRAY_HOST_DEVICE internal::split_result<InnerExtent>
split(
636 template <index_t Min, index_t Extent>
637 NDARRAY_HOST_DEVICE internal::split_result<>
split(
642 template <index_t Min, index_t Extent, index_t Str
ide>
643 NDARRAY_HOST_DEVICE internal::split_result<>
split(
650 using std::index_sequence;
651 using std::make_index_sequence;
655 template <
class Fn,
class Args,
size_t... Is>
656 NDARRAY_INLINE NDARRAY_HOST_DEVICE
auto apply(Fn&& fn,
const Args& args, index_sequence<Is...>)
657 -> decltype(fn(std::get<Is>(args)...)) {
658 return fn(std::get<Is>(args)...);
660 template <
class Fn,
class Args>
661 NDARRAY_INLINE NDARRAY_HOST_DEVICE
auto apply(Fn&& fn,
const Args& args)
662 -> decltype(internal::apply(fn, args, make_index_sequence<std::tuple_size<Args>::value>())) {
663 return internal::apply(fn, args, make_index_sequence<std::tuple_size<Args>::value>());
666 template <
class Fn,
class... Args>
667 using enable_if_callable = decltype(internal::declval<Fn>()(internal::declval<Args>()...));
668 template <
class Fn,
class Args>
669 using enable_if_applicable =
670 decltype(internal::apply(internal::declval<Fn>(), internal::declval<Args>()));
673 NDARRAY_INLINE constexpr
index_t sum() {
return 0; }
677 template <
class... Rest>
679 return x0 + x1 + x2 + x3 + sum(rest...);
682 NDARRAY_INLINE constexpr
int product() {
return 1; }
683 template <
class T,
class... Rest>
684 NDARRAY_INLINE constexpr T product(T first, Rest... rest) {
685 return first * product(rest...);
688 NDARRAY_INLINE constexpr
index_t variadic_min() {
return std::numeric_limits<index_t>::max(); }
689 template <
class... Rest>
690 NDARRAY_INLINE constexpr
index_t variadic_min(
index_t first, Rest... rest) {
691 return min(first, variadic_min(rest...));
694 NDARRAY_INLINE constexpr
index_t variadic_max() {
return std::numeric_limits<index_t>::min(); }
695 template <
class... Rest>
696 NDARRAY_INLINE constexpr
index_t variadic_max(
index_t first, Rest... rest) {
697 return max(first, variadic_max(rest...));
701 template <
class Tuple,
size_t... Is>
702 NDARRAY_HOST_DEVICE
index_t product(
const Tuple& t, index_sequence<Is...>) {
703 return product(std::get<Is>(t)...);
707 template <
class... Bools>
708 constexpr
bool all(Bools... bools) {
709 return sum((bools ? 0 : 1)...) == 0;
711 template <
class... Bools>
712 constexpr
bool any(Bools... bools) {
713 return sum((bools ? 1 : 0)...) != 0;
717 template <
class Dims,
class Indices,
size_t... Is>
718 NDARRAY_HOST_DEVICE
index_t flat_offset(
719 const Dims& dims,
const Indices& indices, index_sequence<Is...>) {
720 return sum(std::get<Is>(dims).flat_offset(std::get<Is>(indices))...);
724 template <
class Dims,
size_t... Is>
725 NDARRAY_HOST_DEVICE
index_t flat_min(
const Dims& dims, index_sequence<Is...>) {
726 return sum((std::get<Is>(dims).extent() - 1) * min<index_t>(0, std::get<Is>(dims).stride())...);
729 template <
class Dims,
size_t... Is>
730 NDARRAY_HOST_DEVICE
index_t flat_max(
const Dims& dims, index_sequence<Is...>) {
731 return sum((std::get<Is>(dims).extent() - 1) *
732 internal::max<index_t>(0, std::get<Is>(dims).stride())...);
737 template <index_t DimMin, index_t DimExtent, index_t DimStr
ide>
741 template <index_t CropMin, index_t CropExtent, index_t DimMin, index_t DimExtent, index_t Str
ide>
742 NDARRAY_HOST_DEVICE
auto range_with_stride(
748 NDARRAY_HOST_DEVICE
auto range_with_stride(
752 template <index_t Min, index_t Extent, index_t Str
ide>
757 template <
class Intervals,
class Dims,
size_t... Is>
758 NDARRAY_HOST_DEVICE
auto intervals_with_strides(
759 const Intervals& intervals,
const Dims& dims, index_sequence<Is...>) {
760 return std::make_tuple(range_with_stride(std::get<Is>(intervals), std::get<Is>(dims))...);
765 NDARRAY_HOST_DEVICE std::tuple<> skip_slices_impl(
const Dim& d,
index_t) {
766 return std::tuple<>();
768 template <
class Dim, index_t Min, index_t Extent>
770 return std::tuple<Dim>(d);
772 template <
class Dim, index_t Min, index_t Extent, index_t Str
ide>
773 NDARRAY_HOST_DEVICE std::tuple<Dim> skip_slices_impl(
775 return std::tuple<Dim>(d);
778 template <
class Dims,
class Intervals,
size_t... Is>
779 NDARRAY_HOST_DEVICE
auto skip_slices(
780 const Dims& dims,
const Intervals& intervals, index_sequence<Is...>) {
781 return std::tuple_cat(skip_slices_impl(std::get<Is>(dims), std::get<Is>(intervals))...);
785 template <
class Dims,
class Indices,
size_t... Is>
786 NDARRAY_HOST_DEVICE
bool is_in_range(
787 const Dims& dims,
const Indices& indices, index_sequence<Is...>) {
788 return all(std::get<Is>(dims).is_in_range(std::get<Is>(indices))...);
796 template <index_t Min, index_t Extent,
class Dim>
800 template <index_t Min, index_t Extent, index_t Str
ide,
class Dim>
805 NDARRAY_HOST_DEVICE
index_t min_of_range(
const decltype(_)&,
const Dim&
dim) {
809 template <
class Intervals,
class Dims,
size_t... Is>
810 NDARRAY_HOST_DEVICE
auto mins_of_intervals(
811 const Intervals& intervals,
const Dims& dims, index_sequence<Is...>) {
812 return std::make_tuple(min_of_range(std::get<Is>(intervals), std::get<Is>(dims))...);
815 template <
class... Dims,
size_t... Is>
816 NDARRAY_HOST_DEVICE
auto mins(
const std::tuple<Dims...>& dims, index_sequence<Is...>) {
817 return std::make_tuple(std::get<Is>(dims).min()...);
820 template <
class... Dims,
size_t... Is>
821 NDARRAY_HOST_DEVICE
auto extents(
const std::tuple<Dims...>& dims, index_sequence<Is...>) {
822 return std::make_tuple(std::get<Is>(dims).extent()...);
825 template <
class... Dims,
size_t... Is>
826 NDARRAY_HOST_DEVICE
auto strides(
const std::tuple<Dims...>& dims, index_sequence<Is...>) {
827 return std::make_tuple(std::get<Is>(dims).stride()...);
830 template <
class... Dims,
size_t... Is>
831 NDARRAY_HOST_DEVICE
auto maxs(
const std::tuple<Dims...>& dims, index_sequence<Is...>) {
832 return std::make_tuple(std::get<Is>(dims).max()...);
841 NDARRAY_HOST_DEVICE
bool is_stride_ok(
index_t stride,
index_t extent,
const Dim& dim) {
842 if (is_unresolved(dim.stride())) {
847 if (extent == 1 && abs(stride) == abs(dim.stride()) && dim.extent() > 1) {
856 if (dim.extent() * abs(dim.stride()) <= stride) {
860 index_t flat_extent = extent * stride;
861 if (abs(dim.stride()) >= flat_extent) {
870 template <
class... Dims>
872 if (all(is_stride_ok(stride, extent, dims)...)) {
875 return std::numeric_limits<index_t>::max();
882 NDARRAY_HOST_DEVICE
index_t candidate_stride(
const Dim& dim) {
883 if (is_unresolved(dim.stride())) {
884 return std::numeric_limits<index_t>::max();
886 return max<index_t>(1, abs(dim.stride()) * dim.extent());
891 template <
class Dims,
size_t... Is>
892 NDARRAY_HOST_DEVICE
index_t find_stride(
index_t extent,
const Dims& dims, index_sequence<Is...>) {
893 return variadic_min(filter_stride(1, extent, std::get<Is>(dims)...),
894 filter_stride(candidate_stride(std::get<Is>(dims)), extent, std::get<Is>(dims)...)...);
898 template <
class AllDims>
899 NDARRAY_HOST_DEVICE
void resolve_unknown_strides(AllDims& all_dims) {}
900 template <
class AllDims,
class Dim0,
class... Dims>
901 NDARRAY_HOST_DEVICE
void resolve_unknown_strides(AllDims& all_dims, Dim0& dim0, Dims&... dims) {
902 if (is_unresolved(dim0.stride())) {
903 constexpr
size_t rank = std::tuple_size<AllDims>::value;
904 dim0.set_stride(find_stride(dim0.extent(), all_dims, make_index_sequence<rank>()));
906 resolve_unknown_strides(all_dims, dims...);
909 template <
class Dims,
size_t... Is>
910 NDARRAY_HOST_DEVICE
void resolve_unknown_strides(Dims& dims, index_sequence<Is...>) {
911 resolve_unknown_strides(dims, std::get<Is>(dims)...);
914 template <
class Dims,
size_t... Is>
915 NDARRAY_HOST_DEVICE
bool is_resolved(
const Dims& dims, index_sequence<Is...>) {
916 return all(is_resolved(std::get<Is>(dims).stride())...);
920 template <
class T,
class Tuple,
size_t... Is>
921 NDARRAY_HOST_DEVICE std::array<T,
sizeof...(Is)> tuple_to_array(
922 const Tuple& t, index_sequence<Is...>) {
923 return {{std::get<Is>(t)...}};
926 template <
class T,
class... Ts>
927 NDARRAY_HOST_DEVICE std::array<T,
sizeof...(Ts)> tuple_to_array(
const std::tuple<Ts...>& t) {
928 return tuple_to_array<T>(t, make_index_sequence<
sizeof...(Ts)>());
931 template <
class T,
size_t N,
size_t... Is>
932 NDARRAY_HOST_DEVICE
auto array_to_tuple(
const std::array<T, N>& a, index_sequence<Is...>) {
933 return std::make_tuple(a[Is]...);
935 template <
class T,
size_t N>
936 NDARRAY_HOST_DEVICE
auto array_to_tuple(
const std::array<T, N>& a) {
937 return array_to_tuple(a, make_index_sequence<N>());
940 template <
class T,
size_t N>
941 using tuple_of_n = decltype(array_to_tuple(internal::declval<std::array<T, N>>()));
945 template <
class T,
class... Args>
946 struct all_of_any_type : std::false_type {};
948 struct all_of_any_type<T> : std::true_type {};
949 template <
class... Ts,
class Arg,
class... Args>
950 struct all_of_any_type<std::tuple<Ts...>, Arg, Args...> {
951 static constexpr
bool value = any(std::is_constructible<Ts, Arg>::value...) &&
952 all_of_any_type<std::tuple<Ts...>, Args...>::value;
957 template <
class T,
class... Args>
958 using all_of_type = all_of_any_type<std::tuple<T>, Args...>;
960 template <
size_t I,
class T,
class... Us, std::enable_if_t<(I <
sizeof...(Us)),
int> = 0>
961 NDARRAY_HOST_DEVICE
auto convert_dim(
const std::tuple<Us...>& u) {
962 return std::get<I>(u);
964 template <
size_t I,
class T,
class... Us, std::enable_if_t<(I >=
sizeof...(Us)),
int> = 0>
965 NDARRAY_HOST_DEVICE
auto convert_dim(
const std::tuple<Us...>& u) {
967 return decltype(std::get<I>(internal::declval<T>()))(1);
970 template <
class T,
class U,
size_t... Is>
971 NDARRAY_HOST_DEVICE T convert_dims(
const U& u, internal::index_sequence<Is...>) {
972 return std::make_tuple(convert_dim<Is, T>(u)...);
977 template <
size_t DstRank,
class SrcDims,
size_t... Is>
978 NDARRAY_HOST_DEVICE
bool is_trivial_slice(
979 const SrcDims& src_dims, internal::index_sequence<Is...>) {
980 return all((Is < DstRank || std::get<Is>(src_dims).extent() == 1)...);
983 constexpr
index_t factorial(
index_t x) {
return x == 1 ? 1 : x * factorial(x - 1); }
987 template <
size_t Rank,
size_t... Is>
988 using enable_if_permutation = std::enable_if_t<
sizeof...(Is) == Rank && all(Is < Rank...) &&
989 product((Is + 2)...) == factorial(Rank + 1)>;
991 template <
class DimDst,
class DimSrc>
992 NDARRAY_HOST_DEVICE
void assert_dim_compatible(
size_t dim_index,
const DimSrc& src) {
993 bool compatible =
true;
994 if (is_static(DimDst::Min) && src.min() != DimDst::Min) {
995 NDARRAY_PRINT_ERR(
"Error converting dim %zu: expected static min " NDARRAY_INDEX_T_FMT
996 ", got " NDARRAY_INDEX_T_FMT
"\n",
997 dim_index, DimDst::Min, src.min());
1000 if (is_static(DimDst::Extent) && src.extent() != DimDst::Extent) {
1001 NDARRAY_PRINT_ERR(
"Error converting dim %zu: expected static extent " NDARRAY_INDEX_T_FMT
1002 ", got " NDARRAY_INDEX_T_FMT
"\n",
1003 dim_index, DimDst::Extent, src.extent());
1006 if (is_static(DimDst::Stride) && is_resolved(src.stride()) && src.stride() != DimDst::Stride) {
1007 NDARRAY_PRINT_ERR(
"Error converting dim %zu: expected static stride " NDARRAY_INDEX_T_FMT
1008 ", got " NDARRAY_INDEX_T_FMT
"\n",
1009 dim_index, DimDst::Stride, src.stride());
1016 template <
class DimsDst,
class DimsSrc,
size_t... Is>
1017 NDARRAY_HOST_DEVICE
void assert_dims_compatible(
const DimsSrc& src, index_sequence<Is...>) {
1019 int unused[] = {(assert_dim_compatible<typename std::tuple_element<Is, DimsDst>::type>(
1025 template <
class DimsDst,
class DimsSrc>
1026 NDARRAY_HOST_DEVICE
const DimsSrc& assert_dims_compatible(
const DimsSrc& src) {
1028 assert_dims_compatible<DimsDst>(src, make_index_sequence<std::tuple_size<DimsDst>::value>());
1035 template <
class... Dims>
1039 template <
class... Dims>
1041 return shape<Dims...>(dims...);
1044 template <
class... Dims>
1045 NDARRAY_HOST_DEVICE
shape<Dims...> make_shape_from_tuple(
const std::tuple<Dims...>& dims) {
1046 return shape<Dims...>(dims);
1053 template <
size_t Rank>
1063 template <
class... Dims>
1070 static constexpr
size_t rank() {
return std::tuple_size<dims_type>::value; }
1078 using size_type = size_t;
1081 using dim_indices = decltype(internal::make_index_sequence<std::tuple_size<dims_type>::value>());
1088 template <
class... OtherDims>
1089 using enable_if_dims_compatible = std::enable_if_t<
sizeof...(OtherDims) == rank()>;
1091 template <
class... Args>
1092 using enable_if_same_rank = std::enable_if_t<(
sizeof...(Args) == rank())>;
1094 template <
class... Args>
1095 using enable_if_all_indices =
1096 std::enable_if_t<
sizeof...(Args) == rank() && internal::all_of_type<
index_t, Args...>::value>;
1098 template <
class... Args>
1099 using enable_if_any_slices =
1100 std::enable_if_t<
sizeof...(Args) == rank() &&
1101 internal::all_of_any_type<std::tuple<interval<>,
dim<>>, Args...>::value &&
1102 !internal::all_of_type<index_t, Args...>::value>;
1104 template <
class... Args>
1105 using enable_if_any_slices_or_indices =
1106 std::enable_if_t<
sizeof...(Args) == rank() &&
1107 internal::all_of_any_type<std::tuple<interval<>, dim<>>, Args...>::value>;
1109 template <
size_t Dim>
1110 using enable_if_dim = std::enable_if_t<(Dim < rank())>;
1113 NDARRAY_HOST_DEVICE
shape() {}
1116 template <
size_t N =
sizeof...(Dims),
class = std::enable_if_t<(N > 0)>>
1117 NDARRAY_HOST_DEVICE
shape(
const Dims&... dims)
1118 : dims_(internal::assert_dims_compatible<
dims_type>(std::make_tuple(dims...))) {}
1119 NDARRAY_HOST_DEVICE
shape(
const shape&) =
default;
1120 NDARRAY_HOST_DEVICE
shape(
shape&&) =
default;
1121 NDARRAY_HOST_DEVICE
shape& operator=(
const shape&) =
default;
1122 NDARRAY_HOST_DEVICE
shape& operator=(
shape&&) =
default;
1127 template <
class... OtherDims,
class = enable_if_dims_compatible<OtherDims...>>
1128 NDARRAY_HOST_DEVICE
shape(
const std::tuple<OtherDims...>& other)
1129 : dims_(internal::assert_dims_compatible<
dims_type>(other)) {}
1130 template <
class... OtherDims,
class = enable_if_dims_compatible<OtherDims...>>
1131 NDARRAY_HOST_DEVICE
shape(OtherDims... other_dims) : shape(std::make_tuple(other_dims...)) {}
1132 template <
class... OtherDims,
class = enable_if_dims_compatible<OtherDims...>>
1134 : dims_(internal::assert_dims_compatible<dims_type>(other.
dims())) {}
1135 template <
class... OtherDims,
class = enable_if_dims_compatible<OtherDims...>>
1137 dims_ = internal::assert_dims_compatible<dims_type>(other.
dims());
1154 NDARRAY_HOST_DEVICE
void resolve() { internal::resolve_unknown_strides(dims_, dim_indices()); }
1158 return internal::is_resolved(dims_, dim_indices());
1162 template <
class... Args,
class = enable_if_any_slices_or_indices<Args...>>
1163 NDARRAY_HOST_DEVICE
bool is_in_range(
const std::tuple<Args...>& args)
const {
1164 return internal::is_in_range(dims_, args, dim_indices());
1166 template <
class... Args,
class = enable_if_any_slices_or_indices<Args...>>
1167 NDARRAY_HOST_DEVICE
bool is_in_range(Args... args)
const {
1168 return internal::is_in_range(dims_, std::make_tuple(args...), dim_indices());
1173 return internal::flat_offset(dims_, indices, dim_indices());
1175 template <
class... Args,
class = enable_if_all_indices<Args...>>
1176 NDARRAY_HOST_DEVICE index_t operator()(Args... indices)
const {
1177 return internal::flat_offset(dims_, std::make_tuple(indices...), dim_indices());
1183 template <
class... Args,
class = enable_if_any_slices<Args...>>
1184 NDARRAY_HOST_DEVICE
auto operator[](
const std::tuple<Args...>& args)
const {
1185 auto new_dims = internal::intervals_with_strides(args, dims_, dim_indices());
1186 auto new_dims_no_slices = internal::skip_slices(new_dims, args, dim_indices());
1187 return make_shape_from_tuple(new_dims_no_slices);
1189 template <
class... Args,
class = enable_if_any_slices<Args...>>
1190 NDARRAY_HOST_DEVICE
auto operator()(Args... args)
const {
1191 return operator[](std::make_tuple(args...));
1195 template <
size_t D,
class = enable_if_dim<D>>
1196 NDARRAY_HOST_DEVICE
auto&
dim() {
1197 return std::get<D>(dims_);
1199 template <
size_t D,
class = enable_if_dim<D>>
1200 NDARRAY_HOST_DEVICE
const auto&
dim()
const {
1201 return std::get<D>(dims_);
1209 return internal::tuple_to_array<nda::dim<>>(dims_)[d];
1214 NDARRAY_HOST_DEVICE
const dims_type& dims()
const {
return dims_; }
1216 NDARRAY_HOST_DEVICE
index_type min()
const {
return internal::mins(dims_, dim_indices()); }
1217 NDARRAY_HOST_DEVICE
index_type max()
const {
return internal::maxs(dims_, dim_indices()); }
1218 NDARRAY_HOST_DEVICE
index_type extent()
const {
return internal::extents(dims_, dim_indices()); }
1219 NDARRAY_HOST_DEVICE
index_type stride()
const {
return internal::strides(dims_, dim_indices()); }
1224 NDARRAY_HOST_DEVICE index_t
flat_min()
const {
return internal::flat_min(dims_, dim_indices()); }
1225 NDARRAY_HOST_DEVICE index_t flat_max()
const {
return internal::flat_max(dims_, dim_indices()); }
1226 NDARRAY_HOST_DEVICE size_type flat_extent()
const {
1227 index_t e = flat_max() - flat_min() + 1;
1228 return e < 0 ? 0 : static_cast<size_type>(e);
1232 NDARRAY_HOST_DEVICE size_type
size()
const {
1233 index_t s = internal::product(extent(), dim_indices());
1234 return s < 0 ? 0 : static_cast<size_type>(s);
1238 NDARRAY_HOST_DEVICE
bool empty()
const {
return size() == 0; }
1243 NDARRAY_HOST_DEVICE
bool is_compact()
const {
return flat_extent() <= size(); }
1251 return flat_extent() >= size();
1257 template <
typename OtherShape>
1258 NDARRAY_HOST_DEVICE
bool is_subset_of(
const OtherShape& other, index_t offset)
const {
1260 return flat_min() >= other.flat_min() + offset && flat_max() <= other.flat_max() + offset;
1265 NDARRAY_HOST_DEVICE
auto&
i() {
return dim<0>(); }
1266 NDARRAY_HOST_DEVICE
const auto& i()
const {
return dim<0>(); }
1267 NDARRAY_HOST_DEVICE
auto& j() {
return dim<1>(); }
1268 NDARRAY_HOST_DEVICE
const auto& j()
const {
return dim<1>(); }
1269 NDARRAY_HOST_DEVICE
auto& k() {
return dim<2>(); }
1270 NDARRAY_HOST_DEVICE
const auto& k()
const {
return dim<2>(); }
1274 NDARRAY_HOST_DEVICE
auto&
x() {
return dim<0>(); }
1275 NDARRAY_HOST_DEVICE
const auto& x()
const {
return dim<0>(); }
1276 NDARRAY_HOST_DEVICE
auto& y() {
return dim<1>(); }
1277 NDARRAY_HOST_DEVICE
const auto& y()
const {
return dim<1>(); }
1278 NDARRAY_HOST_DEVICE
auto& z() {
return dim<2>(); }
1279 NDARRAY_HOST_DEVICE
const auto& z()
const {
return dim<2>(); }
1280 NDARRAY_HOST_DEVICE
auto& c() {
return dim<2>(); }
1281 NDARRAY_HOST_DEVICE
const auto& c()
const {
return dim<2>(); }
1282 NDARRAY_HOST_DEVICE
auto& w() {
return dim<3>(); }
1283 NDARRAY_HOST_DEVICE
const auto& w()
const {
return dim<3>(); }
1287 NDARRAY_HOST_DEVICE index_t
width()
const {
return x().extent(); }
1288 NDARRAY_HOST_DEVICE index_t height()
const {
return y().extent(); }
1289 NDARRAY_HOST_DEVICE index_t channels()
const {
return c().extent(); }
1293 NDARRAY_HOST_DEVICE index_t
rows()
const {
return i().extent(); }
1294 NDARRAY_HOST_DEVICE index_t columns()
const {
return j().extent(); }
1298 template <
class... OtherDims,
class = enable_if_same_rank<OtherDims...>>
1300 return dims_ == other.
dims();
1302 template <
class... OtherDims,
class = enable_if_same_rank<OtherDims...>>
1304 return dims_ != other.
dims();
1308 namespace internal {
1310 template <
size_t... DimIndices,
class Shape,
size_t... Extras>
1311 NDARRAY_HOST_DEVICE
auto transpose_impl(
const Shape&
shape, index_sequence<Extras...>) {
1313 shape.template
dim<DimIndices>()..., shape.template
dim<
sizeof...(DimIndices) + Extras>()...);
1330 template <
size_t... DimIndices,
class... Dims,
1331 class = internal::enable_if_permutation<
sizeof...(DimIndices), DimIndices...>>
1333 return internal::transpose_impl<DimIndices...>(
1334 shape, internal::make_index_sequence<
sizeof...(Dims) -
sizeof...(DimIndices)>());
1343 template <
size_t... DimIndices,
class... Dims>
1348 namespace internal {
1350 template <
class Fn,
class Idx>
1351 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void for_each_index_in_order_impl(Fn&& fn,
const Idx& idx) {
1357 template <
class Fn,
class OuterIdx,
class Dim0>
1358 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_each_index_in_order_impl(
1359 Fn&& fn,
const OuterIdx& idx,
const Dim0& dim0) {
1360 for (index_t i : dim0) {
1361 fn(std::tuple_cat(std::tuple<index_t>(i), idx));
1365 template <
class Fn,
class OuterIdx,
class Dim0,
class Dim1>
1366 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_each_index_in_order_impl(
1367 Fn&& fn,
const OuterIdx& idx,
const Dim0& dim0,
const Dim1& dim1) {
1368 for (index_t i : dim0) {
1369 for (index_t j : dim1) {
1370 fn(std::tuple_cat(std::tuple<index_t, index_t>(j, i), idx));
1375 template <
class Fn,
class OuterIdx,
class Dim0,
class... Dims>
1376 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_each_index_in_order_impl(
1377 Fn&& fn,
const OuterIdx& idx,
const Dim0& dim0,
const Dims&... dims) {
1378 for (index_t i : dim0) {
1379 for_each_index_in_order_impl(fn, std::tuple_cat(std::tuple<index_t>(i), idx), dims...);
1383 template <
class Fn,
class OuterIdx,
class Dim0,
class Dim1,
class... Dims>
1384 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_each_index_in_order_impl(
1385 Fn&& fn,
const OuterIdx& idx,
const Dim0& dim0,
const Dim1& dim1,
const Dims&... dims) {
1386 for (index_t i : dim0) {
1387 for (index_t j : dim1) {
1388 for_each_index_in_order_impl(
1389 fn, std::tuple_cat(std::tuple<index_t, index_t>(j, i), idx), dims...);
1394 template <
class Dims,
class Fn,
size_t... Is>
1395 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void for_each_index_in_order(
1396 Fn&& fn,
const Dims& dims, index_sequence<Is...>) {
1399 for_each_index_in_order_impl(fn, std::tuple<>(), std::get<
sizeof...(Is) - 1 - Is>(dims)...);
1402 template <
typename TSrc,
typename TDst>
1403 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void move_assign(TSrc& src, TDst& dst) {
1404 dst = std::move(src);
1407 template <
typename TSrc,
typename TDst>
1408 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void copy_assign(
const TSrc& src, TDst& dst) {
1412 template <
size_t D,
class Ptr0>
1413 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void advance(Ptr0& ptr0) {
1414 std::get<0>(ptr0) += std::get<D>(std::get<1>(ptr0));
1416 template <
size_t D,
class Ptr0,
class Ptr1>
1417 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void advance(Ptr0& ptr0, Ptr1& ptr1) {
1418 std::get<0>(ptr0) += std::get<D>(std::get<1>(ptr0));
1419 std::get<0>(ptr1) += std::get<D>(std::get<1>(ptr1));
1424 template <
class Fn,
class Ptr0,
class... Ptrs>
1425 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_each_value_in_order_inner_dense(
1426 index_t extent, Fn&& fn, Ptr0 NDARRAY_RESTRICT ptr0, Ptrs NDARRAY_RESTRICT... ptrs) {
1427 Ptr0 end = ptr0 + extent;
1428 while (ptr0 < end) {
1429 fn(*ptr0++, *ptrs++...);
1433 template <size_t,
class ExtentType,
class Fn,
class... Ptrs>
1434 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_each_value_in_order_impl(
1435 std::true_type,
const ExtentType& extent, Fn&& fn, Ptrs... ptrs) {
1436 index_t extent_d = std::get<0>(extent);
1437 if (all(std::get<0>(std::get<1>(ptrs)) == 1 ...)) {
1438 for_each_value_in_order_inner_dense(extent_d, fn, std::get<0>(ptrs)...);
1440 for (index_t i = 0; i < extent_d; i++) {
1441 fn(*std::get<0>(ptrs)...);
1442 advance<0>(ptrs...);
1447 template <
size_t D,
class ExtentType,
class Fn,
class... Ptrs>
1448 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_each_value_in_order_impl(
1449 std::false_type,
const ExtentType& extent, Fn&& fn, Ptrs... ptrs) {
1450 index_t extent_d = std::get<D>(extent);
1451 for (index_t i = 0; i < extent_d; i++) {
1452 using is_inner_loop = std::conditional_t<D == 1, std::true_type, std::false_type>;
1453 for_each_value_in_order_impl<D - 1>(is_inner_loop(), extent, fn, ptrs...);
1454 advance<D>(ptrs...);
1458 template <
size_t D,
class ExtentType,
class Fn,
class... Ptrs,
1459 std::enable_if_t<(D < std::tuple_size<ExtentType>::value),
int> = 0>
1460 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void for_each_value_in_order(
1461 const ExtentType& extent, Fn&& fn, Ptrs... ptrs) {
1462 using is_inner_loop = std::conditional_t<D == 0, std::true_type, std::false_type>;
1463 for_each_value_in_order_impl<D>(is_inner_loop(), extent, fn, ptrs...);
1468 template <
size_t D,
class Fn,
class... Ptrs, std::enable_if_t<(D == -1),
int> = 0>
1469 NDARRAY_INLINE NDARRAY_HOST_DEVICE
void for_each_value_in_order(
1470 const std::tuple<>& extent, Fn&& fn, Ptrs... ptrs) {
1471 fn(*std::get<0>(ptrs)...);
1474 template <
size_t Rank>
1475 NDARRAY_HOST_DEVICE
auto make_default_dense_shape() {
1477 using inner_dim = std::conditional_t<(Rank > 0), std::tuple<
dense_dim<>>, std::tuple<>>;
1478 return make_shape_from_tuple(
1479 std::tuple_cat(inner_dim(), tuple_of_n<dim<>, max<size_t>(1, Rank) - 1>()));
1482 template <index_t CurrentStr
ide>
1483 NDARRAY_HOST_DEVICE std::tuple<> make_compact_dims() {
1484 return std::tuple<>();
1487 template <index_t CurrentStride, index_t Min, index_t Extent, index_t Stride,
class... Dims>
1488 NDARRAY_HOST_DEVICE
auto make_compact_dims(
1493 constexpr index_t NextStride = static_mul(CurrentStride, Extent);
1497 constexpr index_t NewStride = is_static(Stride) ? Stride : CurrentStride;
1499 make_compact_dims<NextStride>(dims...));
1502 template <
class Dims,
size_t... Is>
1503 NDARRAY_HOST_DEVICE
auto make_compact_dims(
const Dims& dims, index_sequence<Is...>) {
1511 constexpr index_t MinStride =
1512 variadic_max(std::tuple_element<Is, Dims>::type::Stride == 1
1513 ? static_abs(std::tuple_element<Is, Dims>::type::Extent)
1515 constexpr
bool AnyStrideGreaterThanOne = any((std::tuple_element<Is, Dims>::type::Stride > 1)...);
1516 constexpr
bool AllDynamic = all(is_dynamic(std::tuple_element<Is, Dims>::type::Stride)...);
1517 constexpr index_t NextStride = AnyStrideGreaterThanOne ? dynamic : (AllDynamic ? 1 : MinStride);
1518 return make_compact_dims<NextStride>(std::get<Is>(dims)...);
1521 template <index_t Min, index_t Extent, index_t Str
ide,
class DimSrc>
1523 return (is_dynamic(Min) || src.min() == Min) && (is_dynamic(Extent) || src.extent() == Extent) &&
1524 (is_dynamic(Stride) || src.stride() == Stride);
1527 template <
class... DimsDst,
class ShapeSrc,
size_t... Is>
1528 NDARRAY_HOST_DEVICE
bool is_shape_compatible(
1530 return all(is_dim_compatible(DimsDst(), src.template
dim<Is>())...);
1533 template <
class DimA,
class DimB>
1534 NDARRAY_HOST_DEVICE
auto clamp_dims(
const DimA& a,
const DimB& b) {
1535 constexpr index_t Min = static_max(DimA::Min, DimB::Min);
1536 constexpr index_t Max = static_min(DimA::Max, DimB::Max);
1537 constexpr index_t Extent = static_add(static_sub(Max, Min), 1);
1538 index_t min = internal::max(a.min(), b.min());
1539 index_t max = internal::min(a.max(), b.max());
1540 index_t extent = max - min + 1;
1544 template <
class DimsA,
class DimsB,
size_t... Is>
1545 NDARRAY_HOST_DEVICE
auto clamp(
const DimsA& a,
const DimsB& b, index_sequence<Is...>) {
1546 return make_shape(clamp_dims(std::get<Is>(a), std::get<Is>(b))...);
1551 constexpr
size_t index_of() {
1556 template <
size_t I,
size_t I0,
size_t... Is>
1557 constexpr
size_t index_of() {
1558 return I == I0 ? 0 : 1 + index_of<I, Is...>();
1563 template <size_t I, class T, std::enable_if_t<(I < std::tuple_size<T>::value),
int> = 0>
1564 NDARRAY_INLINE NDARRAY_HOST_DEVICE
auto get_or_empty(
const T& t) {
1565 return std::make_tuple(std::get<I>(t));
1567 template <
size_t I,
class T, std::enable_if_t<(I >= std::tuple_size<T>::value),
int> = 0>
1568 NDARRAY_INLINE NDARRAY_HOST_DEVICE std::tuple<> get_or_empty(
const T& t) {
1569 return std::tuple<>();
1573 template <
size_t... Is,
class T,
size_t... Js>
1574 NDARRAY_HOST_DEVICE
auto unshuffle(
const T& t, index_sequence<Js...>) {
1575 return std::tuple_cat(get_or_empty<index_of<Js, Is...>()>(t)...);
1577 template <
size_t... Is,
class... Ts>
1578 NDARRAY_HOST_DEVICE
auto unshuffle(
const std::tuple<Ts...>& t) {
1579 return unshuffle<Is...>(t, make_index_sequence<variadic_max(Is...) + 1>());
1582 template <
class ShapeDst,
class ShapeSrc>
1583 using enable_if_shapes_compatible =
1584 std::enable_if_t<std::is_constructible<ShapeDst, ShapeSrc>::value>;
1586 template <
class ShapeDst,
class ShapeSrc>
1587 using enable_if_shapes_explicitly_compatible =
1588 std::enable_if_t<(ShapeSrc::rank() <= ShapeSrc::rank())>;
1590 template <
class ShapeDst,
class ShapeSrc>
1591 using enable_if_shapes_copy_compatible = std::enable_if_t<(ShapeDst::rank() == ShapeSrc::rank())>;
1593 template <
class Alloc>
1594 using enable_if_allocator = decltype(internal::declval<Alloc>().allocate(0));
1600 template <
size_t Rank>
1601 using shape_of_rank = decltype(make_shape_from_tuple(internal::tuple_of_n<dim<>, Rank>()));
1605 template <
size_t Rank>
1606 using dense_shape = decltype(internal::make_default_dense_shape<Rank>());
1619 template <
class Shape>
1621 auto static_compact =
1622 make_shape_from_tuple(internal::make_compact_dims(s.dims(),
typename Shape::dim_indices()));
1623 static_compact.resolve();
1624 return static_compact;
1628 template <index_t... Extents>
1634 template <
class ShapeDst,
class ShapeSrc,
1635 class = internal::enable_if_shapes_compatible<ShapeSrc, ShapeDst>>
1637 return internal::is_shape_compatible(ShapeDst(), src,
typename ShapeSrc::dim_indices());
1648 template <
class ShapeDst,
class ShapeSrc,
1649 class = internal::enable_if_shapes_explicitly_compatible<ShapeDst, ShapeSrc>>
1652 internal::is_trivial_slice<ShapeDst::rank()>(src.dims(),
typename ShapeSrc::dim_indices()));
1653 return internal::convert_dims<typename ShapeDst::dims_type>(
1654 src.dims(),
typename ShapeDst::dim_indices());
1659 template <
class ShapeDst,
class ShapeSrc,
1660 class = internal::enable_if_shapes_explicitly_compatible<ShapeSrc, ShapeDst>>
1662 return internal::is_shape_compatible(ShapeDst(), src,
typename ShapeSrc::dim_indices());
1673 template <
class Shape,
class Fn,
1674 class = internal::enable_if_callable<Fn, typename Shape::index_type>>
1675 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_each_index_in_order(
const Shape& shape, Fn&& fn) {
1676 internal::for_each_index_in_order(fn, shape.dims(),
typename Shape::dim_indices());
1678 template <
class Shape,
class Ptr,
class Fn,
1679 class = internal::enable_if_callable<Fn, typename std::remove_pointer<Ptr>::type&>>
1680 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_each_value_in_order(
1681 const Shape& shape, Ptr base, Fn&& fn) {
1683 assert(shape.empty());
1688 auto base_and_stride = std::make_pair(base, shape.stride());
1689 internal::for_each_value_in_order<Shape::rank() - 1>(shape.extent(), fn, base_and_stride);
1700 template <
class Shape,
class ShapeA,
class PtrA,
class ShapeB,
class PtrB,
class Fn,
1701 class = internal::enable_if_callable<Fn, typename std::remove_pointer<PtrA>::type&,
1702 typename std::remove_pointer<PtrB>::type&>>
1703 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_each_value_in_order(
const Shape& shape,
1704 const ShapeA& shape_a, PtrA base_a,
const ShapeB& shape_b, PtrB base_b, Fn&& fn) {
1705 if (!base_a || !base_b) {
1706 assert(shape.empty());
1709 base_a += shape_a[shape.min()];
1710 base_b += shape_b[shape.min()];
1713 auto a = std::make_pair(base_a, shape_a.stride());
1714 auto b = std::make_pair(base_b, shape_b.stride());
1715 internal::for_each_value_in_order<Shape::rank() - 1>(shape.extent(), fn, a, b);
1718 namespace internal {
1720 NDARRAY_INLINE NDARRAY_HOST_DEVICE
bool can_fuse(
const dim<>& inner,
const dim<>& outer) {
1724 NDARRAY_INLINE NDARRAY_HOST_DEVICE dim<> fuse(
const dim<>& inner,
const dim<>& outer) {
1725 assert(can_fuse(inner, outer));
1735 inline bool operator<(const dim<>& l,
const dim<>& r) {
return l.stride() < r.stride(); }
1737 inline bool operator<(
const copy_dims& l,
const copy_dims& r) {
1738 return l.dst.stride() < r.dst.stride();
1744 template <
class Iterator>
1745 NDARRAY_HOST_DEVICE
void bubble_sort(Iterator begin, Iterator end) {
1746 for (Iterator i = begin; i != end; ++i) {
1747 for (Iterator j = i; j != end; ++j) {
1748 if (*j < *i) { std::swap(*i, *j); }
1755 template <
class Shape>
1757 auto dims = internal::tuple_to_array<dim<>>(shape.dims());
1760 bubble_sort(dims.begin(), dims.end());
1763 size_t rank = dims.size();
1764 for (
size_t i = 0; i + 1 < rank;) {
1765 if (can_fuse(dims[i], dims[i + 1])) {
1766 dims[i] = fuse(dims[i], dims[i + 1]);
1767 for (
size_t j = i + 1; j + 1 < rank; j++) {
1768 dims[j] = dims[j + 1];
1778 for (
size_t i = rank; i < dims.size(); i++) {
1779 dims[i] = dim<>(0, 1, 0);
1787 template <
class ShapeSrc,
class ShapeDst,
1788 class = enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
1789 NDARRAY_HOST_DEVICE
auto dynamic_optimize_copy_shapes(
const ShapeSrc& src,
const ShapeDst& dst) {
1790 constexpr
size_t rank = ShapeSrc::rank();
1791 static_assert(rank == ShapeDst::rank(),
"copy shapes must have same rank.");
1792 auto src_dims = internal::tuple_to_array<dim<>>(src.dims());
1793 auto dst_dims = internal::tuple_to_array<dim<>>(dst.dims());
1795 std::array<copy_dims, rank> dims;
1796 for (
size_t i = 0; i < rank; i++) {
1797 dims[i] = {src_dims[i], dst_dims[i]};
1801 bubble_sort(dims.begin(), dims.end());
1804 size_t new_rank = dims.size();
1805 for (
size_t i = 0; i + 1 < new_rank;) {
1806 if (dims[i].src.extent() == dims[i].dst.extent() && can_fuse(dims[i].src, dims[i + 1].src) &&
1807 can_fuse(dims[i].dst, dims[i + 1].dst)) {
1808 dims[i].src = fuse(dims[i].src, dims[i + 1].src);
1809 dims[i].dst = fuse(dims[i].dst, dims[i + 1].dst);
1810 for (
size_t j = i + 1; j + 1 < new_rank; j++) {
1811 dims[j] = dims[j + 1];
1821 for (
size_t i = new_rank; i < dims.size(); i++) {
1822 dims[i] = {dim<>(0, 1, 0), dim<>(0, 1, 0)};
1825 for (
size_t i = 0; i < dims.size(); i++) {
1826 src_dims[i] = dims[i].src;
1827 dst_dims[i] = dims[i].dst;
1830 return std::make_pair(
1834 template <
class Shape>
1835 NDARRAY_HOST_DEVICE
auto optimize_shape(
const Shape& shape) {
1837 return dynamic_optimize_shape(shape);
1840 template <
class Dim0>
1841 NDARRAY_HOST_DEVICE
auto optimize_shape(
const shape<Dim0>& shape) {
1846 template <
class ShapeSrc,
class ShapeDst>
1847 NDARRAY_HOST_DEVICE
auto optimize_copy_shapes(
const ShapeSrc& src,
const ShapeDst& dst) {
1848 return dynamic_optimize_copy_shapes(src, dst);
1851 template <
class Dim0Src,
class Dim0Dst>
1852 NDARRAY_HOST_DEVICE
auto optimize_copy_shapes(
1855 return std::make_pair(src, dst);
1859 NDARRAY_HOST_DEVICE T* pointer_add(T* x, index_t offset) {
1860 return x !=
nullptr ? x + offset : x;
1866 template <
class Shape>
1869 using shape_type = Shape;
1875 for_each_index_in_order(shape, fn);
1881 template <
class Ptr,
class Fn>
1882 NDARRAY_HOST_DEVICE
static void for_each_value(
const Shape& shape, Ptr base, Fn&& fn) {
1883 auto opt_shape = internal::optimize_shape(shape);
1884 for_each_value_in_order(opt_shape, base, fn);
1890 template <
class ShapeSrc,
class ShapeDst>
1893 using src_shape_type = ShapeSrc;
1894 using dst_shape_type = ShapeDst;
1900 template <
class Fn,
class TSrc,
class TDst>
1902 const ShapeSrc& shape_src, TSrc src,
const ShapeDst& shape_dst, TDst dst, Fn&& fn) {
1905 auto opt_shape = internal::optimize_copy_shapes(shape_src, shape_dst);
1906 const auto& opt_shape_src = opt_shape.first;
1907 const auto& opt_shape_dst = opt_shape.second;
1909 for_each_value_in_order(opt_shape_dst, opt_shape_src, src, opt_shape_dst, dst, fn);
1929 template <
size_t... LoopOrder,
class Shape,
class Fn,
1930 class = internal::enable_if_callable<Fn, typename Shape::index_type>,
1931 std::enable_if_t<(
sizeof...(LoopOrder) == 0),
int> = 0>
1935 template <
size_t... LoopOrder,
class Shape,
class Fn,
1936 class = internal::enable_if_applicable<Fn, typename Shape::index_type>,
1937 std::enable_if_t<(
sizeof...(LoopOrder) == 0),
int> = 0>
1938 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_all_indices(
const Shape& s, Fn&& fn) {
1939 using index_type =
typename Shape::index_type;
1940 for_each_index(s, [fn = std::move(fn)](
const index_type& i) { internal::apply(fn, i); });
1942 template <
size_t... LoopOrder,
class Shape,
class Fn,
1943 class = internal::enable_if_callable<Fn,
index_of_rank<
sizeof...(LoopOrder)>>,
1944 std::enable_if_t<(
sizeof...(LoopOrder) != 0),
int> = 0>
1945 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_each_index(
const Shape& s, Fn&& fn) {
1947 for_each_index_in_order(reorder<LoopOrder...>(s),
1948 [fn = std::move(fn)](
const index_type& i) { fn(internal::unshuffle<LoopOrder...>(i)); });
1950 template <
size_t... LoopOrder,
class Shape,
class Fn,
1951 class = internal::enable_if_callable<Fn, decltype(LoopOrder)...>,
1952 std::enable_if_t<(
sizeof...(LoopOrder) != 0),
int> = 0>
1953 NDARRAY_UNIQUE NDARRAY_HOST_DEVICE
void for_all_indices(
const Shape& s, Fn&& fn) {
1955 for_each_index_in_order(reorder<LoopOrder...>(s), [fn = std::move(fn)](
const index_type& i) {
1956 internal::apply(fn, internal::unshuffle<LoopOrder...>(i));
1960 template <
class T,
class Shape>
1962 template <
class T,
class Shape,
class Alloc>
1965 template <
class T,
class Shape>
1969 template <
class T,
class Shape>
1971 return {base, shape};
1974 namespace internal {
1976 struct no_resolve {};
1978 template <
class T,
class Shape>
1979 NDARRAY_HOST_DEVICE
array_ref<T, Shape> make_array_ref_no_resolve(T* base,
const Shape& shape) {
1980 return {base, shape, no_resolve{}};
1983 template <
class T,
class Shape,
class... Args>
1984 NDARRAY_HOST_DEVICE
auto make_array_ref_at(
1985 T base,
const Shape& shape,
const std::tuple<Args...>& args) {
1986 auto new_shape = shape[args];
1987 auto new_mins = mins_of_intervals(args, shape.dims(), make_index_sequence<
sizeof...(Args)>());
1988 auto old_min_offset = shape[new_mins];
1989 return make_array_ref_no_resolve(internal::pointer_add(base, old_min_offset), new_shape);
1997 template <
class T,
class Shape>
2007 using index_type =
typename Shape::index_type;
2009 using size_type = size_t;
2012 static constexpr
size_t rank() {
return Shape::rank(); }
2015 static constexpr
bool is_scalar() {
return Shape::is_scalar(); }
2018 template <
class OtherShape>
2019 using enable_if_shape_compatible = internal::enable_if_shapes_compatible<Shape, OtherShape>;
2021 template <
class... Args>
2022 using enable_if_all_indices =
2023 std::enable_if_t<
sizeof...(Args) == rank() && internal::all_of_type<index_t, Args...>::value>;
2025 template <
class... Args>
2026 using enable_if_any_slices =
2027 std::enable_if_t<
sizeof...(Args) == rank() &&
2028 internal::all_of_any_type<std::tuple<interval<>, dim<>>, Args...>::value &&
2029 !internal::all_of_type<index_t, Args...>::value>;
2031 template <
size_t Dim>
2032 using enable_if_dim = std::enable_if_t < Dim<rank()>;
2040 NDARRAY_HOST_DEVICE
array_ref(pointer base =
nullptr,
const Shape& shape = Shape())
2041 : base_(base), shape_(shape) {
2045 NDARRAY_HOST_DEVICE
array_ref(pointer base,
const Shape& shape, internal::no_resolve)
2046 : base_(base), shape_(shape) {}
2049 NDARRAY_HOST_DEVICE array_ref(
const array_ref& other) =
default;
2050 NDARRAY_HOST_DEVICE array_ref(array_ref&& other) =
default;
2051 NDARRAY_HOST_DEVICE array_ref& operator=(
const array_ref& other) =
default;
2052 NDARRAY_HOST_DEVICE array_ref& operator=(array_ref&& other) =
default;
2055 template <
class OtherShape,
class = enable_if_shape_compatible<OtherShape>>
2057 : array_ref(other.base(), other.shape(), internal::no_resolve{}) {}
2058 template <
class OtherShape,
class = enable_if_shape_compatible<OtherShape>>
2060 base_ = other.
base();
2061 shape_ = other.
shape();
2066 NDARRAY_HOST_DEVICE reference
operator[](
const index_type& indices)
const {
2067 return base_[shape_[indices]];
2069 template <
class... Args,
class = enable_if_all_indices<Args...>>
2070 NDARRAY_HOST_DEVICE reference operator()(Args... indices)
const {
2071 return base_[shape_(indices...)];
2078 template <
class... Args,
class = enable_if_any_slices<Args...>>
2079 NDARRAY_HOST_DEVICE
auto operator[](
const std::tuple<Args...>& args)
const {
2080 return internal::make_array_ref_at(base_, shape_, args);
2082 template <
class... Args,
class = enable_if_any_slices<Args...>>
2083 NDARRAY_HOST_DEVICE
auto operator()(Args... args)
const {
2084 return internal::make_array_ref_at(base_, shape_, std::make_tuple(args...));
2090 template <
class Fn,
class =
internal::enable_if_callable<Fn, reference>>
2092 shape_traits_type::for_each_value(shape_, base_, fn);
2096 NDARRAY_HOST_DEVICE pointer
base()
const {
return base_; }
2100 NDARRAY_HOST_DEVICE pointer
data()
const {
2101 return internal::pointer_add(base_, shape_.flat_min());
2105 NDARRAY_HOST_DEVICE Shape&
shape() {
return shape_; }
2106 NDARRAY_HOST_DEVICE
const Shape& shape()
const {
return shape_; }
2108 template <
size_t D,
class = enable_if_dim<D>>
2109 NDARRAY_HOST_DEVICE
auto&
dim() {
2110 return shape_.template
dim<D>();
2112 template <
size_t D,
class = enable_if_dim<D>>
2113 NDARRAY_HOST_DEVICE
const auto&
dim()
const {
2114 return shape_.template
dim<D>();
2117 NDARRAY_HOST_DEVICE size_type size()
const {
return shape_.size(); }
2118 NDARRAY_HOST_DEVICE
bool empty()
const {
return base() !=
nullptr ? shape_.empty() :
true; }
2119 NDARRAY_HOST_DEVICE
bool is_compact()
const {
return shape_.is_compact(); }
2123 NDARRAY_HOST_DEVICE
auto&
i() {
return shape_.i(); }
2124 NDARRAY_HOST_DEVICE
const auto& i()
const {
return shape_.i(); }
2125 NDARRAY_HOST_DEVICE
auto& j() {
return shape_.j(); }
2126 NDARRAY_HOST_DEVICE
const auto& j()
const {
return shape_.j(); }
2127 NDARRAY_HOST_DEVICE
auto& k() {
return shape_.k(); }
2128 NDARRAY_HOST_DEVICE
const auto& k()
const {
return shape_.k(); }
2132 NDARRAY_HOST_DEVICE
auto&
x() {
return shape_.x(); }
2133 NDARRAY_HOST_DEVICE
const auto& x()
const {
return shape_.x(); }
2134 NDARRAY_HOST_DEVICE
auto& y() {
return shape_.y(); }
2135 NDARRAY_HOST_DEVICE
const auto& y()
const {
return shape_.y(); }
2136 NDARRAY_HOST_DEVICE
auto& z() {
return shape_.z(); }
2137 NDARRAY_HOST_DEVICE
const auto& z()
const {
return shape_.z(); }
2138 NDARRAY_HOST_DEVICE
auto& c() {
return shape_.c(); }
2139 NDARRAY_HOST_DEVICE
const auto& c()
const {
return shape_.c(); }
2140 NDARRAY_HOST_DEVICE
auto& w() {
return shape_.w(); }
2141 NDARRAY_HOST_DEVICE
const auto& w()
const {
return shape_.w(); }
2145 NDARRAY_HOST_DEVICE index_t
width()
const {
return shape_.width(); }
2146 NDARRAY_HOST_DEVICE index_t height()
const {
return shape_.height(); }
2147 NDARRAY_HOST_DEVICE index_t channels()
const {
return shape_.channels(); }
2151 NDARRAY_HOST_DEVICE index_t
rows()
const {
return shape_.rows(); }
2152 NDARRAY_HOST_DEVICE index_t columns()
const {
return shape_.columns(); }
2159 NDARRAY_HOST_DEVICE
bool operator!=(
const array_ref& other)
const {
2160 if (shape_ != other.shape_) {
return true; }
2165 bool result =
false;
2167 shape_, base_, other.shape_, other.base_, [&](const_reference a, const_reference b) {
2168 if (a != b) { result = true; }
2172 NDARRAY_HOST_DEVICE
bool operator==(
const array_ref& other)
const {
return !operator!=(other); }
2183 template <std::
size_t R = rank(),
typename = std::enable_if_t<R == 0>>
2184 NDARRAY_HOST_DEVICE
operator reference()
const {
2190 NDARRAY_HOST_DEVICE
void set_shape(
const Shape& new_shape, index_t offset = 0) {
2191 assert(new_shape.is_resolved());
2192 assert(new_shape.is_subset_of(shape_, -offset));
2194 base_ = internal::pointer_add(base_, offset);
2199 template <
class T,
size_t Rank>
2201 template <
class T,
size_t Rank>
2205 template <
class T,
size_t Rank>
2207 template <
class T,
size_t Rank>
2211 template <
class T,
class Shape,
class Alloc = std::allocator<T>>
2216 using alloc_traits = std::allocator_traits<Alloc>;
2222 using pointer =
typename alloc_traits::pointer;
2223 using const_pointer =
typename alloc_traits::const_pointer;
2227 using index_type =
typename Shape::index_type;
2230 using size_type = size_t;
2233 static constexpr
size_t rank() {
return Shape::rank(); }
2236 static constexpr
bool is_scalar() {
return Shape::is_scalar(); }
2239 template <
class... Args>
2240 using enable_if_all_indices =
2241 std::enable_if_t<
sizeof...(Args) == rank() && internal::all_of_type<index_t, Args...>::value>;
2243 template <
class... Args>
2244 using enable_if_any_slices =
2245 std::enable_if_t<
sizeof...(Args) == rank() &&
2246 internal::all_of_any_type<std::tuple<interval<>, dim<>>, Args...>::value &&
2247 !internal::all_of_type<index_t, Args...>::value>;
2249 template <
size_t Dim>
2250 using enable_if_dim = std::enable_if_t < Dim<rank()>;
2254 size_type buffer_size_;
2262 size_type flat_extent = shape_.flat_extent();
2263 if (flat_extent > 0) {
2264 buffer_size_ = flat_extent;
2265 buffer_ = alloc_traits::allocate(alloc_, buffer_size_);
2267 base_ = buffer_ - shape_.flat_min();
2272 assert(base_ || shape_.empty());
2273 for_each_value([&](T& x) { alloc_traits::construct(alloc_, &x); });
2275 void construct(
const T& init) {
2276 assert(base_ || shape_.empty());
2277 for_each_value([&](T& x) { alloc_traits::construct(alloc_, &x, init); });
2279 void copy_construct(
const array& other) {
2280 assert(base_ || shape_.empty());
2281 assert(shape_ == other.shape_);
2282 copy_shape_traits_type::for_each_value(other.shape_, other.base_, shape_, base_,
2283 [&](const_reference src, reference dst) { alloc_traits::construct(alloc_, &dst, src); });
2285 void move_construct(
array& other) {
2286 assert(base_ || shape_.empty());
2287 assert(shape_ == other.shape_);
2288 copy_shape_traits_type::for_each_value(
2289 other.shape_, other.base_, shape_, base_, [&](reference src, reference dst) {
2290 alloc_traits::construct(alloc_, &dst, std::move(src));
2296 assert(base_ || shape_.empty());
2297 for_each_value([&](T& x) { alloc_traits::destroy(alloc_, &x); });
2305 alloc_traits::deallocate(alloc_, buffer_, buffer_size_);
2310 static Alloc get_allocator_for_move(
array& other, std::true_type) {
2311 return std::move(other.alloc_);
2313 static Alloc get_allocator_for_move(
array& other, std::false_type) {
return Alloc(); }
2314 static Alloc get_allocator_for_move(
array& other) {
2315 return get_allocator_for_move(
2316 other,
typename alloc_traits::propagate_on_container_move_assignment());
2319 void swap_except_allocator(
array& other) {
2321 swap(buffer_, other.buffer_);
2322 swap(buffer_size_, other.buffer_size_);
2323 swap(base_, other.base_);
2324 swap(shape_, other.shape_);
2332 explicit array(
const Alloc& alloc) :
array(Shape(), alloc) {}
2336 array(
const Shape& shape,
const T& value,
const Alloc& alloc) :
array(alloc) {
2337 assign(shape, value);
2339 array(
const Shape& shape,
const T& value) :
array() { assign(shape, value); }
2343 explicit array(
const Shape& shape,
const Alloc& alloc)
2344 : alloc_(alloc), buffer_(nullptr), buffer_size_(0), base_(nullptr), shape_(shape) {
2348 explicit array(
const Shape& shape)
2349 : buffer_(
nullptr), buffer_size_(0), base_(
nullptr), shape_(shape) {
2357 : alloc_(alloc_traits::select_on_container_copy_construction(other.get_allocator())),
2358 buffer_(nullptr), buffer_size_(0), base_(nullptr) {
2372 : alloc_(get_allocator_for_move(other)), buffer_(nullptr), buffer_size_(0), base_(nullptr) {
2376 if (
typename alloc_traits::propagate_on_container_move_assignment() ||
2378 swap_except_allocator(other);
2380 shape_ = other.shape_;
2382 move_construct(other);
2392 : alloc_(alloc), buffer_(nullptr), buffer_size_(0), base_(nullptr) {
2395 swap_except_allocator(other);
2397 shape_ = other.shape_;
2399 move_construct(other);
2403 ~
array() { deallocate(); }
2414 if (base_ == other.base_) {
2416 assert(shape_ == other.shape_);
2418 shape_ = other.shape_;
2419 assert(shape_.empty());
2424 if (alloc_traits::propagate_on_container_copy_assignment::value) {
2439 if (base_ == other.base_) {
2441 assert(shape_ == other.shape_);
2443 swap(shape_, other.shape_);
2444 assert(shape_.empty());
2449 if (std::allocator_traits<allocator_type>::propagate_on_container_move_assignment::value) {
2450 swap(alloc_, other.alloc_);
2451 swap_except_allocator(other);
2453 swap_except_allocator(other);
2455 assign(std::move(other));
2464 if (base_ == other.base_) {
2466 assert(shape_ == other.shape_);
2468 shape_ = other.shape_;
2469 assert(shape_.empty());
2473 if (base_ && shape_ == other.shape_) {
2477 shape_ = other.shape_;
2480 copy_construct(other);
2482 void assign(
array&& other) {
2483 if (base_ == other.base_) {
2485 assert(shape_ == other.shape_);
2487 shape_ = other.shape_;
2488 assert(shape_.empty());
2492 if (base_ && shape_ == other.shape_) {
2496 shape_ = other.shape_;
2499 move_construct(other);
2506 if (shape_ == shape) {
2520 reference
operator[](
const index_type& indices) {
return base_[shape_[indices]]; }
2521 const_reference operator[](
const index_type& indices)
const {
return base_[shape_[indices]]; }
2522 template <
class... Args,
class = enable_if_all_indices<Args...>>
2523 reference operator()(Args... indices) {
2524 return base_[shape_(indices...)];
2526 template <
class... Args,
class = enable_if_all_indices<Args...>>
2527 const_reference operator()(Args... indices)
const {
2528 return base_[shape_(indices...)];
2534 template <
class... Args,
class = enable_if_any_slices<Args...>>
2536 return internal::make_array_ref_at(base_, shape_, args);
2538 template <
class... Args,
class = enable_if_any_slices<Args...>>
2539 auto operator()(Args... args) {
2540 return internal::make_array_ref_at(base_, shape_, std::make_tuple(args...));
2542 template <
class... Args,
class = enable_if_any_slices<Args...>>
2543 auto operator[](
const std::tuple<Args...>& args)
const {
2544 return internal::make_array_ref_at(base_, shape_, args);
2546 template <
class... Args,
class = enable_if_any_slices<Args...>>
2547 auto operator()(Args... args)
const {
2548 return internal::make_array_ref_at(base_, shape_, std::make_tuple(args...));
2553 template <
class Fn,
class =
internal::enable_if_callable<Fn, reference>>
2555 shape_traits_type::for_each_value(shape_, base_, fn);
2557 template <
class Fn,
class =
internal::enable_if_callable<Fn, const_reference>>
2558 void for_each_value(Fn&& fn)
const {
2559 shape_traits_type::for_each_value(shape_, base_, fn);
2564 const_pointer base()
const {
return base_; }
2568 pointer
data() {
return internal::pointer_add(base_, shape_.flat_min()); }
2569 const_pointer data()
const {
return internal::pointer_add(base_, shape_.flat_min()); }
2572 const Shape&
shape()
const {
return shape_; }
2574 template <
size_t D,
class = enable_if_dim<D>>
2575 const auto&
dim()
const {
2576 return shape_.template
dim<D>();
2579 size_type size()
const {
return shape_.size(); }
2580 bool empty()
const {
return shape_.empty(); }
2581 bool is_compact()
const {
return shape_.is_compact(); }
2597 new_shape.resolve();
2598 if (shape_ == new_shape) {
return; }
2601 array new_array(new_shape);
2604 Shape intersection =
2605 internal::clamp(new_shape.dims(), shape_.dims(),
typename Shape::dim_indices());
2606 pointer intersection_base =
2607 internal::pointer_add(new_array.base_, new_shape[intersection.min()]);
2608 copy_shape_traits_type::for_each_value(
2609 shape_, base_, intersection, intersection_base, internal::move_assign<T, T>);
2611 *
this = std::move(new_array);
2618 void set_shape(
const Shape& new_shape, index_t offset = 0) {
2619 static_assert(std::is_trivial<value_type>::value,
"set_shape is broken for non-trivial types.");
2620 assert(new_shape.is_resolved());
2621 assert(new_shape.is_subset_of(shape_, -offset));
2623 base_ = internal::pointer_add(base_, offset);
2628 const auto&
i()
const {
return shape_.i(); }
2629 const auto& j()
const {
return shape_.j(); }
2630 const auto& k()
const {
return shape_.k(); }
2634 const auto&
x()
const {
return shape_.x(); }
2635 const auto& y()
const {
return shape_.y(); }
2636 const auto& z()
const {
return shape_.z(); }
2637 const auto& c()
const {
return shape_.c(); }
2638 const auto& w()
const {
return shape_.w(); }
2642 index_t
width()
const {
return shape_.width(); }
2643 index_t height()
const {
return shape_.height(); }
2644 index_t channels()
const {
return shape_.channels(); }
2648 index_t
rows()
const {
return shape_.rows(); }
2649 index_t columns()
const {
return shape_.columns(); }
2655 bool operator==(
const array& other)
const {
return cref() == other.cref(); }
2662 if (alloc_traits::propagate_on_container_swap::value) {
2663 swap(alloc_, other.alloc_);
2664 swap_except_allocator(other);
2668 array temp(std::move(other));
2669 other = std::move(*
this);
2670 *
this = std::move(temp);
2684 template <std::
size_t R = rank(),
typename = std::enable_if_t<R == 0>>
2685 operator reference() {
2688 template <std::
size_t R = rank(),
typename = std::enable_if_t<R == 0>>
2689 operator const_reference()
const {
2694 template <
typename NewShape,
typename T2,
typename OldShape,
typename Alloc2>
2697 template <
typename NewShape,
typename T2,
typename OldShape,
typename Alloc2>
2703 template <
class T,
size_t Rank,
class Alloc = std::allocator<T>>
2707 template <
class T,
size_t Rank,
class Alloc = std::allocator<T>>
2711 template <
class T,
class Shape,
class Alloc = std::allocator<T>,
2712 class =
internal::enable_if_allocator<Alloc>>
2713 auto make_array(
const Shape& shape,
const Alloc& alloc = Alloc()) {
2716 template <
class T,
class Shape,
class Alloc = std::allocator<T>,
2717 class =
internal::enable_if_allocator<Alloc>>
2718 auto make_array(
const Shape& shape,
const T& value,
const Alloc& alloc = Alloc()) {
2723 template <
class T,
class Shape,
class Alloc>
2731 template <
class TSrc,
class TDst,
class ShapeSrc,
class ShapeDst,
2732 class = internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2734 if (dst.
shape().empty()) {
return; }
2736 assert(src.
shape().is_in_range(dst.
shape().min()) && src.
shape().is_in_range(dst.
shape().max()));
2739 src.
shape(), src.
base(), dst.
shape(), dst.
base(), internal::copy_assign<TSrc, TDst>);
2741 template <
class TSrc,
class TDst,
class ShapeSrc,
class ShapeDst,
class AllocDst,
2742 class = internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2746 template <
class TSrc,
class TDst,
class ShapeSrc,
class ShapeDst,
class AllocSrc,
2747 class = internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2749 copy(src.cref(), dst);
2751 template <
class TSrc,
class TDst,
class ShapeSrc,
class ShapeDst,
class AllocSrc,
class AllocDst,
2752 class = internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2758 template <
class T,
class ShapeSrc,
2759 class Alloc = std::allocator<typename std::remove_const<T>::type>>
2765 template <
class T,
class ShapeSrc,
class AllocSrc,
class AllocDst = AllocSrc,
2766 class = internal::enable_if_allocator<AllocDst>>
2772 template <
class T,
class ShapeSrc,
class ShapeDst,
2773 class Alloc = std::allocator<typename std::remove_const<T>::type>,
2774 class = internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2776 const
array_ref<T, ShapeSrc>& src,
const ShapeDst& shape,
const Alloc& alloc = Alloc()) {
2781 template <
class T,
class ShapeSrc,
class ShapeDst,
class AllocSrc,
class AllocDst = AllocSrc,
2782 class = internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2784 const AllocDst& alloc = AllocDst()) {
2785 return make_copy(src.cref(), shape, alloc);
2790 template <class T, class Shape, class Alloc = std::allocator<typename std::remove_const<T>::type>>
2794 template <
class T,
class Shape,
class AllocSrc,
class AllocDst = AllocSrc>
2802 template <
class TSrc,
class TDst,
class ShapeSrc,
class ShapeDst,
2803 class = internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2805 if (dst.
shape().empty()) {
return; }
2807 assert(src.
shape().is_in_range(dst.
shape().min()) && src.
shape().is_in_range(dst.
shape().max()));
2810 src.
shape(), src.
base(), dst.
shape(), dst.
base(), internal::move_assign<TSrc, TDst>);
2812 template <
class TSrc,
class TDst,
class ShapeSrc,
class ShapeDst,
class AllocDst,
2813 class = internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2817 template <
class TSrc,
class TDst,
class ShapeSrc,
class ShapeDst,
class AllocSrc,
2818 class = internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2822 template <
class TSrc,
class TDst,
class ShapeSrc,
class ShapeDst,
class AllocSrc,
class AllocDst,
2823 class = internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2827 template <
class T,
class Shape,
class Alloc>
2829 dst = std::move(src);
2834 template <
class T,
class ShapeSrc,
class ShapeDst,
class Alloc = std::allocator<T>,
2835 class =
internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2842 template <
class T,
class ShapeSrc,
class ShapeDst,
class AllocSrc,
class AllocDst = AllocSrc,
2843 class = internal::enable_if_shapes_copy_compatible<ShapeDst, ShapeSrc>>
2848 template <
class T,
class Shape,
class Alloc>
2859 template <
class T,
class Shape,
class Alloc = std::allocator<T>>
2863 template <
class T,
class Shape,
class AllocSrc,
class AllocDst = AllocSrc>
2867 template <
class T,
class Shape,
class Alloc>
2873 template <
class T,
class Shape>
2877 template <
class T,
class Shape,
class Alloc>
2885 template <
class T,
class Shape,
class Generator,
class =
internal::enable_if_callable<Generator>>
2889 template <
class T,
class Shape,
class Alloc,
class Generator,
2890 class = internal::enable_if_callable<Generator>>
2898 template <
typename T,
typename Shape,
class Fn>
2900 using index_type =
typename Shape::index_type;
2902 dst.
shape(), [&, fn = std::move(fn)](
const index_type& idx) { dst[idx] = fn(idx); });
2904 template <
typename T,
typename Shape,
class Fn>
2912 template <
typename T,
typename Shape,
class Fn>
2914 using index_type =
typename Shape::index_type;
2916 dst, [fn = std::move(fn)](
const index_type& idx) {
return internal::apply(fn, idx); });
2918 template <
typename T,
typename Shape,
class Fn>
2924 template <
class TA,
class ShapeA,
class TB,
class ShapeB>
2933 if (a != b) { result =
false; }
2937 template <
class TA,
class ShapeA,
class TB,
class ShapeB,
class AllocB>
2941 template <
class TA,
class ShapeA,
class AllocA,
class TB,
class ShapeB>
2945 template <
class TA,
class ShapeA,
class AllocA,
class TB,
class ShapeB,
class AllocB>
2952 template <
class NewShape,
class T,
class OldShape>
2956 template <
class NewShape,
class T,
class OldShape,
class Allocator>
2958 return convert_shape<NewShape>(a.
ref());
2960 template <
class NewShape,
class T,
class OldShape,
class Allocator>
2962 return convert_shape<NewShape>(a.cref());
2967 template <
class U,
class T,
class Shape,
class = std::enable_if_t<sizeof(T) == sizeof(U)>>
2971 template <
class U,
class T,
class Shape,
class Alloc,
2972 class = std::enable_if_t<sizeof(T) == sizeof(U)>>
2974 return reinterpret<U>(a.
ref());
2976 template <
class U,
class T,
class Shape,
class Alloc,
2977 class = std::enable_if_t<sizeof(T) == sizeof(U)>>
2979 return reinterpret<const U>(a.cref());
2984 template <
class U,
class T,
class Shape>
2991 template <
class NewShape,
class T,
class OldShape>
2995 assert(result.shape().is_subset_of(a.
shape(), -offset));
2998 template <
class NewShape,
class T,
class OldShape,
class Allocator>
3003 template <
class NewShape,
class T,
class OldShape,
class Allocator>
3013 template <
typename NewShape,
typename T,
typename OldShape,
typename Alloc>
3019 std::is_trivial<T>::value,
"move_reinterpret_shape is broken for non-trivial types.");
3020 assert(new_shape.is_subset_of(from.shape(), offset));
3022 assert(result.alloc_ == from.get_allocator());
3025 swap(result.buffer_, from.buffer_);
3026 swap(result.buffer_size_, from.buffer_size_);
3027 swap(result.base_, from.base_);
3028 result.shape_ = new_shape;
3029 from.shape_ = OldShape();
3030 result.base_ += offset;
3034 template <
typename NewShape,
typename T,
typename OldShape,
typename Alloc>
3037 NewShape new_shape = convert_shape<NewShape>(from.shape());
3044 template <
size_t... DimIndices,
class T,
class OldShape,
3045 class = internal::enable_if_permutation<OldShape::rank(), DimIndices...>>
3049 template <
size_t... DimIndices,
class T,
class OldShape,
class Allocator,
3050 class = internal::enable_if_permutation<OldShape::rank(), DimIndices...>>
3054 template <
size_t... DimIndices,
class T,
class OldShape,
class Allocator,
3055 class = internal::enable_if_permutation<OldShape::rank(), DimIndices...>>
3059 template <
size_t... DimIndices,
class T,
class OldShape>
3063 template <
size_t... DimIndices,
class T,
class OldShape,
class Allocator>
3067 template <
size_t... DimIndices,
class T,
class OldShape,
class Allocator>
3076 template <
class T,
size_t N,
size_t Alignment = alignof(T),
class BaseAlloc = std::allocator<T>>
3078 alignas(Alignment)
char buffer[N *
sizeof(T)];
3083 using value_type = T;
3085 using propagate_on_container_copy_assignment = std::false_type;
3086 using propagate_on_container_move_assignment = std::false_type;
3087 using propagate_on_container_swap = std::false_type;
3094 template <
class U,
size_t U_N,
size_t U_A,
class U_BaseAlloc>
3096 : allocated(
false) {}
3107 alloc = std::move(
move.alloc);
3111 value_type* allocate(
size_t n) {
3112 if (!allocated && n <= N) {
3114 return reinterpret_cast<value_type*
>(&buffer[0]);
3116 return std::allocator_traits<BaseAlloc>::allocate(alloc, n);
3119 void deallocate(value_type* ptr,
size_t n) noexcept {
3120 if (ptr == reinterpret_cast<value_type*>(&buffer[0])) {
3124 std::allocator_traits<BaseAlloc>::deallocate(alloc, ptr, n);
3128 template <
class U,
size_t U_N,
size_t U_A>
3130 if (a.allocated || b.allocated) {
3131 return &a.buffer[0] == &b.buffer[0];
3133 return a.alloc == b.alloc;
3137 template <
class U,
size_t U_N,
size_t U_A>
3147 template <
class BaseAlloc>
3150 using value_type =
typename std::allocator_traits<BaseAlloc>::value_type;
3152 using propagate_on_container_copy_assignment =
3153 typename std::allocator_traits<BaseAlloc>::propagate_on_container_copy_assignment;
3154 using propagate_on_container_move_assignment =
3155 typename std::allocator_traits<BaseAlloc>::propagate_on_container_move_assignment;
3156 using propagate_on_container_swap =
3157 typename std::allocator_traits<BaseAlloc>::propagate_on_container_swap;
3160 return std::allocator_traits<BaseAlloc>::select_on_container_copy_construction(alloc);
3163 value_type* allocate(
size_t n) {
return std::allocator_traits<BaseAlloc>::allocate(*
this, n); }
3164 void deallocate(value_type* p,
size_t n) noexcept {
3165 return std::allocator_traits<BaseAlloc>::deallocate(*
this, p, n);
3170 template <
class... Args>
3171 NDARRAY_INLINE
void construct(value_type* ptr, Args&&... args) {
3173 if (
sizeof...(Args) > 0) {
3174 std::allocator_traits<BaseAlloc>::construct(*
this, ptr, std::forward<Args>(args)...);
3178 template <
class OtherBaseAlloc>
3179 friend bool operator==(
3181 return static_cast<const BaseAlloc&
>(a) == static_cast<const OtherBaseAlloc&>(b);
3183 template <
class OtherBaseAlloc>
3184 friend bool operator!=(
3186 return static_cast<const BaseAlloc&
>(a) != static_cast<const OtherBaseAlloc&>(b);
3192 template <class T, class = std::enable_if_t<std::is_trivial<T>::value>>
3197 template <
class T,
size_t N,
size_t Alignment =
sizeof(T),
3198 class = std::enable_if_t<std::is_trivial<T>::value>>
3203 #endif // NDARRAY_ARRAY_H NDARRAY_HOST_DEVICE Shape & shape()
Definition: array.h:2105
decltype(make_shape_from_tuple(internal::tuple_of_n< dim<>, Rank >())) shape_of_rank
Definition: array.h:1601
index_iterator(index_t i)
Definition: array.h:188
NDARRAY_HOST_DEVICE bool is_one_to_one() const
Definition: array.h:1249
void swap(array< T, Shape, Alloc > &a, array< T, Shape, Alloc > &b)
Definition: array.h:2724
NDARRAY_INLINE NDARRAY_HOST_DEVICE bool is_in_range(index_t at) const
Definition: array.h:305
NDARRAY_HOST_DEVICE auto operator[](const std::tuple< Args... > &args) const
Definition: array.h:1184
NDARRAY_HOST_DEVICE ShapeDst convert_shape(const ShapeSrc &src)
Definition: array.h:1650
array(const Shape &shape, const Alloc &alloc)
Definition: array.h:2343
const auto & i() const
Definition: array.h:2628
constexpr index_t dynamic
Definition: array.h:99
NDARRAY_HOST_DEVICE reference operator[](const index_type &indices) const
Definition: array.h:2066
void copy(const array_ref< TSrc, ShapeSrc > &src, const array_ref< TDst, ShapeDst > &dst)
Definition: array.h:2733
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t flat_offset(index_t at) const
Definition: array.h:480
NDARRAY_HOST_DEVICE internal::split_result< InnerExtent > split(const interval< Min, Extent > &v)
Definition: array.h:614
index_of_rank< rank()> index_type
Definition: array.h:1076
pointer data()
Definition: array.h:2568
NDARRAY_HOST_DEVICE array_ref< T, Shape > make_array_ref(T *base, const Shape &shape)
Definition: array.h:1970
static NDARRAY_HOST_DEVICE void for_each_value(const ShapeSrc &shape_src, TSrc src, const ShapeDst &shape_dst, TDst dst, Fn &&fn)
Definition: array.h:1901
NDARRAY_HOST_DEVICE void generate(const array_ref< T, Shape > &dst, Generator &&g)
Definition: array.h:2886
NDARRAY_HOST_DEVICE pointer data() const
Definition: array.h:2100
NDARRAY_HOST_DEVICE index_t flat_min() const
Definition: array.h:1224
auto make_copy(const array_ref< T, ShapeSrc > &src, const Alloc &alloc=Alloc())
Definition: array.h:2760
NDARRAY_HOST_DEVICE const const_array_ref< T, Shape > cref() const
Definition: array.h:2177
NDARRAY_HOST_DEVICE bool operator!=(const array_ref &other) const
Definition: array.h:2159
auto make_array(const Shape &shape, const Alloc &alloc=Alloc())
Definition: array.h:2713
NDARRAY_HOST_DEVICE void for_each_value(Fn &&fn) const
Definition: array.h:2091
static NDARRAY_HOST_DEVICE void for_each_value(const Shape &shape, Ptr base, Fn &&fn)
Definition: array.h:1882
constexpr index_t unresolved
Definition: array.h:103
NDARRAY_HOST_DEVICE bool is_compatible(const ShapeSrc &src)
Definition: array.h:1636
NDARRAY_HOST_DEVICE void resolve()
Definition: array.h:1154
void clear()
Definition: array.h:2587
const interval< 0,-1 > all
Definition: array.h:363
void set_shape(const Shape &new_shape, index_t offset=0)
Definition: array.h:2618
void swap(array &other)
Definition: array.h:2659
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t stride() const
Definition: array.h:470
Shape shape_type
Definition: array.h:2226
void assign(const array &other)
Definition: array.h:2463
decltype(make_shape_from_tuple(internal::make_compact_dims< 1 >(dim< 0, Extents >()...))) fixed_dense_shape
Definition: array.h:1630
NDARRAY_HOST_DEVICE pointer base() const
Definition: array.h:2096
NDARRAY_INLINE NDARRAY_HOST_DEVICE interval range(index_t begin, index_t end)
Definition: array.h:344
NDARRAY_HOST_DEVICE interval(const interval< CopyMin, CopyExtent > &other)
Definition: array.h:281
NDARRAY_HOST_DEVICE auto & dim()
Definition: array.h:1196
array()
Definition: array.h:2331
void for_each_value(Fn &&fn)
Definition: array.h:2554
NDARRAY_UNIQUE NDARRAY_HOST_DEVICE void for_each_index(const Shape &s, Fn &&fn)
Definition: array.h:1932
NDARRAY_HOST_DEVICE auto operator[](const std::tuple< Args... > &args) const
Definition: array.h:2079
NDARRAY_HOST_DEVICE void transform_indices(const array_ref< T, Shape > &dst, Fn &&fn)
Definition: array.h:2913
NDARRAY_HOST_DEVICE index_t rows() const
Definition: array.h:2151
NDARRAY_HOST_DEVICE auto & x()
Definition: array.h:1274
array(const array &other, const Alloc &alloc)
Definition: array.h:2364
index_t rows() const
Definition: array.h:2648
static constexpr bool is_scalar()
Definition: array.h:1073
static constexpr size_t rank()
Definition: array.h:2012
static constexpr bool is_scalar()
Definition: array.h:2236
NDARRAY_HOST_DEVICE auto & i()
Definition: array.h:1265
NDARRAY_HOST_DEVICE auto reorder(const shape< Dims... > &shape)
Definition: array.h:1344
NDARRAY_HOST_DEVICE bool operator==(const shape< OtherDims... > &other) const
Definition: array.h:1299
NDARRAY_HOST_DEVICE size_type size() const
Definition: array.h:1232
NDARRAY_HOST_DEVICE dim(const dim< CopyMin, CopyExtent, CopyStride > &other)
Definition: array.h:442
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t clamp(index_t x, index_t min, index_t max)
Definition: array.h:376
void move(const array_ref< TSrc, ShapeSrc > &src, const array_ref< TDst, ShapeDst > &dst)
Definition: array.h:2804
std::tuple< Dims... > dims_type
Definition: array.h:1067
NDARRAY_HOST_DEVICE bool empty() const
Definition: array.h:1238
array_ref< U, Shape > reinterpret_const(const const_array_ref< T, Shape > &a)
Definition: array.h:2985
reference operator[](const index_type &indices)
Definition: array.h:2520
NDARRAY_HOST_DEVICE index_t width() const
Definition: array.h:2145
NDARRAY_HOST_DEVICE auto make_shape(Dims...dims)
Definition: array.h:1040
Alloc allocator_type
Definition: array.h:2215
T value_type
Definition: array.h:2219
T value_type
Definition: array.h:2001
NDARRAY_HOST_DEVICE array_ref(const array_ref< T, OtherShape > &other)
Definition: array.h:2056
NDARRAY_HOST_DEVICE bool is_subset_of(const OtherShape &other, index_t offset) const
Definition: array.h:1258
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t max() const
Definition: array.h:301
array< T, NewShape, Alloc > move_reinterpret_shape(array< T, OldShape, Alloc > &&from, const NewShape &new_shape, index_t offset=0)
Definition: array.h:3014
array(const Shape &shape, const T &value, const Alloc &alloc)
Definition: array.h:2336
array(array &&other)
Definition: array.h:2371
NDARRAY_HOST_DEVICE index_iterator begin() const
Definition: array.h:322
NDARRAY_HOST_DEVICE auto & i()
Definition: array.h:2123
NDARRAY_HOST_DEVICE interval(index_t min, index_t extent)
Definition: array.h:265
static constexpr bool is_scalar()
Definition: array.h:2015
internal::tuple_of_n< index_t, Rank > index_of_rank
Definition: array.h:1054
NDARRAY_HOST_DEVICE array_ref< T, NewShape > reinterpret_shape(const array_ref< T, OldShape > &a, const NewShape &new_shape, index_t offset=0)
Definition: array.h:2992
auto make_move(const array_ref< T, ShapeSrc > &src, const ShapeDst &shape, const Alloc &alloc=Alloc())
Definition: array.h:2836
NDARRAY_HOST_DEVICE bool operator==(const dim< OtherMin, OtherExtent, OtherStride > &other) const
Definition: array.h:487
NDARRAY_HOST_DEVICE index_iterator begin(const interval< Min, Extent > &d)
Definition: array.h:367
Shape shape_type
Definition: array.h:2006
NDARRAY_HOST_DEVICE auto & x()
Definition: array.h:2132
NDARRAY_HOST_DEVICE bool is_resolved() const
Definition: array.h:1157
NDARRAY_HOST_DEVICE bool is_compact() const
Definition: array.h:1243
auto make_compact_move(const array_ref< T, Shape > &src, const Alloc &alloc=Alloc())
Definition: array.h:2860
void assign(Shape shape, const T &value)
Definition: array.h:2504
NDARRAY_HOST_DEVICE const nda::dim dim(size_t d) const
Definition: array.h:1207
const auto & x() const
Definition: array.h:2634
const Alloc & get_allocator() const
Definition: array.h:2517
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t operator*() const
Definition: array.h:191
NDARRAY_HOST_DEVICE auto transpose(const shape< Dims... > &shape)
Definition: array.h:1332
NDARRAY_HOST_DEVICE array_ref< U, Shape > reinterpret(const array_ref< T, Shape > &a)
Definition: array.h:2968
std::ptrdiff_t index_t
Definition: array.h:87
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t min() const
Definition: array.h:293
array & operator=(array &&other)
Definition: array.h:2437
auto operator[](const std::tuple< Args... > &args)
Definition: array.h:2535
NDARRAY_HOST_DEVICE bool equal(const array_ref< TA, ShapeA > &a, const array_ref< TB, ShapeB > &b)
Definition: array.h:2925
NDARRAY_INLINE NDARRAY_HOST_DEVICE bool is_in_range(const interval< OtherMin, OtherExtent > &at) const
Definition: array.h:311
array & operator=(const array &other)
Definition: array.h:2413
NDARRAY_HOST_DEVICE bool operator==(const interval< OtherMin, OtherExtent > &other) const
Definition: array.h:329
NDARRAY_HOST_DEVICE dims_type & dims()
Definition: array.h:1213
NDARRAY_HOST_DEVICE void set_shape(const Shape &new_shape, index_t offset=0)
Definition: array.h:2190
bool operator!=(const array &other) const
Definition: array.h:2654
NDARRAY_HOST_DEVICE bool is_in_range(const std::tuple< Args... > &args) const
Definition: array.h:1163
index_t width() const
Definition: array.h:2642
array(array &&other, const Alloc &alloc)
Definition: array.h:2391
NDARRAY_HOST_DEVICE index_t operator[](const index_type &indices) const
Definition: array.h:1172
NDARRAY_HOST_DEVICE auto make_compact(const Shape &s)
Definition: array.h:1620
NDARRAY_HOST_DEVICE index_iterator end() const
Definition: array.h:324
static NDARRAY_HOST_DEVICE void for_each_index(const Shape &shape, Fn &&fn)
Definition: array.h:1874
NDARRAY_HOST_DEVICE index_t rows() const
Definition: array.h:1293
array(const array &other)
Definition: array.h:2356
NDARRAY_HOST_DEVICE void fill(const array_ref< T, Shape > &dst, const T &value)
Definition: array.h:2874
static constexpr size_t rank()
Definition: array.h:1070
NDARRAY_HOST_DEVICE index_t width() const
Definition: array.h:1287
NDARRAY_HOST_DEVICE bool is_explicitly_compatible(const ShapeSrc &src)
Definition: array.h:1661
static constexpr size_t rank()
Definition: array.h:2233
const Shape & shape() const
Definition: array.h:2572
void reshape(Shape new_shape)
Definition: array.h:2596
NDARRAY_HOST_DEVICE dim(index_t min, index_t extent, index_t stride=DefaultStride)
Definition: array.h:422
NDARRAY_HOST_DEVICE array_ref(pointer base=nullptr, const Shape &shape=Shape())
Definition: array.h:2040
array_ref< T, Shape > ref()
Definition: array.h:2675
pointer base()
Definition: array.h:2563
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t extent() const
Definition: array.h:296
decltype(internal::make_default_dense_shape< Rank >()) dense_shape
Definition: array.h:1606
auto make_compact_copy(const array_ref< T, Shape > &src, const Alloc &alloc=Alloc())
Definition: array.h:2791
NDARRAY_HOST_DEVICE void transform_index(const array_ref< T, Shape > &dst, Fn &&fn)
Definition: array.h:2899
NDARRAY_HOST_DEVICE shape(const std::tuple< OtherDims... > &other)
Definition: array.h:1128