/*****************************************************************************

        MidiFileTrack.cpp
        Author: Laurent de Soras, 2022

--- Legal stuff ---

This program is free software. It comes without any warranty, to
the extent permitted by applicable law. You can redistribute it
and/or modify it under the terms of the Do What The Fuck You Want
To Public License, Version 2, as published by Sam Hocevar. See
http://sam.zoy.org/wtfpl/COPYING for more details.

*Tab=3***********************************************************************/



/*\\\ INCLUDE FILES \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/

#include "fstb/def.h"
#include "MidiFileTrack.h"

#include <algorithm>
#include <stdexcept>

#include <cassert>



/*\\\ PUBLIC \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/



// end_ptr sets the limit of the valid memory area provided by the caller.
// It is either the end or past the end of the chunk, or the end of the valid
// input buffer if the file was truncated for some reason.
// On output, mtrk_ptr is updates and points on the next track
// Throws: std::runtime_error if the track cannot be fully decoded.
MidiFileTrack::MidiFileTrack (const uint8_t * &mtrk_ptr, const uint8_t *end_ptr, NoteRepeat rep_strategy)
:	_rep_strategy (rep_strategy)
{
	assert (mtrk_ptr != nullptr);
	assert (end_ptr != nullptr);
	assert (mtrk_ptr < end_ptr);

	if (end_ptr - mtrk_ptr < 8)
	{
		throw std::overflow_error ("Truncated MTrk header.");
	}
	if (   mtrk_ptr [0] != uint8_t ('M')
	    || mtrk_ptr [1] != uint8_t ('T')
	    || mtrk_ptr [2] != uint8_t ('r')
	    || mtrk_ptr [3] != uint8_t ('k'))
	{
		throw std::runtime_error ("Not an MTrk chunk");
	}
	const auto     len = ptrdiff_t (
		  (int32_t (mtrk_ptr [4]) << 24)
		| (int32_t (mtrk_ptr [5]) << 16)
		| (int32_t (mtrk_ptr [6]) <<  8)
		|  int32_t (mtrk_ptr [7])
	);
	const auto     payload_ptr = mtrk_ptr + 8;
	if (end_ptr - payload_ptr < len)
	{
		throw std::overflow_error ("Truncated MTrk payload");
	}

	// Loops over events
	auto           evt_ptr         = payload_ptr;
	const auto     payload_end_ptr = payload_ptr + len;
	while (evt_ptr < payload_end_ptr)
	{
		if (_track_end_flag)
		{
			throw std::runtime_error ("Unexpected event after End of Track");
		}
		process_event (evt_ptr, payload_end_ptr);
	}

	for (const auto &chn : _chn_arr)
	{
		if (chn.is_any_note_playing ())
		{
			throw std::runtime_error ("Note on without matching Note off");
		}
	}

	mtrk_ptr = payload_end_ptr;

	// Don't check the End of Track event because a lot of otherwise valid MIDI
	// tracks don't have it.
//	assert (_track_end_flag);
}



MidiFileTrack::NoteArray	MidiFileTrack::get_notes (int chn_idx) const
{
	assert (chn_idx >= 0);
	assert (chn_idx < midi::_nbr_chn);

	return _chn_arr [chn_idx]._note_arr;
}



// Merges all channels. Notes are kept chronologically ordered.
MidiFileTrack::NoteArray	MidiFileTrack::get_notes_all_chn () const
{
	NoteArray      all;
	for (const auto &chn : _chn_arr)
	{
		all.insert (all.end (), chn._note_arr.begin (), chn._note_arr.end ());
	}

	sort_notes (all);

	return all;
}



// Quarter per minute
const MidiFileTrack::TempoMap &	MidiFileTrack::use_tempo_map () const noexcept
{
	return _tempo_map;
}



const MidiFileTrack::TimeSigMap &	MidiFileTrack::use_time_sig_map () const noexcept
{
	return _time_sig_map;
}



void	MidiFileTrack::sort_notes (NoteArray &notes) noexcept
{
	std::sort (notes.begin (), notes.end (),
		[] (const Note &lhs, const Note &rhs)
		{
			return lhs._timestamp < rhs._timestamp;
		}
	);
}



/*\\\ PROTECTED \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/



/*\\\ PRIVATE \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/



bool	MidiFileTrack::Channel::is_note_playing (uint8_t note) const noexcept
{
	assert (note < midi::_nbr_notes);

	return ! _note_alloc [note].empty ();
}



bool	MidiFileTrack::Channel::is_any_note_playing () const noexcept
{
	for (uint8_t note = 0; note < midi::_nbr_notes; ++note)
	{
		if (is_note_playing (note))
		{
			return true;
		}
	}

	return false;
}



void	MidiFileTrack::process_event (const uint8_t * &cur_ptr, const uint8_t *end_ptr)
{
	const auto     delta_time = read_varlen (cur_ptr, end_ptr);
	auto           val        = read_byte (cur_ptr, end_ptr);

	_timestamp += delta_time;

	// Non-MIDI message (meta events)
	if (val == 0xFF)
	{
		process_meta (cur_ptr, end_ptr);
	}

	// Code for SysEx continuation or escaped event
	else if (val == midi::_st_com_sysex_end)
	{
		process_sysex (cur_ptr, end_ptr, true);
	}

	// SysEx, full or beginning part
	else if (val == midi::_st_com_sysex)
	{
		process_sysex (cur_ptr, end_ptr, false);
	}

	// Other MIDI event
	else
	{
		_ms.push_byte (val);
		const auto     status = _ms.get_running_status ();
		if (status == 0)
		{
			throw std::runtime_error ("MIDI data byte without running status");
		}
		int            msg_len = MidiState::get_param_len (status);
		if (MidiState::is_status (val))
		{
			++ msg_len;
		}

		_ms.get_message (_msg);
		process_std_midi (_msg);

		for (int k = 1; k < msg_len; ++k)
		{
			const auto     param = read_byte (cur_ptr, end_ptr);
			_ms.push_byte (param);
			_ms.get_message (_msg);
			process_std_midi (_msg);
		}
	}
}



void	MidiFileTrack::process_meta (const uint8_t * &cur_ptr, const uint8_t *end_ptr)
{
	const auto     type = read_byte (cur_ptr, end_ptr);
	const auto     len  = ptrdiff_t (read_varlen (cur_ptr, end_ptr));
	if (end_ptr - cur_ptr < len)
	{
		throw std::overflow_error ("Truncated meta-event");
	}

	process_meta_content (cur_ptr, len, type);
	cur_ptr += len;
}



void	MidiFileTrack::process_meta_content (const uint8_t *cur_ptr, ptrdiff_t len, uint8_t type)
{
	const auto     end_ptr = cur_ptr + len;
	fstb::unused (end_ptr);

	switch (type)
	{
	case Meta::_trk_end:
		assert (len == 0);
		_track_end_flag = true;
		break;

	case Meta::_tempo:
		assert (len == 3);
		{
			const int   us =
				(int (cur_ptr [0]) << 16) + (int (cur_ptr [1]) << 8) + cur_ptr [2];
			if (us == 0)
			{
				throw std::runtime_error ("Invalid tempo setting");
			}
			const auto     tempo = 60 * 1e6 / us;
			_tempo_map [_timestamp] = tempo;
		}
		break;

	case Meta::_time_sig:
		assert (len == 4);
		{
			const auto     num = cur_ptr [0];
			const auto     den = cur_ptr [1];
			// Ignores MIDI clock related values
			if (num == 0 || den == 0)
			{
				throw std::runtime_error ("Invalid time signature");
			}
			_time_sig_map [_timestamp] = { num, den };
		}
		break;
	}
}



void	MidiFileTrack::process_std_midi (const ByteArray &msg)
{
	if (! msg.empty ())
	{
		const auto     status = msg.front ();
		assert (MidiState::is_status (status));
		const auto     categ  = MidiState::extract_cmd_categ (status);

		if (categ == midi::_st_chn_note_on)
		{
			assert (msg.size () == 3);
			const auto     chn_idx = MidiState::extract_channel (status);
			auto &         chn     = _chn_arr [chn_idx];
			const auto     note    = msg [1];
			const auto     velo    = msg [2];
			start_note (chn, note, velo);
		}

		else if (categ == midi::_st_chn_note_off)
		{
			assert (msg.size () == 3);
			const auto     chn_idx = MidiState::extract_channel (status);
			auto &         chn     = _chn_arr [chn_idx];
			const auto     note    = msg [1];
			const auto     velo    = msg [2];
			if (chn.is_note_playing (note))
			{
				stop_note (chn, note, velo);
			}
		}

		else if (categ == midi::_st_chn_ctrl)
		{
			assert (msg.size () == 3);
			const auto     chn_idx = MidiState::extract_channel (status);
			auto &         chn     = _chn_arr [chn_idx];
			const auto     ctrl    = msg [1];
			if (ctrl == midi::_ct_all_notes_off || ctrl == midi::_ct_all_snd_off)
			{
				stop_all_notes (chn);
			}
		}
	}
}



void	MidiFileTrack::start_note (Channel &chn, uint8_t note, uint8_t velo)
{
	assert (note < midi::_nbr_notes);
	assert (velo > 0);

	while (_rep_strategy == NoteRepeat::_cut && chn.is_note_playing (note))
	{
		stop_note (chn, note, velo);
	}

	const auto     note_idx = int (chn._note_arr.size ());
	chn._note_arr.emplace_back (Note { _timestamp, -1, note, velo });
	chn._note_alloc [note].push_back (note_idx);
}



void	MidiFileTrack::stop_note (Channel &chn, uint8_t note, uint8_t velo)
{
	fstb::unused (velo);
	assert (note < midi::_nbr_notes);
	assert (chn.is_note_playing (note));

	const auto     note_idx  = chn._note_alloc [note].front ();
	auto &         note_info = chn._note_arr [note_idx];
	note_info._duration      = _timestamp - note_info._timestamp;
	chn._note_alloc [note].erase (chn._note_alloc [note].begin ());
}



void	MidiFileTrack::stop_all_notes (Channel &chn)
{
	for (uint8_t note = 0; note < midi::_nbr_notes; ++note)
	{
		while (chn.is_note_playing (note))
		{
			stop_note (chn, note, midi::_def_velo);
		}
	}
}



void	MidiFileTrack::process_sysex (const uint8_t * &cur_ptr, const uint8_t *end_ptr, bool continue_flag)
{
	const auto     len = read_varlen (cur_ptr, end_ptr);
	if (end_ptr - cur_ptr > len)
	{
		throw std::overflow_error ("Truncated sysex or escape sequence");
	}


	/*** To do ***/
	fstb::unused (continue_flag);
	cur_ptr += len;


}



// cur_ptr is updated and points to the end of the sequence.
// If an exception is thrown, cur_ptr is kept at the beginning of the sequence.
int32_t	MidiFileTrack::read_varlen (const uint8_t * &cur_ptr, const uint8_t *end_ptr)
{
	int32_t        val      = 0;
	int            read_len = 0;
	uint8_t        b        = 0;
	uint8_t        v        = 0;
	do
	{
		if (cur_ptr >= end_ptr)
		{
			throw std::overflow_error ("Truncated variable-length quantity");
		}
		if (read_len >= 4)
		{
			throw std::runtime_error ("Too long variable-length quantity");
		}
		b   = cur_ptr [read_len];
		v   = b & 0x7F;
		val = (val << 7) | v;
		++ read_len;
	}
	while (b != v);

	cur_ptr += read_len;

	return val;
}



uint8_t	MidiFileTrack::read_byte (const uint8_t * &cur_ptr, const uint8_t *end_ptr)
{
	if (cur_ptr >= end_ptr)
	{
		throw std::overflow_error ("Too few bytes in the stream");
	}
	const auto     val = *cur_ptr;
	++ cur_ptr;

	return val;
}



/*\\\ EOF \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/
