#pragma once
// tensor.hpp: super-simple dense tensor class implementation
//
// Copyright (C) 2017 Stillwater Supercomputing, Inc.
// SPDX-License-Identifier: MIT
//
// This file is part of the universal numbers project, which is released under an MIT Open Source license.
#include <iostream>
#include <vector>
#include <initializer_list>
#include <map>
#include <universal/blas/exceptions.hpp>

#if defined(__clang__)
/* Clang/LLVM. ---------------------------------------------- */
#define _HAS_NODISCARD 1

#elif defined(__ICC) || defined(__INTEL_COMPILER)
/* Intel ICC/ICPC. ------------------------------------------ */
#define _HAS_NODISCARD 1

#elif defined(__GNUC__) || defined(__GNUG__)
/* GNU GCC/G++. --------------------------------------------- */
#define _HAS_NODISCARD 1

#elif defined(__HP_cc) || defined(__HP_aCC)
/* Hewlett-Packard C/aC++. ---------------------------------- */
#define _HAS_NODISCARD 1

#elif defined(__IBMC__) || defined(__IBMCPP__)
/* IBM XL C/C++. -------------------------------------------- */
#define _HAS_NODISCARD 1

#elif defined(_MSC_VER)
/* Microsoft Visual Studio. --------------------------------- */
// already defines _NODISCARD

#elif defined(__PGI)
/* Portland Group PGCC/PGCPP. ------------------------------- */
#define _HAS_NODISCARD 1

#elif defined(__SUNPRO_C) || defined(__SUNPRO_CC)
/* Oracle Solaris Studio. ----------------------------------- */
#define _HAS_NODISCARD 1

#endif

#if _HAS_NODISCARD
#define _NODISCARD [[nodiscard]]
#else // ^^^ CAN HAZ [[nodiscard]] / NO CAN HAZ [[nodiscard]] vvv
#define _NODISCARD
#endif // _HAS_NODISCARD

namespace sw { namespace universal { namespace blas { 

template<typename Scalar>
class tensor {
public:
	typedef Scalar									value_type;
	typedef const value_type&						const_reference;
	typedef value_type&								reference;
	typedef const value_type*						const_pointer_type;
	typedef typename std::vector<Scalar>::size_type size_type;
	typedef typename std::vector<Scalar>::iterator     iterator;
	typedef typename std::vector<Scalar>::const_iterator const_iterator;
	typedef typename std::vector<Scalar>::reverse_iterator reverse_iterator;
	typedef typename std::vector<Scalar>::const_reverse_iterator const_reverse_iterator;
	static constexpr unsigned AggregationType = UNIVERSAL_AGGREGATE_TENSOR;

	tensor() : _m{ 0 }, _n{ 0 }, data(0) {}
	tensor(unsigned m, unsigned n) : _m{ m }, _n{ n }, data(m*n, Scalar(0.0)) { }
	tensor(std::initializer_list< std::initializer_list<Scalar> > values) {
		unsigned nrows = values.size();
		unsigned ncols = values.begin()->size();
		data.resize(nrows * ncols);
		unsigned r = 0;
		for (auto l : values) {
			if (l.size() == ncols) {
				unsigned c = 0;
				for (auto v : l) {
					data[r*ncols + c] = v;
					++c;
				}
				++r;
			}
		}
		_m = nrows;
		_n = ncols;
	}
	tensor(const tensor& A) : _m{ A._m }, _n{ A._n }, data(A.data) {}

	// Converting Constructor (SourceType A --> Scalar B)
	template<typename SourceType>
	tensor(const tensor<SourceType>& A) : _m{ A.rows() }, _n{A.cols() }{
		data.resize(_m*_n);
		for (unsigned i = 0; i < _m; ++i){
			for (unsigned j = 0; j < _n; ++j){
				data[i*_n + j] = Scalar(A(i,j));
			}
		}
	}


