#include <stdlib.h>
#include <string.h>
#include <math.h>
#include "mat4.h"
#include "etc.h"

//!Matrices are expected to be column-major.


//creates an identity matrix and return its pointer.
//the matrix is essentially an array of 16 floats.
//
// 1 0 0 0
// 0 1 0 0
// 0 0 1 0
// 0 0 0 1
//
float * mat4_identity(void)
{
	float *matrix = malloc(sizeof(float) * 16);

	int pos = 0;
	for (int i = 0; i < 16; i++)
	{
		if (i == pos)
		{
			matrix[i] = 1;	
			pos += 5;
			continue;
		}
		matrix[i] = 0;
	}
	return matrix;
}

//returns a projectio matrix.
float * mat4_projection(float fov, float z_near, float z_far, float aspect_ratio)
{
	//float D2R = M_PI / 180.0;
	//float frustrum_height = 1.0 / tan(D2R * fov / 2);

	float *matrix = malloc(sizeof(float) * 16);

	//don't forget to set all to 0, to avoid having the matrix fail randomly for
	//apparently no reason
	memset(matrix, 0, sizeof(float) * 16);

	float frustrum_height = 1.0f / tanf(fov * M_PI / 360);
	float frustrum_width = frustrum_height / aspect_ratio;
	float near_m_far = z_near - z_far;

	matrix[0] = frustrum_width;
	matrix[5] = frustrum_height;
	matrix[10] = (z_far + z_near) / near_m_far;
	matrix[11] = -1.0f;
	matrix[14] = (2 * z_far * z_near) / near_m_far;

	return matrix;
}


//translates a matrix with the given vector.
//
// .  .  .  +x
// .  .  .  +y
// .  .  .  +z
// .  .  .  .
//
void mat4_translate(float *matrix, struct vec3 direction)
{
	matrix[12] += direction.x;
	matrix[13] += direction.y;
	matrix[14] += direction.z;
}

void mat4_position(float *matrix, struct vec3 position)
{
	matrix[12] = position.x;
	matrix[13] = position.y;
	matrix[14] = position.z;
}


//multiplies two matrices together and outputs the result in a third matrix.
void mat4_multiply(float *matrix_a, float *matrix_b, float *matrix_out)
{
	matrix_out[0] = matrix_a[0] * matrix_b[0] + matrix_a[4] * matrix_b[1] + matrix_a[8] * matrix_b[2] + matrix_a[12] * matrix_b[3];
	matrix_out[1] = matrix_a[1] * matrix_b[0] + matrix_a[5] * matrix_b[1] + matrix_a[9] * matrix_b[2] + matrix_a[13] * matrix_b[3];
	matrix_out[2] = matrix_a[2] * matrix_b[0] + matrix_a[6] * matrix_b[1] + matrix_a[10] * matrix_b[2] + matrix_a[14] * matrix_b[3];
	matrix_out[3] = matrix_a[3] * matrix_b[0] + matrix_a[7] * matrix_b[1] + matrix_a[11] * matrix_b[2] + matrix_a[15] * matrix_b[3];
	matrix_out[4] = matrix_a[0] * matrix_b[4] + matrix_a[4] * matrix_b[5] + matrix_a[8] * matrix_b[6] + matrix_a[12] * matrix_b[7];
	matrix_out[5] = matrix_a[1] * matrix_b[4] + matrix_a[5] * matrix_b[5] + matrix_a[9] * matrix_b[6] + matrix_a[13] * matrix_b[7];
	matrix_out[6] = matrix_a[2] * matrix_b[4] + matrix_a[6] * matrix_b[5] + matrix_a[10] * matrix_b[6] + matrix_a[14] * matrix_b[7];
	matrix_out[7] = matrix_a[3] * matrix_b[4] + matrix_a[7] * matrix_b[5] + matrix_a[11] * matrix_b[6] + matrix_a[15] * matrix_b[7];
	matrix_out[8] = matrix_a[0] * matrix_b[8] + matrix_a[4] * matrix_b[9] + matrix_a[8] * matrix_b[10] + matrix_a[12] * matrix_b[11];
	matrix_out[9] = matrix_a[1] * matrix_b[8] + matrix_a[5] * matrix_b[9] + matrix_a[9] * matrix_b[10] + matrix_a[13] * matrix_b[11];
	matrix_out[10] = matrix_a[2] * matrix_b[8] + matrix_a[6] * matrix_b[9] + matrix_a[10] * matrix_b[10] + matrix_a[14] * matrix_b[11];   
	matrix_out[11] = matrix_a[3] * matrix_b[8] + matrix_a[7] * matrix_b[9] + matrix_a[11] * matrix_b[10] + matrix_a[15] * matrix_b[11];   
	matrix_out[12] = matrix_a[0] * matrix_b[12] + matrix_a[4] * matrix_b[13] + matrix_a[8] * matrix_b[14] + matrix_a[12] * matrix_b[15];  
	matrix_out[13] = matrix_a[1] * matrix_b[12] + matrix_a[5] * matrix_b[13] + matrix_a[9] * matrix_b[14] + matrix_a[13] * matrix_b[15];  
	matrix_out[14] = matrix_a[2] * matrix_b[12] + matrix_a[6] * matrix_b[13] + matrix_a[10] * matrix_b[14] + matrix_a[14] * matrix_b[15]; 
	matrix_out[15] = matrix_a[3] * matrix_b[12] + matrix_a[7] * matrix_b[13] + matrix_a[11] * matrix_b[14] + matrix_a[15] * matrix_b[15];
	//chunky
}


