12 #ifndef HELIB_MATRIX_H
13 #define HELIB_MATRIX_H
23 #include <type_traits>
24 #include <initializer_list>
26 #include <NTL/BasicThreadPool.h>
28 #include "assertions.h"
30 #include "zeroValue.h"
41 template <std::
size_t N>
50 template <
typename Iter1,
typename Iter2>
52 const Iter1& lastLength,
53 const Iter2& firstStride,
54 const Iter2& lastStride,
55 const std::vector<long>& st) :
58 std::copy(firstLength, lastLength, this->lengths.begin());
59 std::copy(firstStride, lastStride, this->strides.begin());
60 this->size = (std::accumulate(
lengths.begin(),
63 std::multiplies<std::size_t>()));
67 template <
typename Iter1,
typename Iter2>
69 const Iter1& lastLength,
70 const Iter2& firstStride,
71 const Iter2& lastStride,
75 std::copy(firstLength, lastLength, this->lengths.begin());
76 std::copy(firstStride, lastStride, this->strides.begin());
77 this->size = (std::accumulate(
lengths.begin(),
80 std::multiplies<std::size_t>()));
84 template <
typename... Dims>
90 std::multiplies<
std::size_t>()))
93 this->strides.back() = 1;
99 template <
typename... Dims>
102 static_assert(
sizeof...(Dims) == N,
"Wrong number of indices given.");
104 std::array<std::size_t, N> args{std::size_t(dims)...};
106 for (
long i = 0; i < long(N); ++i) {
107 if (args[i] >= this->lengths[i]) {
109 "Index given: " + std::to_string(args[i]) +
110 ". Max value is: " + std::to_string(this->lengths[i]));
114 if (this->start.size() == 1) {
115 return std::inner_product(args.begin(),
117 this->strides.begin(),
121 return std::inner_product(args.begin(),
123 this->strides.begin(),
124 this->start.at(args.at(1)) *
strides.back());
128 std::size_t
order()
const {
return N; }
134 else if (this->size == rhs.
size && this->start == rhs.
start &&
147 template <
typename T, std::
size_t N>
153 std::shared_ptr<std::vector<T>> elements_ptr;
154 bool full_view =
true;
162 template <
typename U = T,
164 typename std::enable_if_t<
165 !std::is_convertible<U, std::size_t>::value>* =
nullptr>
168 elements_ptr(
std::make_shared<
std::vector<T>>(subscripts.
size, obj))
171 template <
typename... Dims>
173 subscripts{
std::size_t(
dims)...},
174 elements_ptr(
std::make_shared<
std::vector<T>>(subscripts.
size))
178 Tensor(std::initializer_list<std::vector<T>> lst) :
179 subscripts{lst.
size(), lst.begin()->
size()},
181 std::make_shared<
std::vector<T>>(lst.
size() * lst.begin()->
size()))
183 int column_length = lst.begin()->size();
185 for (
const auto& v : lst) {
186 if (column_length !=
long(v.size()))
188 "Column dimensions do not match on initializer list.");
192 this->elements_ptr->begin() + (column_length * cnt++));
197 const std::shared_ptr<std::vector<T>>& elems) :
198 subscripts(ts), elements_ptr(elems), full_view(false)
217 copy.elements_ptr = std::make_shared<std::vector<T>>(*elements_ptr);
222 std::size_t
order()
const {
return N; }
224 template <
typename... Args>
227 return this->elements_ptr->at(subscripts(args...));
230 template <
typename... Args>
233 return this->elements_ptr->at(subscripts(args...));
236 std::size_t
size()
const {
return this->subscripts.
size; }
238 std::size_t
dims(
int i)
const {
return this->subscripts.
lengths.at(i); }
246 }
else if (this->subscripts != rhs.subscripts) {
249 *elements_ptr == *(rhs.elements_ptr)) {
252 for (
size_t i = 0; i <
dims(0); ++i)
253 for (
size_t j = 0; j <
dims(1); ++j)
254 if (this->
operator()(i, j) != rhs(i, j)) {
266 this->subscripts.lengths.end(),
267 this->subscripts.strides.begin() + 1,
268 this->subscripts.strides.end(),
269 i * this->subscripts.strides.at(0));
270 return Tensor<T, N - 1>(ts, this->elements_ptr);
277 this->subscripts.lengths.end() - 1,
278 this->subscripts.strides.begin(),
279 this->subscripts.strides.end() - 1,
281 return Tensor<T, N - 1>(ts, this->elements_ptr);
286 if (i >= this->
dims(0)) {
288 ". Max value is: " + std::to_string(this->
dims(0)));
300 if (j >= this->
dims(1)) {
302 ". Max value is: " + std::to_string(this->
dims(1)));
306 ts.
start = {
static_cast<long>(j)};
316 for (
const auto& j : js)
317 assertInRange<LogicError>(
320 static_cast<long>(this->
dims(1)),
321 "Index for column does not exist. Given index " + std::to_string(j) +
322 ". Expected index in " +
"range [0, " +
323 std::to_string(this->
dims(1)) +
").");
326 std::vector<std::size_t> lengths = {this->
dims(0), js.size()};
328 std::vector<long> offsets(js);
329 for (std::size_t i = 0; i < offsets.size(); ++i) {
335 this->subscripts.strides.begin(),
336 this->subscripts.strides.end(),
342 template <
typename T2>
344 std::function<T&(T&,
const T2&)> operation)
347 std::array<std::size_t, N> rhs_subscripts;
348 for (std::size_t i = 0; i < N; ++i) {
349 rhs_subscripts[i] = rhs.
dims(i);
353 if (!std::equal(this->subscripts.
lengths.begin(),
354 this->subscripts.lengths.end(),
355 rhs_subscripts.begin())) {
360 if (
static_cast<const void*
>(&this->
data()) ==
361 static_cast<const void*
>(&rhs.
data())) {
366 if (this->full_view && rhs.
fullView()) {
367 const std::vector<T2>& rhs_v = rhs.
data();
368 for (std::size_t i = 0; i < this->elements_ptr->size(); ++i) {
369 operation((*this->elements_ptr)[i], rhs_v[i]);
373 for (std::size_t j = 0; j < this->
dims(1); ++j) {
374 for (std::size_t i = 0; i < this->
dims(0); ++i) {
375 operation(this->
operator()(i, j), rhs(i, j));
383 template <
typename T2>
386 return entrywiseOperation<T2>(
388 [](
auto& lhs,
const auto& rhs) -> decltype(
auto) {
393 template <
typename T2>
396 return entrywiseOperation<T2>(
398 [](
auto& lhs,
const auto& rhs) -> decltype(
auto) {
403 template <
typename T2>
406 return entrywiseOperation<T2>(
408 [](
auto& lhs,
const auto& rhs) -> decltype(
auto) {
416 if (this->full_view) {
417 NTL_EXEC_RANGE(
long(this->elements_ptr->size()), first, last)
418 for (
long i = first; i < last; ++i)
419 fn((*elements_ptr)[i]);
424 NTL_EXEC_RANGE(this->
dims(1), first, last)
425 for (
long j = first; j < last; ++j)
426 for (std::size_t i = 0; i < this->
dims(0); ++i)
427 fn(this->
operator()(i, j));
437 ret.inPlaceTranspose();
447 std::vector<int> permutation(
size());
448 std::iota(permutation.begin(), permutation.end(), 0);
449 for (
int& num : permutation)
453 std::vector<std::vector<int>> cycles;
454 std::vector<bool> seen(
size(),
false);
455 int num_processed = 0;
457 while (num_processed <
long(
size())) {
459 std::vector<int> cycle = {current_pos};
460 seen[current_pos] =
true;
461 while (permutation.at(cycle.back()) != cycle.front()) {
462 seen[permutation.at(cycle.back())] =
true;
463 cycle.push_back(permutation.at(cycle.back()));
465 num_processed += cycle.size();
466 cycles.push_back(std::move(cycle));
468 while (current_pos <
long(
size()) && seen[current_pos])
472 std::vector<std::pair<int, int>> swaps;
473 for (
const auto& cycle : cycles)
474 if (cycle.size() >= 2)
475 for (
int i = cycle.size() - 1; i > 0; --i)
476 swaps.emplace_back(cycle[i], cycle[i - 1]);
478 for (
const auto& swap : swaps)
479 std::swap(elements_ptr->at(swap.first), elements_ptr->at(swap.second));
482 std::make_shared<std::vector<T>>(
size(),
data().front());
495 j = (j + 1) % subscripts.
lengths[1];
497 i = (i + 1) % subscripts.
lengths[0];
501 new_i = (new_i + 1) % subscripts.
lengths[1];
503 new_j = (new_j + 1) % subscripts.
lengths[0];
505 elements_ptr = new_elements;
508 std::reverse(subscripts.
lengths.begin(), subscripts.
lengths.end());
510 subscripts.
start = {0};
514 const std::vector<T>&
data()
const {
return *this->elements_ptr; }
518 template <
typename T,
520 typename std::enable_if_t<
521 std::is_convertible<T, std::size_t>::value>* =
nullptr>
524 HELIB_NTIMER_START(MatrixMultiplicationConv);
535 NTL_EXEC_RANGE(M1.
dims(0), first, last)
536 for (
long i = first; i < last; ++i)
538 for (std::size_t j = 0; j < M2.
dims(1); ++j)
539 for (std::size_t k = 0; k < M2.
dims(0); ++k) {
546 HELIB_NTIMER_STOP(MatrixMultiplicationConv);
551 template <
typename T,
553 typename std::enable_if_t<
554 !std::is_convertible<T, std::size_t>::value>* =
nullptr>
555 inline Tensor<T, 2>
operator*(
const Tensor<T, 2>& M1,
const Tensor<T2, 2>& M2)
557 HELIB_NTIMER_START(MatrixMultiplicationNotConv);
562 if (M1.dims(1) != M2.dims(0)) {
564 "The number of columns in left matrix (" + std::to_string(M1.dims(1)) +
565 ") do not match the number of rows of the right matrix (" +
566 std::to_string(M2.dims(0)) +
").");
570 NTL_EXEC_RANGE(M1.dims(0), first, last)
571 for (
long i = first; i < last; ++i)
574 for (std::size_t j = 0; j < M2.dims(1); ++j)
575 for (std::size_t k = 0; k < M2.dims(0); ++k) {
582 HELIB_NTIMER_STOP(MatrixMultiplicationNotConv);
586 template <
typename T,
typename T2>
594 template <
typename T,
typename T2>
603 template <
typename T>
606 template <
typename T>
610 template <
typename T>
615 std::vector<Matrix<T>> columns;
621 template <
typename T>
624 for (std::size_t i = 0; i < M.
dims(0); ++i) {
625 for (std::size_t j = 0; j < M.
dims(1); ++j)
626 out << M(i, j) <<
" ";
Inherits from Exception and std::logic_error.
Definition: exceptions.h:68
MatrixView(const std::initializer_list< Matrix< T >> lst)
Definition: Matrix.h:618
Inherits from Exception and std::out_of_range.
Definition: exceptions.h:86
Tensor< T, N > & entrywiseOperation(const Tensor< T2, N > &rhs, std::function< T &(T &, const T2 &)> operation)
Definition: Matrix.h:343
Tensor< T, 2 > transpose() const
Definition: Matrix.h:434
Tensor< T, N > & operator+=(const Tensor< T2, N > &rhs)
Definition: Matrix.h:384
Tensor< T, 2 > & inPlaceTranspose()
Definition: Matrix.h:442
std::size_t order() const
Definition: Matrix.h:222
Tensor(Tensor &&other)=default
Tensor< T, N > columns(const std::vector< long > &js) const
Definition: Matrix.h:313
bool operator!=(const Tensor &rhs) const
Definition: Matrix.h:261
Tensor(Dims... dims)
Definition: Matrix.h:172
Tensor< T, 2 > getColumn(std::size_t j) const
Definition: Matrix.h:298
bool fullView() const
Definition: Matrix.h:240
std::size_t dims(int i) const
Definition: Matrix.h:238
Tensor< T, 2 > getRow(std::size_t i) const
Definition: Matrix.h:284
const std::vector< T > & data() const
Definition: Matrix.h:514
Tensor(std::initializer_list< std::vector< T >> lst)
Definition: Matrix.h:178
Tensor< T, N > & operator-=(const Tensor< T2, N > &rhs)
Definition: Matrix.h:394
std::size_t size() const
Definition: Matrix.h:236
T & operator()(Args... args)
Definition: Matrix.h:225
Tensor< T, N > deepCopy() const
Definition: Matrix.h:211
Tensor & operator=(Tensor &&rhs)=default
Tensor< T, N > & hadamard(const Tensor< T2, N > &rhs)
Definition: Matrix.h:404
Tensor(const Tensor &other)=default
Tensor< T, N - 1 > column(std::size_t j) const
Definition: Matrix.h:273
const T & operator()(Args... args) const
Definition: Matrix.h:231
Tensor< T, N > & apply(std::function< void(T &x)> fn)
Definition: Matrix.h:413
bool operator==(const Tensor &rhs) const
Definition: Matrix.h:242
Tensor(const TensorSlice< N > &ts, const std::shared_ptr< std::vector< T >> &elems)
Definition: Matrix.h:196
Tensor(const T &obj, Dims... dims)
Definition: Matrix.h:166
Tensor & operator=(const Tensor &rhs)=default
Tensor< T, N - 1 > row(std::size_t i) const
Definition: Matrix.h:263
Definition: apiAttributes.h:21
void printMatrix(const Matrix< T > &M, std::ostream &out=std::cout)
Definition: Matrix.h:622
Tensor< T, 2 > operator*(const Tensor< T, 2 > &M1, const Tensor< T2, 2 > &M2)
Definition: Matrix.h:522
Tensor< T, 2 > operator+(const Tensor< T, 2 > &M1, const Tensor< T2, 2 > &M2)
Definition: Matrix.h:595
T zeroValue(const T &x)
Given an object x return a zero object of the same type.
Definition: zeroValue.h:29
Tensor< T, 2 > operator-(const Tensor< T, 2 > &M1, const Tensor< T2, 2 > &M2)
Definition: Matrix.h:587
bool operator!=(const TensorSlice &rhs) const
Definition: Matrix.h:141
std::array< std::size_t, N > strides
Definition: Matrix.h:45
std::array< std::size_t, N > lengths
Definition: Matrix.h:44
bool operator==(const TensorSlice &rhs) const
Definition: Matrix.h:130
std::vector< long > start
Definition: Matrix.h:46
TensorSlice(const Iter1 &firstLength, const Iter1 &lastLength, const Iter2 &firstStride, const Iter2 &lastStride, const std::vector< long > &st)
Definition: Matrix.h:51
std::size_t operator()(Dims... dims) const
Definition: Matrix.h:100
std::size_t order() const
Definition: Matrix.h:128
TensorSlice(const Iter1 &firstLength, const Iter1 &lastLength, const Iter2 &firstStride, const Iter2 &lastStride, unsigned long pos)
Definition: Matrix.h:68
TensorSlice(Dims... dims)
Definition: Matrix.h:85
std::size_t size
Definition: Matrix.h:47