libsemigroups  v3.6.0
C++ library for semigroups and monoids
Loading...
Searching...
No Matches
matrix-common.hpp
1//
2// libsemigroups - C++ library for semigroups and monoids
3// Copyright (C) 2026 James D. Mitchell
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU General Public License for more details.
14//
15// You should have received a copy of the GNU General Public License
16// along with this program. If not, see <http://www.gnu.org/licenses/>.
17//
18
19#ifndef LIBSEMIGROUPS_DETAIL_MATRIX_COMMON_HPP_
20#define LIBSEMIGROUPS_DETAIL_MATRIX_COMMON_HPP_
21
22#include <algorithm> // for copy, equal, fill
23#include <cstddef> // for size_t
24#include <initializer_list> // for initializer_list
25#include <iosfwd> // for ostream
26#include <iterator> // for distance
27#include <numeric> // for inner_product
28#include <string> // for basic_string
29#include <string_view> // for basic_string_view
30#include <type_traits> // for is_same_v, ena...
31#include <utility> // for swap, make_pair
32#include <vector> // for vector
33
34#include "libsemigroups/adapters.hpp" // for Hash
35#include "libsemigroups/constants.hpp" // for NEGATIVE_INFINITY
36#include "libsemigroups/debug.hpp" // for LIBSEMIGROUPS_...
37#include "libsemigroups/exception.hpp" // for LIBSEMIGROUPS_...
38#include "libsemigroups/is-matrix.hpp" // for IsMatrix, Matr...
39
40#include "fmt.hpp" // for format
41#include "matrix-exceptions.hpp" // for throw_if_bad_dim
42#include "string.hpp" // for to_string
43
44namespace libsemigroups::detail {
45 // This function is required for exceptions and to_human_readable_repr, so
46 // that if we encounter an entry of a matrix (Scalar type), then it can be
47 // printed correctly. If we just did fmt::format("{}", val) and val ==
48 // POSITIVE_INFINITY, but the type of val is, say, size_t, then this
49 // wouldn't use the formatter for PositiveInfinity.
50 //
51 // Also in fmt v11.1.4 the custom formatter for POSITIVE_INFINITY and
52 // NEGATIVE_INFINITY stopped working (and I wasn't able to figure out why)
53 template <typename Scalar>
54 [[nodiscard]] std::string entry_repr(Scalar a);
55
56 template <typename Container,
57 typename Subclass,
58 typename TRowView,
59 typename Semiring = void>
60 class MatrixCommon : MatrixPolymorphicBase {
61 public:
63 // MatrixCommon - Aliases - public
65
66 using scalar_type = typename Container::value_type;
67 using scalar_reference = typename Container::reference;
68 using scalar_const_reference = typename Container::const_reference;
69 using semiring_type = Semiring;
70
71 using container_type = Container;
72 using iterator = typename Container::iterator;
73 using const_iterator = typename Container::const_iterator;
74
75 using RowView = TRowView;
76
77 [[nodiscard]] scalar_type scalar_one() const noexcept {
78 return static_cast<Subclass const*>(this)->one_impl();
79 }
80
81 [[nodiscard]] scalar_type scalar_zero() const noexcept {
82 return static_cast<Subclass const*>(this)->zero_impl();
83 }
84
85 [[nodiscard]] Semiring const* semiring() const noexcept {
86 return static_cast<Subclass const*>(this)->semiring_impl();
87 }
88
89 private:
91 // MatrixCommon - Semiring arithmetic - private
93
94 [[nodiscard]] scalar_type plus_no_checks(scalar_type x,
95 scalar_type y) const noexcept {
96 return static_cast<Subclass const*>(this)->plus_no_checks_impl(y, x);
97 }
98
99 [[nodiscard]] scalar_type product_no_checks(scalar_type x,
100 scalar_type y) const noexcept {
101 return static_cast<Subclass const*>(this)->product_no_checks_impl(y, x);
102 }
103
104 protected:
106 // MatrixCommon - Container functions - protected
108
109 void resize(size_t r, size_t c) {
110 if constexpr (std::is_same_v<container_type, std::vector<scalar_type>>) {
111 _container.resize(r * c);
112 }
113 }
114
115 private:
116 // not noexcept because resize isn't
117 template <typename T>
118 void init(T const& m);
119
120 // not noexcept because init isn't
121 void
122 init(std::initializer_list<std::initializer_list<scalar_type>> const& m) {
123 init<std::initializer_list<std::initializer_list<scalar_type>>>(m);
124 }
125
126 public:
128 // MatrixCommon - Constructors + destructor - public
130
131 // none of the constructors are noexcept because they allocate
132 MatrixCommon() = default;
133 MatrixCommon(MatrixCommon const&) = default;
134 MatrixCommon(MatrixCommon&&) = default;
135 MatrixCommon& operator=(MatrixCommon const&) = default;
136 MatrixCommon& operator=(MatrixCommon&&) = default;
137
138 ~MatrixCommon() = default;
139
140 explicit MatrixCommon(std::initializer_list<scalar_type> const& c)
141 : MatrixCommon() {
142 resize(1, c.size());
143 std::copy(c.begin(), c.end(), _container.begin());
144 }
145
146 explicit MatrixCommon(std::vector<std::vector<scalar_type>> const& m)
147 : MatrixCommon() {
148 init(m);
149 }
150
151 MatrixCommon(
152 std::initializer_list<std::initializer_list<scalar_type>> const& m)
153 : MatrixCommon() {
154 init(m);
155 }
156
157 public:
158 explicit MatrixCommon(RowView const& rv) : MatrixCommon() {
159 resize(1, rv.size());
160 std::copy(rv.cbegin(), rv.cend(), _container.begin());
161 }
162
163 // not noexcept because mem allocate is required
164 [[nodiscard]] Subclass one() const;
165
167 // Comparison operators
169
170 // not noexcept because apparently vector::operator== isn't
171 [[nodiscard]] bool operator==(MatrixCommon const& that) const {
172 return _container == that._container;
173 }
174
175 // not noexcept because apparently vector::operator== isn't
176 [[nodiscard]] bool operator==(RowView const& that) const {
177 return number_of_rows() == 1
178 && static_cast<RowView>(*static_cast<Subclass const*>(this))
179 == that;
180 }
181
182 // not noexcept because apparently vector::operator< isn't
183 [[nodiscard]] bool operator<(MatrixCommon const& that) const {
184 return _container < that._container;
185 }
186
187 // not noexcept because apparently vector::operator< isn't
188 [[nodiscard]] bool operator<(RowView const& that) const {
189 return number_of_rows() == 1
190 && static_cast<RowView>(*static_cast<Subclass const*>(this))
191 < that;
192 }
193
194 // not noexcept because operator== isn't
195 template <typename T>
196 [[nodiscard]] bool operator!=(T const& that) const {
197 static_assert(IsMatrix<T> || std::is_same_v<T, RowView>);
198 return !(*this == that);
199 }
200
201 // not noexcept because operator< isn't
202 template <typename T>
203 [[nodiscard]] bool operator>(T const& that) const {
204 static_assert(IsMatrix<T> || std::is_same_v<T, RowView>);
205 return that < *this;
206 }
207
208 // not noexcept because operator< isn't
209 template <typename T>
210 [[nodiscard]] bool operator>=(T const& that) const {
211 static_assert(IsMatrix<T> || std::is_same_v<T, RowView>);
212 return that < *this || that == *this;
213 }
214
215 // not noexcept because operator< isn't
216 template <typename T>
217 [[nodiscard]] bool operator<=(T const& that) const {
218 static_assert(IsMatrix<T> || std::is_same_v<T, RowView>);
219 return *this < that || that == *this;
220 }
221
223 // Attributes
225
226 // not noexcept because vector::operator[] isn't, and neither is
227 // array::operator[]
228 [[nodiscard]] scalar_reference operator()(size_t r, size_t c) {
229 return this->_container[r * number_of_cols() + c];
230 }
231
232 [[nodiscard]] scalar_reference at(size_t r, size_t c) {
233 matrix::throw_if_bad_coords(static_cast<Subclass const&>(*this), r, c);
234 return this->operator()(r, c);
235 }
236
237 // not noexcept because vector::operator[] isn't, and neither is
238 // array::operator[]
239 [[nodiscard]] scalar_const_reference operator()(size_t r, size_t c) const {
240 return this->_container[r * number_of_cols() + c];
241 }
242
243 [[nodiscard]] scalar_const_reference at(size_t r, size_t c) const {
244 matrix::throw_if_bad_coords(static_cast<Subclass const&>(*this), r, c);
245 return this->operator()(r, c);
246 }
247
248 // noexcept because number_of_rows_impl is noexcept
249 [[nodiscard]] size_t number_of_rows() const noexcept {
250 return static_cast<Subclass const*>(this)->number_of_rows_impl();
251 }
252
253 // noexcept because number_of_cols_impl is noexcept
254 [[nodiscard]] size_t number_of_cols() const noexcept {
255 return static_cast<Subclass const*>(this)->number_of_cols_impl();
256 }
257
258 // not noexcept because Hash<T>::operator() isn't
259 [[nodiscard]] size_t hash_value() const {
260 return Hash<Container>()(_container);
261 }
262
264 // Arithmetic operators - in-place
266
267 // not noexcept because memory is allocated
268 void product_inplace_no_checks(Subclass const& A, Subclass const& B);
269 void product_inplace(Subclass const& A, Subclass const& B);
270
271 // not noexcept because iterator increment isn't
272 void operator*=(scalar_type a) {
273 for (auto it = _container.begin(); it < _container.end(); ++it) {
274 *it = product_no_checks(*it, a);
275 }
276 }
277
278 void plus_inplace_no_checks(Subclass const& that);
279
280 // not noexcept because vector::operator[] and array::operator[] aren't
281 void operator+=(Subclass const& that);
282
283 void plus_inplace_no_checks(RowView const& that) {
284 LIBSEMIGROUPS_ASSERT(number_of_rows() == 1);
285 LIBSEMIGROUPS_ASSERT(number_of_cols() == that.size());
286 RowView(*static_cast<Subclass const*>(this)) += that;
287 }
288
289 void operator+=(RowView const& that);
290
291 void operator+=(scalar_type a) {
292 for (auto it = _container.begin(); it < _container.end(); ++it) {
293 *it = plus_no_checks(*it, a);
294 }
295 }
296
297 // TODO(2) implement operator*=(Subclass const&)
298
300 // Arithmetic operators - not in-place
302
303 [[nodiscard]] Subclass plus_no_checks(Subclass const& y) const {
304 Subclass result(*static_cast<Subclass const*>(this));
305 result.plus_inplace_no_checks(y);
306 return result;
307 }
308
309 [[nodiscard]] Subclass operator+(Subclass const& y) const;
310
311 [[nodiscard]] Subclass product_no_checks(Subclass const& y) const {
312 Subclass result(*static_cast<Subclass const*>(this));
313 result.product_inplace_no_checks(*static_cast<Subclass const*>(this), y);
314 return result;
315 }
316
317 // not noexcept because product_inplace_no_checks isn't
318 [[nodiscard]] Subclass operator*(Subclass const& y) const;
319
320 [[nodiscard]] Subclass operator*(scalar_type a) const {
321 Subclass result(*static_cast<Subclass const*>(this));
322 result *= a;
323 return result;
324 }
325
326 [[nodiscard]] Subclass operator+(scalar_type a) const {
327 Subclass result(*static_cast<Subclass const*>(this));
328 result += a;
329 return result;
330 }
331
333 // Iterators
335
336 // noexcept because vector::begin and array::begin are noexcept
337 [[nodiscard]] iterator begin() noexcept {
338 return _container.begin();
339 }
340
341 // noexcept because vector::end and array::end are noexcept
342 [[nodiscard]] iterator end() noexcept {
343 return _container.end();
344 }
345
346 // noexcept because vector::begin and array::begin are noexcept
347 [[nodiscard]] const_iterator begin() const noexcept {
348 return _container.begin();
349 }
350
351 // noexcept because vector::end and array::end are noexcept
352 [[nodiscard]] const_iterator end() const noexcept {
353 return _container.end();
354 }
355
356 // noexcept because vector::cbegin and array::cbegin are noexcept
357 [[nodiscard]] const_iterator cbegin() const noexcept {
358 return _container.cbegin();
359 }
360
361 // noexcept because vector::cend and array::cend are noexcept
362 [[nodiscard]] const_iterator cend() const noexcept {
363 return _container.cend();
364 }
365
366 template <typename Iterator>
367 [[nodiscard]] std::pair<scalar_type, scalar_type>
368 coords(Iterator const& it) const;
369
371 // Modifiers
373
374 // noexcept because vector::swap and array::swap are noexcept
375 void swap(MatrixCommon& that) noexcept {
376 std::swap(_container, that._container);
377 }
378
379 // noexcept because swap is noexcept, and so too are number_of_rows and
380 // number_of_cols
381 void transpose_no_checks() noexcept;
382
383 void transpose() {
384 matrix::throw_if_not_square(static_cast<Subclass&>(*this));
385 transpose_no_checks();
386 }
387
389 // Rows
391
392 // not noexcept because there's an allocation
393 [[nodiscard]] RowView row_no_checks(size_t i) const;
394
395 [[nodiscard]] RowView row(size_t i) const;
396
397 // not noexcept because there's an allocation
398 template <typename T>
399 void rows(T& x) const;
400
402 // Friend functions
404
405 friend std::ostream& operator<<(std::ostream& os, MatrixCommon const& x) {
406 os << detail::to_string(x);
407 return os;
408 }
409
410 private:
412 // Private data
414 container_type _container;
415 }; // class MatrixCommon
416
417 template <typename Scalar>
418 class MatrixDynamicDim {
419 public:
420 MatrixDynamicDim() : _number_of_cols(0), _number_of_rows(0) {}
421 MatrixDynamicDim(MatrixDynamicDim const&) = default;
422 MatrixDynamicDim(MatrixDynamicDim&&) = default;
423 MatrixDynamicDim& operator=(MatrixDynamicDim const&) = default;
424 MatrixDynamicDim& operator=(MatrixDynamicDim&&) = default;
425
426 MatrixDynamicDim(size_t r, size_t c)
427 : _number_of_cols(c), _number_of_rows(r) {}
428
429 ~MatrixDynamicDim() = default;
430
431 void swap(MatrixDynamicDim& that) noexcept {
432 std::swap(_number_of_cols, that._number_of_cols);
433 std::swap(_number_of_rows, that._number_of_rows);
434 }
435
436 protected:
437 [[nodiscard]] size_t number_of_rows_impl() const noexcept {
438 return _number_of_rows;
439 }
440
441 [[nodiscard]] size_t number_of_cols_impl() const noexcept {
442 return _number_of_cols;
443 }
444
445 private:
446 size_t _number_of_cols;
447 size_t _number_of_rows;
448 };
449
450 template <typename PlusOp,
451 typename ProdOp,
452 typename ZeroOp,
453 typename OneOp,
454 typename Scalar>
455 struct MatrixStaticArithmetic {
456 MatrixStaticArithmetic() = default;
457 MatrixStaticArithmetic(MatrixStaticArithmetic const&) = default;
458 MatrixStaticArithmetic(MatrixStaticArithmetic&&) = default;
459 MatrixStaticArithmetic& operator=(MatrixStaticArithmetic const&) = default;
460 MatrixStaticArithmetic& operator=(MatrixStaticArithmetic&&) = default;
461
462 using scalar_type = Scalar;
463
464 protected:
465 [[nodiscard]] static constexpr scalar_type
466 plus_no_checks_impl(scalar_type x, scalar_type y) noexcept {
467 return PlusOp()(x, y);
468 }
469
470 [[nodiscard]] static constexpr scalar_type
471 product_no_checks_impl(scalar_type x, scalar_type y) noexcept {
472 return ProdOp()(x, y);
473 }
474
475 [[nodiscard]] static constexpr scalar_type one_impl() noexcept {
476 return OneOp()();
477 }
478
479 [[nodiscard]] static constexpr scalar_type zero_impl() noexcept {
480 return ZeroOp()();
481 }
482
483 [[nodiscard]] static constexpr void const* semiring_impl() noexcept {
484 return nullptr;
485 }
486 };
487
489 // RowViews - class for cheaply storing iterators to rows
491
492 template <typename Mat, typename Subclass>
493 class RowViewCommon {
494 static_assert(IsMatrix<Mat>,
495 "the template parameter Mat must be derived from "
496 "MatrixPolymorphicBase");
497
498 public:
499 using const_iterator = typename Mat::const_iterator;
500 using iterator = typename Mat::iterator;
501
502 using scalar_type = typename Mat::scalar_type;
503 using scalar_reference = typename Mat::scalar_reference;
504 using scalar_const_reference = typename Mat::scalar_const_reference;
505
506 using Row = typename Mat::Row;
507 using matrix_type = Mat;
508
509 [[nodiscard]] size_t size() const noexcept {
510 return static_cast<Subclass const*>(this)->length_impl();
511 }
512
513 private:
514 scalar_type plus_no_checks(scalar_type x, scalar_type y) const noexcept {
515 return static_cast<Subclass const*>(this)->plus_no_checks_impl(y, x);
516 }
517
518 scalar_type product_no_checks(scalar_type x, scalar_type y) const noexcept {
519 return static_cast<Subclass const*>(this)->product_no_checks_impl(y, x);
520 }
521
522 public:
523 RowViewCommon() = default;
524 RowViewCommon(RowViewCommon const&) = default;
525 RowViewCommon(RowViewCommon&&) = default;
526 RowViewCommon& operator=(RowViewCommon const&) = default;
527 RowViewCommon& operator=(RowViewCommon&&) = default;
528
529 ~RowViewCommon() = default;
530
531 explicit RowViewCommon(Row const& r)
532 : RowViewCommon(const_cast<Row&>(r).begin()) {}
533
534 // Not noexcept because iterator::operator[] isn't
535 [[nodiscard]] scalar_const_reference operator[](size_t i) const {
536 return _begin[i];
537 }
538
539 // Not noexcept because iterator::operator[] isn't
540 [[nodiscard]] scalar_reference operator[](size_t i) {
541 return _begin[i];
542 }
543
544 // Not noexcept because iterator::operator[] isn't
545 [[nodiscard]] scalar_const_reference operator()(size_t i) const {
546 return (*this)[i];
547 }
548
549 // Not noexcept because iterator::operator[] isn't
550 [[nodiscard]] scalar_reference operator()(size_t i) {
551 return (*this)[i];
552 }
553
554 // noexcept because begin() is
555 [[nodiscard]] const_iterator cbegin() const noexcept {
556 return _begin;
557 }
558
559 // not noexcept because iterator arithmetic isn't
560 [[nodiscard]] const_iterator cend() const {
561 return _begin + size();
562 }
563
564 // noexcept because begin() is
565 [[nodiscard]] const_iterator begin() const noexcept {
566 return _begin;
567 }
568
569 // not noexcept because iterator arithmetic isn't
570 [[nodiscard]] const_iterator end() const {
571 return _begin + size();
572 }
573
574 // noexcept because begin() is
575 [[nodiscard]] iterator begin() noexcept {
576 return _begin;
577 }
578
579 // not noexcept because iterator arithmetic isn't
580 [[nodiscard]] iterator end() noexcept {
581 return _begin + size();
582 }
583
585 // Arithmetic operators
587
588 // not noexcept because operator[] isn't
589 void plus_inplace_no_checks(RowViewCommon const& x);
590
591 // TODO add tests
592 void operator+=(RowViewCommon const& x);
593
594 // not noexcept because operator+= isn't
595 [[nodiscard]] Row plus_no_checks(RowViewCommon const& that) const {
596 Row result(*static_cast<Subclass const*>(this));
597 result.plus_inplace_no_checks(static_cast<Subclass const&>(that));
598 return result;
599 }
600
601 // TODO add tests
602 [[nodiscard]] Row operator+(RowViewCommon const& x);
603
604 // not noexcept because iterator arithmetic isn't
605 void operator+=(scalar_type a) {
606 for (auto& x : *this) {
607 x = plus_no_checks(x, a);
608 }
609 }
610
611 // not noexcept because iterator arithmetic isn't
612 void operator*=(scalar_type a) {
613 for (auto& x : *this) {
614 x = product_no_checks(x, a);
615 }
616 }
617
618 // not noexcept because operator*= isn'tl
619 [[nodiscard]] Row operator*(scalar_type a) const {
620 Row result(*static_cast<Subclass const*>(this));
621 result *= a;
622 return result;
623 }
624
625 template <typename U>
626 [[nodiscard]] bool operator==(U const& that) const {
627 // TODO(1) static assert that U is Row or RowView
628 return std::equal(begin(), end(), that.begin());
629 }
630
631 template <typename U>
632 [[nodiscard]] bool operator!=(U const& that) const {
633 return !(*this == that);
634 }
635
636 template <typename U>
637 [[nodiscard]] bool operator<(U const& that) const {
639 cbegin(), cend(), that.cbegin(), that.cend());
640 }
641
642 template <typename U>
643 [[nodiscard]] bool operator>(U const& that) const {
644 return that < *this;
645 }
646
647 void swap(RowViewCommon& that) noexcept {
648 std::swap(that._begin, _begin);
649 }
650
651 friend std::ostream& operator<<(std::ostream& os, RowViewCommon const& x) {
652 os << detail::to_string(x);
653 return os;
654 }
655
656 protected:
657 explicit RowViewCommon(iterator first) : _begin(first) {}
658
659 private:
660 iterator _begin;
661 }; // class RowViewCommon
662
663 template <typename Mat, typename Subclass>
664 void throw_if_bad_dim(RowViewCommon<Mat, Subclass> const& x,
665 RowViewCommon<Mat, Subclass> const& y,
666 std::string_view arg_desc_x = "the 1st argument",
667 std::string_view arg_desc_y = "the 2nd argument");
668} // namespace libsemigroups::detail
669
670#include "matrix-common.tpp"
671
672#endif // LIBSEMIGROUPS_DETAIL_MATRIX_COMMON_HPP_
T copy(T... args)
T equal(T... args)
constexpr bool IsMatrix
Helper variable template.
Definition is-matrix.hpp:87
T lexicographical_compare(T... args)
void throw_if_not_square(Mat const &x, std::string_view arg_desc="the argument")
Throws if a matrix is not square.
void throw_if_bad_coords(Mat const &x, size_t r, size_t c)
Throws the arguments do not index an entry of a matrix.
T swap(T... args)