/*
 * Copyright (c) 2021, Jeffrey Lee
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met: 
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */
#ifndef VECTOR_H
#define VECTOR_H

#include <type_traits>

template<typename T,int N> class vector
{
private:
	T v[N];

public:
	static const int size_const = N;
	typedef T value_type;

	typedef struct {
		typeof(T() > 0) b[N];
	} bool_type;

	inline int size() const { return N; }
	inline T *ptr() { return v; }
	inline T const *ptr() const { return v; }

	T &operator[](size_t i) { return v[i]; }
	T const &operator[](size_t i) const { return v[i]; }

	inline void set_elem(const int i,float f) { v[i] = f; }

	inline vector() {}

	template<typename T2,int N2>
	inline vector(vector<T2,N2> const& other)
	{
		for(int i=0;i<N;i++)
		{
			v[i] = (i<N2?other[i]:0);
		}
	}

	inline vector(T val)
	{
		for(int i=0;i<N;i++)
		{
			v[i] = val;
		}
	}

//	template<class = typename std::enable_if<N==2>::type>
	inline vector(T x, T y)
	{
		v[0] = x;
		v[1] = y;
	}

//	template<class = typename std::enable_if<N==3>::type>
	inline vector(T x, T y, T z)
	{
		v[0] = x;
		v[1] = y;
		v[2] = z;
	}

//	template<class = typename std::enable_if<N==4>::type>
	inline vector(T x, T y, T z, T w)
	{
		v[0] = x;
		v[1] = y;
		v[2] = z;
		v[3] = w;
	}

	/* Vector math */

	inline vector operator +(const vector &b) const
	{
		vector c;
		for(int i=0;i<N;i++)
		{
			c.v[i] = v[i] + b.v[i];
		}
		return c;
	}

	inline vector operator -(const vector &b) const
	{
		vector c;
		for(int i=0;i<N;i++)
		{
			c.v[i] = v[i] - b.v[i];
		}
		return c;
	}

	inline vector operator *(const vector &b) const
	{
		vector c;
		for(int i=0;i<N;i++)
		{
			c.v[i] = v[i] * b.v[i];
		}
		return c;
	}

	inline vector operator /(const vector &b) const
	{
		vector c;
		for(int i=0;i<N;i++)
		{
			c.v[i] = v[i] / b.v[i];
		}
		return c;
	}

	inline vector& operator +=(const vector &b)
	{
		for(int i=0;i<N;i++)
		{
			v[i] += b.v[i];
		}
		return *this;
	}

	inline vector& operator -=(const vector &b)
	{
		for(int i=0;i<N;i++)
		{
			v[i] -= b.v[i];
		}
		return *this;
	}

	inline vector& operator *=(const vector &b)
	{
		for(int i=0;i<N;i++)
		{
			v[i] *= b.v[i];
		}
		return *this;
	}

	inline vector& operator /=(const vector &b)
	{
		for(int i=0;i<N;i++)
		{
			v[i] /= b.v[i];
		}
		return *this;
	}

	/* Scalar math */

	inline vector operator +(T b) const
	{
		vector c;
		for(int i=0;i<N;i++)
		{
			c.v[i] = v[i] + b;
		}
		return c;
	}

	inline vector operator -(T b) const
	{
		vector c;
		for(int i=0;i<N;i++)
		{
			c.v[i] = v[i] - b;
		}
		return c;
	}

	inline vector operator *(T b) const
	{
		vector c;
		for(int i=0;i<N;i++)
		{
			c.v[i] = v[i] * b;
		}
		return c;
	}

	inline vector operator /(T b) const
	{
		vector c;
		for(int i=0;i<N;i++)
		{
			c.v[i] = v[i] / b;
		}
		return c;
	}

	inline vector& operator +=(T b)
	{
		for(int i=0;i<N;i++)
		{
			v[i] += b;
		}
		return *this;
	}

	inline vector& operator -=(T b)
	{
		for(int i=0;i<N;i++)
		{
			v[i] -= b;
		}
		return *this;
	}

	inline vector& operator *=(T b)
	{
		for(int i=0;i<N;i++)
		{
			v[i] *= b;
		}
		return *this;
	}

	inline vector& operator /=(T b)
	{
		for(int i=0;i<N;i++)
		{
			v[i] /= b;
		}
		return *this;
	}

	/* Other operators */

	inline vector operator-() const
	{
		vector c;
		for(int i=0;i<N;i++)
		{
			c.v[i] = -v[i];
		}
		return c;
	}

	inline bool_type operator<(const vector &b) const
	{
		bool_type c;
		for(int i=0;i<N;i++)
		{
			c.b[i] = v[i] < b.v[i];
		}
		return c;
	}

	inline bool_type operator<(const T b) const
	{
		bool_type c;
		for(int i=0;i<N;i++)
		{
			c.b[i] = v[i] < b;
		}
		return c;
	}

	inline bool_type operator>(const vector &b) const
	{
		bool_type c;
		for(int i=0;i<N;i++)
		{
			c.b[i] = v[i] > b.v[i];
		}
		return c;
	}

	inline bool_type operator>(const T b) const
	{
		bool_type c;
		for(int i=0;i<N;i++)
		{
			c.b[i] = v[i] > b;
		}
		return c;
	}

	inline bool_type operator<=(const vector &b) const
	{
		bool_type c;
		for(int i=0;i<N;i++)
		{
			c.b[i] = v[i] <= b.v[i];
		}
		return c;
	}

	inline bool_type operator<=(const T b) const
	{
		bool_type c;
		for(int i=0;i<N;i++)
		{
			c.b[i] = v[i] <= b;
		}
		return c;
	}

	inline bool_type operator>=(const vector &b) const
	{
		bool_type c;
		for(int i=0;i<N;i++)
		{
			c.b[i] = v[i] >= b.v[i];
		}
		return c;
	}

	inline bool_type operator>=(const T b) const
	{
		bool_type c;
		for(int i=0;i<N;i++)
		{
			c.b[i] = v[i] < b;
		}
		return c;
	}

	/* Utility */

	template<typename T2,int N2>
	inline vector<T,N2> swizzle(T2 const swiz[N2]) const
	{
		vector<T,N2> c;
		for(int i=0;i<N2;i++)
		{
			c[i] = v[swiz[i]];
		}
		return c;
	}
};