	/* Operators
	/ 	Binary and unitary operators, =,>,!=,...
	*/
	tensor& operator=(const tensor& M) = default;
	tensor& operator=(tensor&& M) = default;

	// Identity tensor operator
	tensor& operator=(const Scalar& one) {
		setzero();
		unsigned smallestDimension = (_m < _n ? _m : _n);
		for (unsigned i = 0; i < smallestDimension; ++i) data[i*_n + i] = one;
		return *this;
	}

	Scalar operator()(unsigned i, unsigned j) const { return data[i*_n + j]; }
	Scalar& operator()(unsigned i, unsigned j) { return data[i*_n + j]; }
	RowProxy<Scalar> operator[](unsigned i) {
		typename std::vector<Scalar>::iterator it = data.begin() + int64_t(i) * int64_t(_n);
		RowProxy<Scalar> proxy(it);
		return proxy;
	}
	ConstRowProxy<Scalar> operator[](unsigned i) const {
		typename std::vector<Scalar>::const_iterator it = data.begin() + static_cast<int64_t>(i * _n);
		ConstRowProxy<Scalar> proxy(it);
		return proxy;
	}

	// tensor element-wise sum
	tensor& operator+=(const tensor& rhs) {
		// check if the matrices are compatible
		if (_m != rhs._m || _n != rhs._n) {
			std::cerr << "Element-wise tensor sum received incompatible matrices ("
				<< _m << ", " << _n << ") += (" << rhs._m << ", " << rhs._n << ")\n";
			return *this; // return without changing
		}
		for (size_type e = 0; e < _m * _n; ++e) {
			data[e] += rhs.data[e];
		}
		return *this;
	}
	// tensor element-wise difference
	tensor& operator-=(const tensor& rhs) {
		// check if the matrices are compatible
		if (_m != rhs._m || _n != rhs._n) {
			std::cerr << "Element-wise tensor difference received incompatible matrices ("
				<< _m << ", " << _n << ") -= (" << rhs._m << ", " << rhs._n << ")\n";
			return *this; // return without changing
		}
		for (size_type e = 0; e < _m*_n; ++e) {
			data[e] -= rhs.data[e];
		}
		return *this;
	}

	// multiply all tensor elements
	tensor& operator*=(const Scalar& a) {
		using size_type = typename tensor<Scalar>::size_type;
		for (size_type e = 0; e < _m*_n; ++e) {
			data[e] *= a;
		}
		return *this;
	}
	// divide all tensor elements
	tensor& operator/=(const Scalar& a) {
		using size_type = typename tensor<Scalar>::size_type;
		for (size_type e = 0; e < _m * _n; ++e) {
			data[e] /= a;
		}
		return *this;
	}

	// modifiers
	inline void setzero() { for (auto& elem : data) elem = Scalar(0); }
	inline void resize(unsigned m, unsigned n) { _m = m; _n = n; data.resize(m * n); }
	// selectors
	inline unsigned rows() const { return _m; }
	inline unsigned cols() const { return _n; }
	inline std::pair<unsigned, unsigned> size() const { return std::make_pair(_m, _n); }

	// in-place transpose
	tensor& transpose() {
		unsigned size = _m * _n - 1;
		std::map<unsigned, bool> b; // mark visits
		b[0] = true; // A(0,0) is stationary
		b[size] = true; // A(m-1,n-1) is stationary
		unsigned index = 1;
		while (index < size) {
			unsigned cycleStart = index; // holds start of cycle
			Scalar e = data[index]; // holds value of the element to be swapped
			do {
				unsigned next = (index * _m) % size; // index of e
				std::swap(data[next], e);
				b[index] = true;
				index = next;
			} while (index != cycleStart);
			// get the starting point of the next cycle
			for (index = 1; index < size && b[index]; ++index) {}
		}
		std::swap(_m, _n);
		return *this;
	}

	// Eigen operators I need to reverse engineer
	tensor Zero(unsigned m, unsigned n) {
		tensor z(m, n);
		return z;
	}

