Ai
1 Star 4 Fork 1

爱丽丝妹妹与数学妖精/matrix_plus_cpp

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
matrix.hpp 24.96 KB
一键复制 编辑 原始数据 按行查看 历史
AuroraKarow 提交于 2023-12-02 17:05 +08:00 . Update net_decimal from repository
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
MATRIX_BEGIN
matrix_declare struct net_strassen_matrix final {
uint64_t ln_cnt = 0,
ln_cnt_orgn = 0,
col_cnt = 0,
col_cnt_orgn = 0,
ln_begin = 0,
col_begin = 0;
matrix_ptr ptr = nullptr;
uint64_t curr_ln(uint64_t ln) const { return (ln + ln_begin) * col_cnt_orgn + col_begin; }
void init_para(uint64_t _ln_cnt, uint64_t _col_cnt) {
ln_cnt = _ln_cnt;
col_cnt = _col_cnt;
ln_cnt_orgn = _ln_cnt;
col_cnt_orgn = _col_cnt;
}
void allocate(uint64_t _ln_cnt, uint64_t _col_cnt) {
if (ln_cnt == ln_cnt_orgn && col_cnt == col_cnt_orgn) reset(ptr);
ptr = init<matrix_elem_t>(_ln_cnt * _col_cnt);
init_para(_ln_cnt, _col_cnt);
}
long double *operator[](uint64_t ln) const { return ptr + curr_ln(ln); }
net_strassen_matrix operator+(const net_strassen_matrix &right) const {
net_strassen_matrix ans;
if (!(ln_cnt == right.ln_cnt && col_cnt == right.col_cnt)) return ans;
ans.allocate(ln_cnt, col_cnt);
for (auto i = 0ull; i < ln_cnt; ++i) for (auto j = 0ull; j < col_cnt; ++j) ans[i][j] = ptr[curr_ln(i) + j] + right.ptr[right.curr_ln(i) + j];
return ans;
}
net_strassen_matrix operator-(const net_strassen_matrix &right) const {
net_strassen_matrix ans;
if (!(ln_cnt == right.ln_cnt && col_cnt == right.col_cnt)) return ans;
ans.allocate(ln_cnt, col_cnt);
for (auto i = 0ull; i < ln_cnt; ++i) for (auto j = 0ull; j < col_cnt; ++j) ans[i][j] = ptr[curr_ln(i) + j] - right.ptr[right.curr_ln(i) + j];
return ans;
}
net_strassen_matrix operator*(const net_strassen_matrix &right) const {
net_strassen_matrix ans;
if (col_cnt != right.ln_cnt) return ans;
ans.allocate(ln_cnt, right.col_cnt);
for (auto i = 0ull; i < ln_cnt; ++i) for (auto j = 0ull; j < col_cnt; ++j) {
auto coe = ptr[curr_ln(i) + j];
for(auto k = 0ull; k < ans.col_cnt; ++k) ans[i][k] += coe * right.ptr[right.curr_ln(j) + k];
}
return ans;
}
void release() {
if (ln_cnt == ln_cnt_orgn && col_cnt == col_cnt_orgn && ptr) reset(ptr);
else ptr = nullptr;
}
void print() {
for (auto i = 0ull; i < ln_cnt; ++i) {
for (auto j = 0ull; j < col_cnt; ++j) std::cout << *(ptr + curr_ln(i) + j) << '\t';
std::cout << '\n';
}
}
};
callback_matrix void strassen_child(net_strassen_matrix<matrix_elem_t> &ans, const net_strassen_matrix<matrix_elem_t> &src, uint64_t ln_from, uint64_t col_from, uint64_t ln_cnt_child, uint64_t col_cnt_child) {
ans.ptr = src.ptr;
ans.ln_cnt_orgn = src.ln_cnt_orgn;
ans.col_cnt_orgn = src.col_cnt_orgn;
ans.ln_cnt = ln_cnt_child;
ans.col_cnt = col_cnt_child;
ans.ln_begin = src.ln_begin + ln_from;
ans.col_begin = src.col_begin + col_from;
}
callback_matrix void strassen_quartile(const net_strassen_matrix<matrix_elem_t> &src, net_strassen_matrix<matrix_elem_t> &src00, net_strassen_matrix<matrix_elem_t> &src01, net_strassen_matrix<matrix_elem_t> &src10, net_strassen_matrix<matrix_elem_t> &src11) {
auto dim_cnt_sub = src.ln_cnt / 2;
strassen_child(src00, src, 0, 0, dim_cnt_sub, dim_cnt_sub);
strassen_child(src01, src, 0, dim_cnt_sub, dim_cnt_sub, dim_cnt_sub);
strassen_child(src10, src, dim_cnt_sub, 0, dim_cnt_sub, dim_cnt_sub);
strassen_child(src11, src, dim_cnt_sub, dim_cnt_sub, dim_cnt_sub, dim_cnt_sub);
}
callback_matrix net_strassen_matrix<matrix_elem_t> strassen_mult(const net_strassen_matrix<matrix_elem_t> &left, const net_strassen_matrix<matrix_elem_t> &right, uint64_t recursive_gate) {
if (left.ln_cnt <= recursive_gate) return left * right;
net_strassen_matrix<matrix_elem_t> a00, a01, a10, a11, b00, b01, b10, b11;
strassen_quartile(left, a00, a01, a10, a11);
strassen_quartile(right, b00, b01, b10, b11);
auto s0 = b01 - b11;
auto s1 = a00 + a01;
auto s2 = a10 + a11;
auto s3 = b10 - b00;
auto s4 = a00 + a11;
auto s5 = b00 + b11;
auto s6 = a01 - a11;
auto s7 = b10 + b11;
auto s8 = a00 - a10;
auto s9 = b00 + b01;
auto p0 = strassen_mult(a00, s0, recursive_gate);
auto p1 = strassen_mult(s1, b11, recursive_gate);
auto p2 = strassen_mult(s2, b00, recursive_gate);
auto p3 = strassen_mult(a11, s3, recursive_gate);
auto p4 = strassen_mult(s4, s5, recursive_gate);
auto p5 = strassen_mult(s6, s7, recursive_gate);
auto p6 = strassen_mult(s8, s9, recursive_gate);
net_strassen_matrix<matrix_elem_t> ans;
ans.allocate(left.ln_cnt, right.col_cnt);
auto dim_cnt_sub = ans.ln_cnt / 2;
for (auto i = 0ull; i < ans.ln_cnt; ++i) for (auto j = 0ull; j < ans.col_cnt; ++j) if (i < dim_cnt_sub) {
if (j < dim_cnt_sub) ans[i][j] = p4[i][j] + p3[i][j] - p1[i][j] + p5[i][j];
else ans[i][j] = p0[i][j - dim_cnt_sub] + p1[i][j - dim_cnt_sub];
} else {
if (j < dim_cnt_sub) ans[i][j] = p2[i - dim_cnt_sub][j] + p3[i - dim_cnt_sub][j];
else ans[i][j] = p4[i - dim_cnt_sub][j - dim_cnt_sub] + p0[i - dim_cnt_sub][j - dim_cnt_sub] - p2[i - dim_cnt_sub][j - dim_cnt_sub] - p6[i - dim_cnt_sub][j - dim_cnt_sub];
}
s0.release(); s1.release(); s2.release(); s3.release(); s4.release(); s5.release(); s6.release(); s7.release(); s8.release(); s9.release();
p0.release(); p1.release(); p2.release(); p3.release(); p4.release(); p5.release(); p6.release();
return ans;
}
callback_matrix matrix_ptr strassen_mult(const matrix_ptr left, uint64_t left_ln_cnt, uint64_t left_col_cnt, const matrix_ptr right, uint64_t right_col_cnt, uint64_t recursive_gate = 32) {
auto r_p_l = left_col_cnt,
b_p_l = left_ln_cnt,
r_p_r = right_col_cnt,
b_p_r = left_col_cnt,
mx_len = num_extreme({r_p_l, b_p_l, r_p_r, b_p_r});
if (mx_len <= recursive_gate) return mult(left, left_ln_cnt, left_col_cnt, right, right_col_cnt);
auto pad_val = num_pad_pow(mx_len, 2),
pad_ans = mx_len + pad_val;
r_p_l = pad_ans - r_p_l;
b_p_l = pad_ans - b_p_l;
r_p_r = pad_ans - r_p_r;
b_p_r = pad_ans - b_p_r;
net_strassen_matrix<matrix_elem_t> left_stsn, right_stsn;
left_stsn.init_para(pad_ans, pad_ans);
right_stsn.init_para(pad_ans, pad_ans);
left_stsn.ptr = pad(pad_ans, pad_ans, left, left_ln_cnt, left_col_cnt, 0, r_p_l, b_p_l),
right_stsn.ptr = pad(pad_ans, pad_ans, right, left_col_cnt, right_col_cnt, 0, r_p_r, b_p_r);
auto ans = strassen_mult(left_stsn, right_stsn, recursive_gate);
auto ans_ptr = crop(left_ln_cnt, right_col_cnt, ans.ptr, pad_ans, pad_ans, 0, r_p_r, b_p_l);
left_stsn.release();
right_stsn.release();
ans.release();
return ans_ptr;
}
matrix_declare class net_matrix {
protected:
void value_assign(const net_matrix &src) {
ln_cnt = src.ln_cnt;
col_cnt = src.col_cnt;
elem_cnt = src.elem_cnt;
}
void value_copy(const net_matrix &src) {
copy(ptr, elem_cnt, src.ptr, src.elem_cnt);
value_assign(src);
}
template<typename s_arg> void value_copy(const net_matrix<s_arg> &src) {
ptr_alter(ptr, elem_cnt, src.__elem_cnt__(), false);
ln_cnt = src.__ln_cnt__();
col_cnt = src.__col_cnt__();
elem_cnt = src.__elem_cnt__();
for (auto i = 0ull; i < elem_cnt; ++i) if constexpr (std::is_same_v<s_arg, net_decimal>) *(ptr + i) = src.index(i).number_format;
else *(ptr + i) = src.index(i);
}
void value_move(net_matrix &&src) {
move(ptr, std::move(src.ptr));
value_assign(src);
src.reset();
}
void init_list_mtx(net_matrix &&_vect, net_sequence<net_matrix> &vect_set, net_sequence<pos> &left_top_pos, uint64_t curr_left, uint64_t curr_top, uint64_t &left_increase, uint64_t &top_increase) const {
left_increase = _vect.col_cnt;
top_increase = _vect.ln_cnt;
vect_set.emplace_back(std::move(_vect));
left_top_pos.emplace_back(pos{curr_top, curr_left});
}
void init_list_mtx(std::initializer_list<std::initializer_list<net_matrix>> _vect, net_sequence<net_matrix> &vect_set, net_sequence<pos> &left_top_pos, uint64_t curr_left, uint64_t curr_top, uint64_t &left_increase, uint64_t &top_increase) const {
auto curr_sub_top = curr_top;
for(auto ln_temp : _vect) {
auto curr_sub_left = curr_left,
curr_top_increase = 0ull;
for(auto temp : ln_temp) {
auto curr_left_increase = 0ull;
init_list_mtx(std::move(temp), vect_set, left_top_pos, curr_sub_left, curr_sub_top, curr_left_increase, curr_top_increase);
curr_sub_left += curr_left_increase;
}
curr_sub_top += curr_top_increase;
}
}
struct line_data final {
public:
line_data(const net_matrix *buf_ptr = nullptr, uint64_t begin_idx = 0, uint64_t ptr_len = 0) :
_buf_ptr_(buf_ptr),
_ptr_ln_(begin_idx),
_ptr_col_cnt(ptr_len) {}
matrix_elem_t &operator[](uint64_t col) const {
if (col >= _ptr_col_cnt) return neunet_null_ref(matrix_elem_t);
return *(_buf_ptr_->ptr + _ptr_ln_ * _ptr_col_cnt + col);
}
~line_data() {
_buf_ptr_ = nullptr;
_ptr_ln_ = 0;
_ptr_col_cnt = 0;
}
private:
const net_matrix *_buf_ptr_ = nullptr;
uint64_t _ptr_ln_ = 0,
_ptr_col_cnt = 0;
};
public:
net_matrix(uint64_t mtx_ln_cnt, uint64_t mtx_col_cnt, bool rand_init = false, const matrix_elem_t &fst_rng = -1, const matrix_elem_t &snd_rng = 1) :
ln_cnt(mtx_ln_cnt),
col_cnt(mtx_col_cnt),
elem_cnt(mtx_ln_cnt * mtx_col_cnt) {
if (elem_cnt == 0) return;
if (rand_init) ptr = init_rand<matrix_elem_t>(elem_cnt, fst_rng, snd_rng);
else ptr = init<matrix_elem_t>(elem_cnt);
}
template<typename matrix_elem_para, typename matrix_elem_para_v> net_matrix(const i_arg &src) :
elem_cnt(1),
ln_cnt(1),
col_cnt(1),
ptr(init<matrix_elem_t>(1)) {
if constexpr (std::is_same_v<matrix_elem_para, net_decimal> && !std::is_same_v<matrix_elem_t, matrix_elem_para>) *ptr = src.number_format;
else *ptr = src;
}
net_matrix(matrix_ptr &&src = nullptr, uint64_t mtx_ln_cnt = 0, uint64_t mtx_col_cnt = 0) :
ln_cnt(mtx_ln_cnt),
col_cnt(mtx_col_cnt),
elem_cnt(mtx_ln_cnt * mtx_col_cnt) { ptr_move(ptr, std::move(src)); }
net_matrix(const net_sequence<net_matrix> &_vect, const net_sequence<pos> &left_top_pos) {
if (_vect.length != left_top_pos.length) return;
uint64_t ln_cnt_temp = 0,
col_cnt_temp = 0;
for (auto i = 0ull; i < left_top_pos.length; ++i) {
if(!left_top_pos[i].ln) col_cnt_temp += _vect[i].col_cnt;
if(!left_top_pos[i].col) ln_cnt_temp += _vect[i].ln_cnt;
}
value_move(net_matrix(ln_cnt_temp, col_cnt_temp));
for (auto i = 0ull; i < _vect.length; ++i) for (auto j = 0ull; j < _vect[i].ln_cnt; ++j) for(auto k = 0ull; k < _vect[i].col_cnt; ++k) (*this)[left_top_pos[i].ln + j][left_top_pos[i].col + k] = std::move(_vect[i][j][k]);
}
net_matrix(std::initializer_list<std::initializer_list<net_matrix>> _vect) {
if(_vect.size() && _vect.begin()->size()) {
net_sequence<net_matrix> temp_seq;
net_sequence<pos> left_top_pos;
uint64_t top_temp = 0,
left_temp = 0;
init_list_mtx(_vect, temp_seq, left_top_pos, 0, 0, left_temp, top_temp);
value_move(net_matrix(temp_seq, left_top_pos));
}
}
net_matrix(const net_matrix &src) { value_copy(src); }
net_matrix(net_matrix &&src) { value_move(std::move(src)); }
template<typename s_arg> net_matrix(const net_matrix<s_arg> &src) { value_copy(src); }
bool is_matrix() const { return ptr && (ln_cnt * col_cnt == elem_cnt) && ln_cnt && col_cnt; }
template<typename matrix_elem_para, typename matrix_elem_para_v> void fill_elem(const matrix_elem_para &src) {
if constexpr (!std::is_same_v<matrix_elem_para, matrix_elem_t> && std::is_same_v<matrix_elem_para, net_decimal>) fill(ptr, elem_cnt, src.number_format);
else fill<matrix_elem_t>(ptr, elem_cnt, src);
}
net_list<pos> extremum_position(bool get_max, uint64_t from_ln, uint64_t to_ln, uint64_t from_col, uint64_t to_col, uint64_t ln_dilate = 0, uint64_t col_dilate = 0) const {
net_list<pos> ans;
extremum(ans, get_max, ptr, ln_cnt, col_cnt, from_ln, to_ln, from_col, to_col, ln_dilate, col_dilate);
return ans;
}
net_list<pos> extremum_position(bool get_max = true) const { return extremum_position(get_max, 0, ln_cnt - 1, 0, col_cnt - 1); }
net_matrix child(uint64_t from_ln, uint64_t to_ln, uint64_t from_col, uint64_t to_col, uint64_t ln_dilate, uint64_t col_dilate) const {
uint64_t ln_cnt_temp = 0,
col_cnt_temp = 0;
auto ans_temp = sub(ln_cnt_temp, col_cnt_temp, ptr, ln_cnt, col_cnt, from_ln, to_ln, from_col, to_col, ln_dilate, col_dilate);
return net_matrix(std::move(ans_temp), ln_cnt_temp, col_cnt_temp);
}
net_matrix child(uint64_t from_ln, uint64_t c_ln_cnt, uint64_t from_col, uint64_t c_col_cnt) { return child(from_ln, from_ln + c_ln_cnt - 1, from_col, from_col + c_col_cnt - 1, 0, 0); }
net_matrix rotate_rectangle(bool clockwise = true) const { return net_matrix(rotate_rect(ptr, ln_cnt, col_cnt, clockwise), col_cnt, ln_cnt); }
net_matrix mirror_flipping(bool vertical_symmetry = true) const { return net_matrix(mirror_flip(ptr, ln_cnt, col_cnt, vertical_symmetry), ln_cnt, col_cnt); }
bool shape_verify(uint64_t src_ln_cnt, uint64_t src_col_cnt) const { return ln_cnt == src_ln_cnt && col_cnt == src_col_cnt; }
bool shape_verify(const net_matrix &src) const { return shape_verify(src.ln_cnt, src.col_cnt); }
net_matrix reshape(uint64_t re_ln_cnt, uint64_t re_col_cnt) const {
auto ans = *this;
if (re_ln_cnt * re_col_cnt == elem_cnt) {
ans.ln_cnt = re_ln_cnt;
ans.col_cnt = re_col_cnt;
}
return ans;
}
net_matrix reshape(const net_matrix &src) const { return reshape(src.ln_cnt, src.col_cnt); }
net_matrix elem_swap(uint64_t fst_idx, uint64_t snd_idx, bool ln_swap = true) { return net_matrix(ln_col_swap(ptr, ln_cnt, col_cnt, fst_idx, snd_idx, ln_swap), ln_cnt, col_cnt); }
matrix_elem_t elem_sum() {
matrix_elem_t ans = 0;
for (auto i = 0ull; i < elem_cnt; ++i) ans += (*(ptr + i));
return ans;
}
matrix_elem_t elem_sum(uint64_t from_ln, uint64_t to_ln, uint64_t from_col, uint64_t to_col, uint64_t ln_dilate, uint64_t col_dilate) const { return child(from_ln, to_ln, from_col, to_col, ln_dilate, col_dilate).elem_sum(); }
net_matrix elem_wise_opt(const net_matrix &val, uint64_t operation, bool elem_swap = false, long double epsilon = 1e-8) const {
if (shape_verify(val)) return net_matrix(elem_operate(ptr, val.ptr, elem_cnt, operation, elem_swap, epsilon), ln_cnt, col_cnt);
else return net_matrix();
}
template<typename matrix_elem_para, typename matrix_elem_para_v> net_matrix elem_wise_opt(const matrix_elem_para &para, uint64_t operation, bool para_fst = false, long double epsilon = 1e-8) const {
if constexpr (std::is_same_v<matrix_elem_para, net_decimal> && !std::is_same_v<matrix_elem_t, net_decimal>) return net_matrix(elem_operate(ptr, elem_cnt, para.number_format, operation, para_fst, epsilon), ln_cnt, col_cnt);
else return net_matrix(elem_operate<matrix_elem_t>(ptr, elem_cnt, para, operation, para_fst, epsilon), ln_cnt, col_cnt);
}
template<typename matrix_elem_para, typename matrix_elem_para_v> net_matrix broadcast_addition(const matrix_elem_para &para, bool subtract = false, bool para_fst = false) const {
if constexpr (std::is_same_v<matrix_elem_para, net_decimal> && !std::is_same_v<matrix_elem_t, net_decimal>) return net_matrix(broadcast_add(ptr, elem_cnt, para.number_format, subtract, para_fst), ln_cnt, col_cnt);
else return net_matrix(broadcast_add<matrix_elem_t>(ptr, elem_cnt, para, subtract, para_fst), ln_cnt, col_cnt);
}
net_matrix padding(uint64_t top = 0, uint64_t right = 0, uint64_t bottom = 0, uint64_t left = 0, uint64_t ln_dist = 0, uint64_t col_dist = 0) const {
uint64_t ln_cnt_temp = 0,
col_cnt_temp = 0;
auto ans_temp = pad(ln_cnt_temp, col_cnt_temp, ptr, ln_cnt, col_cnt, top, right, bottom, left, ln_dist, col_dist);
return net_matrix(std::move(ans_temp), ln_cnt_temp, col_cnt_temp);
}
net_matrix cropping(uint64_t top = 0, uint64_t right = 0, uint64_t bottom = 0, uint64_t left = 0, uint64_t ln_dist = 0, uint64_t col_dist = 0) {
uint64_t ln_cnt_temp = 0,
col_cnt_temp = 0;
auto ans_temp = crop(ln_cnt_temp, col_cnt_temp, ptr, ln_cnt, col_cnt, top, right, bottom, left, ln_dist, col_dist);
return net_matrix(std::move(ans_temp), ln_cnt_temp, col_cnt_temp);
}
matrix_elem_t &index(uint64_t idx) const {
if (idx >= elem_cnt)return neunet_null_ref(matrix_elem_t);
return *(ptr + idx);
}
net_matrix<long double> float_point_format() const {
if constexpr (std::is_same_v<net_decimal, matrix_elem_t>) {
net_matrix<long double> ans(ln_cnt, col_cnt);
for (auto i = 0ull; i < elem_cnt; ++i) ans.index(i) = (*(ptr + i)).number_format;
return ans;
} else return net_matrix(*this);
}
void reset() {
recycle<matrix_elem_t>(ptr);
ln_cnt = 0;
col_cnt = 0;
elem_cnt = 0;
}
~net_matrix() { reset(); }
protected:
uint64_t ln_cnt = 0,
col_cnt = 0,
elem_cnt = 0;
matrix_ptr ptr = nullptr;
public:
uint64_t __ln_cnt__() const { return ln_cnt; }
uint64_t __col_cnt__() const { return col_cnt; }
uint64_t __elem_cnt__() const { return elem_cnt; }
matrix_elem_t __determinant__() const {
if (ln_cnt != col_cnt) return neunet_null_ref(matrix_elem_t);
return det(ptr, ln_cnt);
}
net_matrix __inverse_() const {
if (ln_cnt != col_cnt) return neunet_null_ref(net_matrix);
return net_matrix(inverser(ptr, ln_cnt), ln_cnt, col_cnt);
}
net_matrix __tranpose__() const { return net_matrix(transposition(ptr, ln_cnt, col_cnt), col_cnt, ln_cnt); }
matrix_elem_t __atom__() const {
if (!(ln_cnt == col_cnt && ln_cnt == 1)) return neunet_null_ref(matrix_elem_t);
return *ptr;
}
net_matrix __abs__() const { return net_matrix(absolute(ptr, elem_cnt), ln_cnt, col_cnt); }
net_matrix __LU_decompose__() const {
if (ln_cnt != col_cnt) return neunet_null_ref(net_matrix);
return net_matrix(LU(ptr, ln_cnt), ln_cnt, col_cnt);
}
net_matrix __adjugation__() const {
if (ln_cnt != col_cnt) return neunet_null_ref(net_matrix);
return net_matrix(adjugate(ptr, ln_cnt), ln_cnt, col_cnt);
}
uint64_t __ranking__() const { return rank(ptr, ln_cnt, col_cnt); }
net_matrix<net_decimal> __decimal_format__() const {
if constexpr (std::is_same_v<net_decimal, matrix_elem_t>) return net_matrix(*this);
else {
net_matrix<net_decimal> ans(ln_cnt, col_cnt);
for (auto i = 0ull; i < elem_cnt; ++i) ans.index(i) = *(ptr + i);
return ans;
}
}
static net_matrix identity(uint64_t dim_cnt) { return net_matrix(init_identity<matrix_elem_t>(dim_cnt), dim_cnt, dim_cnt); }
static net_matrix set_sigma(const net_set<net_matrix> &vect_set) {
if constexpr (std::is_same_v<matrix_elem_t, double> || (std::is_same_v<matrix_elem_t, long double> && sizeof(long double) == sizeof(double))) {
net_matrix ans;
ans.ln_cnt = vect_set[0].ln_cnt;
ans.col_cnt = vect_set[0].col_cnt;
ans.elem_cnt = vect_set[0].elem_cnt;
__m256d ans_reg[MATRIX_UNROLL];
auto ans_ptr = new double [ans.elem_cnt];
auto idx_temp = 0ull;
while (idx_temp < ans.elem_cnt) {
auto idx_next = idx_temp + MATRIX_UNROLL_UNIT;
if (idx_next > ans.elem_cnt) {
for (auto i = idx_temp; i < ans.elem_cnt; ++i) {
*(ans_ptr + i) = (*(vect_set[0].ptr + i));
for (auto j = 1ull; j < vect_set.length; ++j) *(ans_ptr + i) += (*(vect_set[j].ptr + i));
}
break;
}
for (auto x = 0; x < MATRIX_UNROLL; ++x) ans_reg[x] = _mm256_load_pd((double *)vect_set[0].ptr + idx_temp + x * MATRIX_REGSIZE);
for (auto i = 1ull; i < vect_set.length; ++i) for (auto x = 0; x < MATRIX_UNROLL; ++x) ans_reg[x] = _mm256_add_pd(ans_reg[x], _mm256_load_pd((double *)vect_set[i].ptr + idx_temp + x * MATRIX_REGSIZE));
for (auto x = 0; x < MATRIX_UNROLL; ++x) _mm256_store_pd(ans_ptr + idx_temp + x * MATRIX_REGSIZE, ans_reg[x]);
idx_temp = idx_next;
}
ans.ptr = (matrix_elem_t *)ans_ptr;
ans_ptr = nullptr;
return ans;
} else return vect_set.sum;
}
__declspec(property(get=__ln_cnt__)) uint64_t line_count;
__declspec(property(get=__col_cnt__)) uint64_t column_count;
__declspec(property(get=__elem_cnt__)) uint64_t element_count;
__declspec(property(get=is_matrix)) bool verify;
__declspec(property(get=__determinant__)) matrix_elem_t determinant;
__declspec(property(get=__inverse_)) net_matrix inverse;
__declspec(property(get=__tranpose__)) net_matrix transpose;
__declspec(property(get=__atom__)) matrix_elem_t atom;
__declspec(property(get=__abs__)) net_matrix abs;
__declspec(property(get=__LU_decompose__)) net_matrix LU_decompse;
__declspec(property(get=__adjugation__)) net_matrix adjugation;
__declspec(property(get=__ranking__)) uint64_t ranking;
__declspec(property(get=float_point_format)) net_matrix<long double> float_format;
__declspec(property(get=__decimal_format__)) net_matrix<net_decimal> decimal_format;
net_matrix operator+(const net_matrix &src) const {
if (!shape_verify(src)) return net_matrix(*this);
net_matrix ans;
ans.ptr = add(ptr, src.ptr, elem_cnt);
ans.value_assign(src);
return ans;
}
void operator+=(const net_matrix &src) { *this = *this + src; }
net_matrix operator-(const net_matrix &src) const {
if (!shape_verify(src)) return net_matrix(*this);
net_matrix ans;
ans.ptr = add(ptr, src.ptr, elem_cnt, true);
ans.value_assign(src);
return ans;
}
void operator-=(const net_matrix &src) { *this = *this - src; }
net_matrix operator*(const net_matrix &src) const {
net_matrix ans;
if (src.elem_cnt == 1) {
ans.ptr = mult(ptr, elem_cnt, *src.ptr);
ans.col_cnt = col_cnt;
ans.elem_cnt = elem_cnt;
} else {
if (col_cnt != src.ln_cnt) return net_matrix(*this);
ans.ptr = mult(ptr, ln_cnt, col_cnt, src.ptr, src.col_cnt);
ans.col_cnt = src.col_cnt;
ans.elem_cnt = ln_cnt * ans.col_cnt;
}
ans.ln_cnt = ln_cnt;
return ans;
}
template<typename matrix_elem_para, typename matrix_elem_para_v> friend net_matrix operator*(const matrix_elem_para &para, const net_matrix &src) { return src * net_matrix(para); }
void operator*=(const net_matrix &src) { *this = *this * src; }
net_matrix &operator=(const net_matrix &src) {
value_copy(src);
return *this;
}
net_matrix &operator=(net_matrix &&src) {
value_move(std::move(src));
return *this;
}
template<typename s_arg> net_matrix &operator=(const net_matrix<s_arg> &src) {
value_copy(src);
return *this;
}
line_data operator[](uint64_t ln) const {
if (ln > ln_cnt) return {};
return line_data(this, ln, col_cnt);
}
bool operator==(const net_matrix &src) const {
if (shape_verify(src)) {
for (auto i = 0ull; i < elem_cnt; ++i) if (*(ptr + i) != *(src.ptr + i)) return false;
return true;
} else return false;
}
bool operator!=(const net_matrix &src) const { return !(*this == src); }
friend std::ostream &operator<<(std::ostream &out, const net_matrix &src) {
if(src.is_matrix()) for(auto i = 0ull; i < src.elem_cnt; ++i) {
out << *(src.ptr + i);
if(!((i + 1) % src.col_cnt || (i + 1) == src.elem_cnt)) out << '\n';
else out << '\t';
}
return out;
}
};
callback_matrix neunet_vect vect_sum(const net_set<neunet_vect> &vect_set) { return net_matrix<matrix_elem_t>::set_sigma(vect_set); }
MATRIX_END
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
C++
1
https://gitee.com/AliceSisMathElf/matrix_plus_cpp.git
git@gitee.com:AliceSisMathElf/matrix_plus_cpp.git
AliceSisMathElf
matrix_plus_cpp
matrix_plus_cpp
master

搜索帮助