/* Return slice from a vector. Ideally these would be member functions, but that requires invocations to use the "template" keyword, e.g. "foo = bar.template slice<0,3>()", to avoid the compiler confusing the template arg list as operator< */
template<int O,int N,typename V,class = typename std::enable_if<O+N<=V::size_const>::type>
inline vector<typename V::value_type, N> const &slice(const V &v)
{
	return *((const vector<typename V::value_type, N>*)(v.ptr()+O));
}

template<int O,int N,typename V,class = typename std::enable_if<O+N<=V::size_const>::type>
inline vector<typename V::value_type, N> &slice(V &v)
{
	return *((vector<typename V::value_type, N>*)(v.ptr()+O));
}

/* vector-scalar min/max */
template<typename V>
inline V min(V const &a,typename V::value_type const &b)
{
	V ret;
	for(int i=0;i<V::size_const;i++)
	{
		ret[i] = min(a[i],b);
	}
	return ret;
}

template<typename V>
inline V max(V const &a,typename V::value_type const &b)
{
	V ret;
	for(int i=0;i<V::size_const;i++)
	{
		ret[i] = max(a[i],b);
	}
	return ret;
}

/* vector-scalar MLA/MLS */
template<typename V>
inline V mla(V const &a,V const &b,typename V::value_type const c)
{
	V ret;
	for(int i=0;i<V::size_const;i++)
	{
		ret[i] = mla(a[i],b[i],c);
	}
	return ret;
}

template<typename V>
inline V mls(V const &a,V const &b,typename V::value_type const c)
{
	V ret;
	for(int i=0;i<V::size_const;i++)
	{
		ret[i] = mls(a[i],b[i],c);
	}
	return ret;
}

template<typename V>
static inline typename V::value_type length(const V &vec)
{
	return SQRT(dot(vec,vec));
}

template<typename V>
static inline typename V::value_type inv_length(const V &vec)
{
	return inversesqrt(dot(vec,vec));
}

template<typename V>
static inline V floor(const V &vec)
{
	V ret;
	for(int i=0;i<V::size_const;i++)
	{
		ret[i] = floor(vec[i]);
	}
	return ret;
}

template<typename V>
static inline V floor_tozero(const V &vec)
{
	V ret;
	for(int i=0;i<V::size_const;i++)
	{
		ret[i] = floor_tozero(vec[i]);
	}
	return ret;
}

template<typename V>
static inline V fract(const V &vec)
{
	V ret;
	for(int i=0;i<V::size_const;i++)
	{
		ret[i] = fract(vec[i]);
	}
	return ret;
}

/* For mixed SOA & scalar, it may be necessary to ensure SOA is the first arg */
template<typename V,typename V2>
static inline typename V::value_type dot(const V &a,const V2 &b)
{
	typename V::value_type c = a[0] * b[0];
	for(int i=1;i<V::size_const;i++)
	{
		c = mla(c, a[i], b[i]);
	}
	return c;
}

template<typename T>
static inline vector<T,3> cross(const vector<T,3> &a,const vector<T,3> &b)
{
	vector<T,3> c;
	c[0] = a[1]*b[2] - a[2]*b[1];
	c[1] = a[2]*b[0] - a[0]*b[2];
	c[2] = a[0]*b[1] - a[1]*b[0];
	return c;
}

template<typename V>
static inline V abs(const V &a)
{
	V f;
	for(int i=0;i<V::size_const;i++)
	{
		f[i] = abs(a[i]);
	}
	return f;
}

template<typename V>
static inline V SQRT(const V &a)
{
	V f;
	for(int i=0;i<V::size_const;i++)
	{
		f[i] = SQRT(a[i]);
	}
	return f;
}

template<typename V>
static inline V inversesqrt(const V &a)
{
	V f;
	for(int i=0;i<V::size_const;i++)
	{
		f[i] = inversesqrt(a[i]);
	}
	return f;
}

template<typename V,class = typename std::enable_if<V::size_const!=0>::type>
static inline V recp(const V &a)
{
	V f;
	for(int i=0;i<V::size_const;i++)
	{
		f[i] = recp(a[i]);
	}
	return f;
}

template<typename V>
static inline V select(const typename V::bool_type &a,const V &b,const V &c)
{
	V f;
	for(int i=0;i<V::size_const;i++)
	{
		f[i] = select(a.b[i],b[i],c[i]);
	}
	return f;
}

template<typename V>
static inline V normalise(const V &v)
{
	return v * inv_length(v);
}

#endif