	// iterators
	_NODISCARD iterator begin() noexcept {
		return data.begin();
	}

	_NODISCARD const_iterator begin() const noexcept {
		return data.begin();
	}

	_NODISCARD iterator end() noexcept {
		return data.end();
	}

	_NODISCARD const_iterator end() const noexcept {
		return data.end();
	}

	_NODISCARD reverse_iterator rbegin() noexcept {
		return reverse_iterator(end());
	}

	_NODISCARD const_reverse_iterator rbegin() const noexcept {
		return const_reverse_iterator(end());
	}

	_NODISCARD reverse_iterator rend() noexcept {
		return reverse_iterator(begin());
	}

	_NODISCARD const_reverse_iterator rend() const noexcept {
		return const_reverse_iterator(begin());
	}

private:
	unsigned _m, _n; // m rows and n columns
	std::vector<Scalar> data;

};

template<typename Scalar>
inline unsigned num_rows(const tensor<Scalar>& A) { return A.rows(); }
template<typename Scalar>
inline unsigned num_cols(const tensor<Scalar>& A) { return A.cols(); }
template<typename Scalar>
inline std::pair<unsigned, unsigned> size(const tensor<Scalar>& A) { return A.size(); }

// ostream operator: no need to declare as friend as it only uses public interfaces
template<typename Scalar>
std::ostream& operator<<(std::ostream& ostr, const tensor<Scalar>& A) {
	auto width = ostr.width();
	unsigned m = A.rows();
	unsigned n = A.cols();
	for (unsigned i = 0; i < m; ++i) {
		for (unsigned j = 0; j < n; ++j) {
			ostr << std::setw(width) << A(i, j) << " ";
		}
		ostr << '\n';
	}
	return ostr;
}

// tensor element-wise sum
template<typename Scalar>
tensor<Scalar> operator+(const tensor<Scalar>& A, const tensor<Scalar>& B) {
	tensor<Scalar> Sum(A);
	return Sum += B;
}

// tensor element-wise difference
template<typename Scalar>
tensor<Scalar> operator-(const tensor<Scalar>& A, const tensor<Scalar>& B) {
	tensor<Scalar> Diff(A);
	return Diff -= B;
}

// tensor scaling through Scalar multiply
template<typename Scalar>
tensor<Scalar> operator*(const Scalar& a, const tensor<Scalar>& B) {
	tensor<Scalar> A(B);
	return A *= a;
}

// tensor scaling through Scalar divide
template<typename Scalar>
tensor<Scalar> operator/(const tensor<Scalar>& A, const Scalar& b) {
	tensor<Scalar> B(A);
	return B /= b;
}

 
// tensor-vector multiply
template<typename Scalar>
vector<Scalar> operator*(const tensor<Scalar>& A, const vector<Scalar>& x) {
	vector<Scalar> b(A.rows());
	for (unsigned i = 0; i < A.rows(); ++i) {
		b[i] = Scalar(0);
		for (unsigned j = 0; j < A.cols(); ++j) {
			b[i] += A(i, j) * x[j];
		}
	}
	return b;
}

template<typename Scalar>
tensor<Scalar> operator*(const tensor<Scalar>& A, const tensor<Scalar>& B) {
	if (A.cols() != B.rows()) throw matmul_incompatible_matrices(incompatible_matrices(A.rows(), A.cols(), B.rows(), B.cols(), "*").what());
	unsigned rows = A.rows();
	unsigned cols = B.cols();
	unsigned dots = A.cols();
	tensor<Scalar> C(rows, cols);
	for (unsigned i = 0; i < rows; ++i) {
		for (unsigned j = 0; j < cols; ++j) {
			Scalar e = Scalar(0);
			for (unsigned k = 0; k < dots; ++k) {
				e += A(i, k) * B(k, j);
			}
			C(i, j) = e;
		}
	}
	return C;
}




template<typename Scalar>
tensor<Scalar> operator%(const tensor<Scalar>& A, const tensor<Scalar>& B) {
	// Hadamard Product A.*B.  Element-wise multiplication.
	if (A.size() != B.size()) throw matmul_incompatible_matrices(incompatible_matrices(A.rows(), A.cols(), B.rows(), B.cols(), "%").what());
	unsigned rows = A.rows();
	unsigned cols = A.cols();
	 
	tensor<Scalar> C(rows, cols);
	for (unsigned i = 0; i < rows; ++i) {
		for (unsigned j = 0; j < cols; ++j) {
			C(i, j) = A(i, j) * B(i, j);
		}
	}
	return C;
}


// tensor equivalence tests
template<typename Scalar>
bool operator==(const tensor<Scalar>& A, const tensor<Scalar>& B) {
	if (num_rows(A) != num_rows(B) ||
		num_cols(A) != num_cols(B)) return false;
	bool equal = true;
	for (unsigned i = 0; i < num_rows(A); ++i) {
		for (unsigned j = 0; j < num_cols(A); ++j) {
			if (A(i, j) != B(i, j)) {
				equal = false;
				break;
			}
		}
		if (!equal) break;
	}
	return equal;
}



template<typename Scalar>
bool operator!=(const tensor<Scalar>& A, const tensor<Scalar>& B) {
	return !(A == B);
}



// Matrix > x ==> Matrix with 1/0 representing True/False
template<typename Scalar>
tensor<Scalar> operator>(const tensor<Scalar>& A, const Scalar& x) {
	tensor<Scalar> B(A.cols(), A.rows());
	
	for (unsigned i = 0; i < num_rows(A); ++i) {
		for (unsigned j = 0; j < num_cols(A); ++j) {
			B(i,j) = (A(i, j) > x) ? 1 : 0;
		}
	}
	return B;
}
 

// maxelement (jq 2022-10-15)
template<typename Scalar>
Scalar maxelement(const tensor<Scalar>&A) {
	auto x = abs(A(0,0));
	for (size_t i = 0; i < num_rows(A); ++i) {
		for (size_t j = 0; j < num_cols(A); ++j) {
			x = (abs(A(i, j)) > x) ? abs(A(i, j)) : x;
		}
	}
	return x;
}

// minelement (jq 2022-10-15)
template<typename Scalar>
Scalar minelement(const tensor<Scalar>&A) {
	auto x = maxelement(A);
	for (size_t i = 0; i < num_rows(A); ++i) {
		for (size_t j = 0; j < num_cols(A); ++j) {
			x = ((abs(A(i, j)) < x) && (A(i,j)!=0)) ? abs(A(i, j)) : x;
		}
	}
	return x;
}


// Gets the ith row of tensor A
template<typename Scalar>
vector<Scalar> getRow(unsigned i, const tensor<Scalar>&A) {
	vector<Scalar> x(num_cols(A));
	for (size_t j = 0; j < num_cols(A); ++j) {
		x(j) = A(i,j);
		}
	return x;
}

// Gets the jth column of tensor A
template<typename Scalar>
vector<Scalar> getCol(unsigned j, const tensor<Scalar>&A) {
	vector<Scalar> x(num_rows(A));
	for (size_t i = 0; i < num_rows(A); ++i) {
		x(i) = A(i,j);
		}
	return x;
}


// Display Matrix
template<typename Scalar>
void disp(const tensor<Scalar>& A, const size_t COLWIDTH = 10){
    for (size_t i = 0;i<num_rows(A);++i){
        for (size_t j = 0; j<num_cols(A);++j){
            // std::cout <<std::setw(COLWIDTH) << A(i,j) << std::setw(COLWIDTH) << "\t" << std::fixed;
			std::cout << "\t" << A(i,j) << "\t"; // << std::fixed;
        }
        std::cout << "\n";
    }
    std::cout << "\n" << std::endl;
}


}}} // namespace sw::universal::blas