//updates the view matrix.
void mat4_look_at(float *matrix, struct vec3 position, struct vec3 target, struct vec3 up_dir)
{
	struct vec3 front = vec3_normalize(vec3_subtract(target,position));
	struct vec3 side = vec3_normalize(vec3_cross(front, up_dir));
	struct vec3 up = vec3_normalize(vec3_cross(side, front));

	matrix[0] = side.x;
	matrix[4] = side.y;
	matrix[8] = side.z;

	matrix[1] = up.x;
	matrix[5] = up.y;
	matrix[9] = up.z;

	matrix[2] = -front.x;
	matrix[6] = -front.y;
	matrix[10] = -front.z;

	matrix[3] = 0.0f;
	matrix[7] = 0.0f;
	matrix[11] = 0.0f;

	matrix[12] = -vec3_dot(side, position);
	matrix[13] = -vec3_dot(up, position);
	matrix[14] = vec3_dot(front, position);
	matrix[15] = 1.0f;
}


//multiplies a given vec3 with the given matrix. w is assumed to be 1.
//I don't know what I'm doing
void mat4_multiply_vec3(struct vec3 *v, float * m)
{
	float x = v->x;
	float y = v->y;
	float z = v->z;
	float w = 1;

	v->x = m[0] * x + m[4] * y + m[8] * z + m[12] * w;
	v->y = m[1] * x + m[5] * y + m[9] * z + m[13] * w;
	v->z = m[2] * x + m[6] * y + m[10] * z + m[14] * w;
}


void mat4_scale(float *m, struct vec3 v)
{
	m[0] = v.x;
	m[5] = v.y;
	m[10] = v.z;
}

//https://www.opengl-tutorial.org/assets/faq_quaternions/index.html#Q54
void mat4_from_quat(float *m, struct quat q)
{
	float xx = q.x * q.x;
	float xy = q.x * q.y;
	float xz = q.x * q.z;
	float xw = q.x * q.w;

	float yy = q.y * q.y;
	float yz = q.y * q.z;
	float yw = q.y * q.w;

	float zz = q.z * q.z;
	float zw = q.z * q.w;

	m[0]  = 1 - 2 * (yy + zz);
	m[1]  =     2 * (xy - zw);
	m[2]  =     2 * (xz + yw);
	m[3]  = 0;
	m[4]  =     2 * (xy + zw);
	m[5]  = 1 - 2 * (xx + zz);
	m[6]  =     2 * (yz - xw);
	m[7]  = 0;
	m[8]  =     2 * (xz - yw);
	m[9]  =     2 * (yz + xw);
	m[10] = 1 - 2 * (xx + yy);
	m[11] = 0;
	m[12] = 0;
	m[13] = 0;
	m[14] = 0;
	m[15] = 1;
}
