(** Assembly output writer. *)

(*
    il4c  --  Compiler for the IL4 Lisp-ahtava langauge
    Copyright (C) 2007 Jere Sanisalo

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*)

include Imbc

let file_header = [
	";";
	"; This file is generated by the il4c compiler";
	"; Designed to be compiled with NASM";
	";";
	"";
	"\tUSE32";
	"\tCPU 486";
	"";
	"\tGLOBAL WinMainCRTStartup";
	"\tEXTERN _ExitProcess@4";
	]

let section_code = [
	";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;";
	"\tSECTION .code";
	]

let vm_main_code = [
	"; VM start";
	"; Receives the address of the bytecode in EBX (which must be preserved)";
	"; On return, the return value is in EAX";
	"VmStart:";
	"\tpush ebp";
	"\tpush esi";
	"\tpush edi";
	"\tpush ebx";
	"\tmov ebp, esp";
	"\tmov esi, ebx";
	"\tinc esi\t\t; Skip the number of parameters";
	"\tlodsb\t\t; Get the number of local stack entries";
	"\tmovzx eax, al";
	"\tshl eax, 2";
	"\tsub esp, eax";
	"";
	"\tmov ebx, VmMain\t\t; Keep the VM start address in EBX";
	"VmMain:\t; -- VM main loop --";
	"\tlodsb";
	"\tmovzx eax,al";
	"\tmov ax, [OpcodeTable + eax * 2]";
	"\tadd eax, ebx";
	"\tjmp eax";
	]

let vm_float_compare_code = [
	"\t; Floating point test.";
	"\t; DL is the set of floating point flags to test.";
	"VmFloatTest:";
	"\tfld dword [esp]";
	"\tfcomp dword [esp+4]";
	"\tfnstsw ax";
  	"\ttest ah, dl";
	"\tsetnp al";
	"\tmovzx eax, al";
	"\tpop edx";
	"\tmov [esp], eax";
	"\tjmp ebx";
	]

let section_data = [
	";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;";
	"\tSECTION .data";
	]

let section_data_idx idx = [
	";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;";
	Printf.sprintf "\tSECTION .data%d" idx;
	]

let section_bss = [
	";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;";
	"\tSECTION .bss";
	]

(** Generates code for one opcode. *)
let gen_opcode_asm imbc opcode =
	(* Returns an index for the opcode. *)
	let get_opcode_idx op =
		Util.list_find_idx (fun o -> opcode_eq o op) imbc.imbc_opcodes
	in

	(* Generate the code. *)
	let opcode_code =
		match opcode with
		| Label _ ->
			failwith "Label is not a valid assembly opcode!"
		| Jump _ ->
			["\tlodsw";
			"\tmovsx eax,ax";
			"\tadd esi, eax"]
		| JumpIfNot _ ->
			["\tlodsw";
			"\tmovsx eax,ax";
			"";
			"\t; Test the value; jump in the bytecode if it's zero";
			"\tpop ecx";
			"\tor ecx, ecx";
			"\tjnz VmMain";
			"\tadd esi, eax"]
		| PushLocal _ ->
			["\tlodsb";
			".fixed_jp:";
			"\tmovsx eax,al";
			"\tpush dword [ebp + eax * 4]"]
		| PushPresetGlobal _ ->
			["\tlodsb";
			".fixed_jp:";
			"\tmovzx eax,al";
			"\tpush dword [Globals_Preset + eax * 4]"]
		| PushUninitGlobal _ ->
			["\tlodsb";
			".fixed_jp:";
			"\tmovzx eax,al";
			"\tpush dword [Globals_Unset + eax * 4]"]
		| PushConstant _ ->
			["\tlodsb";
			".fixed_jp:";
			"\tmovzx eax,al";
			"\tpush dword [Constants + eax * 4]"]
		| PushConstantByte _ ->
			["\tlodsb";
			".fixed_jp:";
			"\tmovsx eax,al";
			"\tpush eax"]
		| Pop ->
			["\tpop eax"]
		| StoreLocal _ ->
			["\tlodsb";
			".fixed_jp:";
			"\tmovsx eax,al";
			"\tpop edx";
			"\tpush edx";
			"\tmov dword [ebp + eax * 4], edx"]
		| StorePresetGlobal _ ->
			["\tlodsb";
			"\tmovzx eax,al";
			"\tpop edx";
			"\tpush edx";
			"\tmov dword [Globals_Preset + eax * 4], edx"]
		| StoreUninitGlobal _ ->
			["\tlodsb";
			".fixed_jp:";
			"\tmovzx eax,al";
			"\tpop edx";
			"\tpush edx";
			"\tmov dword [Globals_Unset + eax * 4], edx"]
		| Return ->
			["\tpop eax\t\t; Get the return value";
			"\tmov esp,ebp";
			"\tpop ebx";
			"\tpop edi";
			"\tpop esi";
			"\tpop ebp";
			"\tret"]
		| Call _ ->
			["\t; Get the address of the bytecode data to EBX and call the VM again";
			"\tlodsb";
			".fixed_jp:";
			"\tmovzx ebx, al";
			"\tmov bx, [BytecodeFunTable + ebx * 2]";
			"\tadd ebx, Bytecode";
			"\tcall VmStart";
			"";
			"\t; Pop the arguments";
			"\tmov cl, [ebx]";
			"\tmovzx ecx, cl";
			"\tshl ecx, 2";
			"\tadd esp, ecx";
			"";
			"\t; Push the return value (in EAX)";
			"\tpush eax";
			"";
			"\tmov ebx, VmMain\t\t; Restore EBX (the VM address)"]
		| Call_Cdecl _ ->
			["; Get the address and call the function";
			"\tpop eax";
			"\tcall eax";
			"\tmov ecx, eax\t\t; Store the answer";
			"";
			"; Pop the arguments";
			"\tlodsb";
			"\tmovzx eax, al";
			"\tshl eax, 2";
			"\tadd esp,eax";
			"";
			"\t; Push the return value (in ECX)";
			"\tpush eax"]
		| Call_CdeclFP _ ->
			["\t; Get the address and call the function";
			"\tpop eax";
			"\tcall eax";
			"";
			"\t; Pop the arguments";
			"\tlodsb";
			"\tmovzx eax, al";
			"\tshl eax, 2";
			"\tadd esp,eax";
			"";
			"\t; Push the return value (in ST0)";
			"\tpush eax\t\t; Just a dummy value; will be overridden";
			"\tfstp dword [esp]"]
		| Call_Stdcall ->
			["\t; Get the address and call the function";
			"\tpop eax";
			"\tcall eax";
			"";
			"\t; Push the return value (in EAX)";
			"\tpush eax"]
		| Call_StdcallFP ->
			["\t; Get the address and call the function";
			"\tpop eax";
			"\tcall eax";
			"";
			"\t; Push the return value (in ST0)";
			"\tpush eax\t\t; Just a dummy value; will be overridden";
			"\tfstp dword [esp]"]
		| AsmFun (_,_,code,_) ->
			List.map (fun l -> "\t" ^ l) code
		| FixedPushLocal v ->
			[Printf.sprintf "\tmov al, %d" v;
			Printf.sprintf "\tjmp Opcode_%d.fixed_jp" (get_opcode_idx (PushLocal 0))]
		| FixedPushPresetGlobal v ->
			[Printf.sprintf "\tmov al, %d" v;
			Printf.sprintf "\tjmp Opcode_%d.fixed_jp" (get_opcode_idx (PushPresetGlobal 0))]
		| FixedPushUninitGlobal v ->
			[Printf.sprintf "\tmov al, %d" v;
			Printf.sprintf "\tjmp Opcode_%d.fixed_jp" (get_opcode_idx (PushUninitGlobal 0))]
		| FixedPushConstant v ->
			[Printf.sprintf "\tmov al, %d" v;
			Printf.sprintf "\tjmp Opcode_%d.fixed_jp" (get_opcode_idx (PushConstant 0))]
		| FixedPushConstantByte v ->
			[Printf.sprintf "\tmov al, %d" v;
			Printf.sprintf "\tjmp Opcode_%d.fixed_jp" (get_opcode_idx (PushConstantByte 0))]
		| FixedStoreLocal v ->
			[Printf.sprintf "\tmov al, %d" v;
			Printf.sprintf "\tjmp Opcode_%d.fixed_jp" (get_opcode_idx (StoreLocal 0))]
		| FixedStorePresetGlobal v ->
			[Printf.sprintf "\tmov al, %d" v;
			Printf.sprintf "\tjmp Opcode_%d.fixed_jp" (get_opcode_idx (StorePresetGlobal 0))]
		| FixedStoreUninitGlobal v ->
			[Printf.sprintf "\tmov al, %d" v;
			Printf.sprintf "\tjmp Opcode_%d.fixed_jp" (get_opcode_idx (StoreUninitGlobal 0))]
		| FixedCall fn ->
			let fidx = Util.list_find_idx (fun (n,_) -> n = fn) imbc.imbc_funs in
			[Printf.sprintf "\tmov al, %d" fidx;
			Printf.sprintf "\tjmp Opcode_%d.fixed_jp" (get_opcode_idx (Call ""))]
	in
	let end_code =
		match opcode with
		| Return -> []
		| AsmFun (_,_,_,attrs) when Program.has_attr attrs Program.FunAttr_UseVmFloatTest -> []

		| FixedPushLocal _
		| FixedPushPresetGlobal _
		| FixedPushUninitGlobal _
		| FixedPushConstant _
		| FixedPushConstantByte _
		| FixedStoreLocal _
		| FixedStorePresetGlobal _
		| FixedStoreUninitGlobal _
		| FixedCall _ -> []
		| _ -> ["\tjmp ebx"]
	in
	List.flatten [[Printf.sprintf "Opcode_%d: ; ****************** %s ******************" (get_opcode_idx opcode) (string_of_opcode opcode)]; opcode_code; end_code]

(** Generates bytecode for a function. *)
let gen_bytecode imbc f fidx =
	let make_code opcode =
		let intro_line =
			match opcode with
			| Label l -> Printf.sprintf "BCLabel_Fun%d_%s:" fidx l
			| o ->
				let v = Util.list_find_idx (fun o2 -> opcode_eq o o2) imbc.imbc_opcodes in
				Printf.sprintf "\tdb %d\t\t; %s" v (string_of_opcode o)
		in
		let data =
			match opcode with
			| Jump s -> [Printf.sprintf "\tdw BCLabel_Fun%d_%s - $ - 2" fidx s]
			| JumpIfNot s -> [Printf.sprintf "\tdw BCLabel_Fun%d_%s - $ - 2" fidx s]
			| PushLocal v -> [Printf.sprintf "\tdb %d" v]
			| PushPresetGlobal v -> [Printf.sprintf "\tdb %d" v]
			| PushUninitGlobal v -> [Printf.sprintf "\tdb %d" v]
			| PushConstant v -> [Printf.sprintf "\tdb %d" v]
			| PushConstantByte v -> [Printf.sprintf "\tdb %d" v]
			| StoreLocal v -> [Printf.sprintf "\tdb %d" v]
			| StorePresetGlobal v -> [Printf.sprintf "\tdb %d" v]
			| StoreUninitGlobal v -> [Printf.sprintf "\tdb %d" v]
			| Call s ->
				let v = Util.list_find_idx (fun (n,_) -> n = s) imbc.imbc_funs in
				[Printf.sprintf "\tdb %d" v]
			| Call_Cdecl v -> [Printf.sprintf "\tdb %d" v]
			| Call_CdeclFP v -> [Printf.sprintf "\tdb %d" v]
			| _ -> []
		in
		intro_line :: data
	in

	let str_params = Printf.sprintf "\tdb %d\t\t; Number of parameters" f.ifun_params in
	let str_stack = Printf.sprintf "\tdb %d\t\t; Number of stack entries" f.ifun_stack_entries in

	let code = (List.flatten (List.map make_code f.ifun_code)) in
	let info = Printf.sprintf "\t; Total of %d bytes" ((List.length code) + 2) in
	info :: str_params :: str_stack :: code

(** Generates & collects the code lines into a single line list. *)
let collect_code imbc =
	let externs =
		List.map (fun s -> Printf.sprintf "\tEXTERN %s" s) imbc.imbc_externals
	in

	let code_start =
		["WinMainCRTStartup:";
		"\tfinit";
		Printf.sprintf "\tmov ebx, BytecodeFun_%d" (Util.list_find_idx (fun (n,_) -> n = "main") imbc.imbc_funs);
		"\tcall VmStart";
		"";
		"\tpush 0";
		"\tcall _ExitProcess@4"]
	in

	let opcode_code =
		let code = List.map (fun op -> gen_opcode_asm imbc op) imbc.imbc_opcodes in
		List.flatten [["\t; -- Opcode code --"]; List.flatten code]
	in

	let vm_extra_code =
		(* Uses the VM float test? *)
		let code_vm_float_test =
			let test_op opcode =
				match opcode with
				| AsmFun (_,_,_,attrs) when Program.has_attr attrs Program.FunAttr_UseVmFloatTest -> true
				| _ -> false
			in
			if (List.exists test_op imbc.imbc_opcodes) then
				vm_float_compare_code
			else
				[]
		in

		(* Collect the extra code. *)
		List.flatten [code_vm_float_test]
	in

	let opcode_table =
		let opcode_lst =
			Util.list_mapi (fun idx op -> Printf.sprintf "\tdw Opcode_%d-VmMain\t\t; %s" idx (string_of_opcode op)) imbc.imbc_opcodes
		in
		List.flatten [["OpcodeTable:"]; opcode_lst; [""]]
	in

	let bytecode_fun_table =
		"BytecodeFunTable:" :: (Util.list_mapi (fun idx (n,_) -> Printf.sprintf "\tdw BytecodeFun_%d-Bytecode\t; '%s'" idx n) imbc.imbc_funs)
	in

	let bytecode_funcs =
		let do_fun idx (n,f) =
			((Printf.sprintf "BytecodeFun_%d: ; **** Function '%s' ****" idx n) :: (gen_bytecode imbc f idx)) @ [""]
		in
		"Bytecode:" :: (List.flatten (Util.list_mapi do_fun imbc.imbc_funs))
	in

	let gen_ivalue_data ival =
		match ival with
		| Val_Int v -> Printf.sprintf "\tdd %ld\t\t; 0x%lx" v v
		| Val_Float v -> Printf.sprintf "\tdd %ld\t\t; 0x%lx, %f" (Int32.bits_of_float v) (Int32.bits_of_float v) v
		| Val_String_Const v -> "\tdd String_Const_" ^ (string_of_int v)
		| Val_External v -> "\tdd " ^ v
	in

	let constants =
		"Constants:" :: (List.map gen_ivalue_data imbc.imbc_constants)
	in

	let string_constants =
		let conv_str idx s =
			(Printf.sprintf "String_Const_%d:" idx) :: (List.map (fun c -> Printf.sprintf "\tdb %d" (int_of_char c)) (Util.string_to_chars s))
		in
		"String_Constants:" :: (List.flatten (Util.list_mapi conv_str imbc.imbc_strings))
	in

	let preset_globals =
		let conv_glob idx (name,v) =
			[Printf.sprintf "\t; #%d. '%s'" idx name;
			gen_ivalue_data v]
		in
		let globs = Util.list_mapi conv_glob imbc.imbc_preset_globals
		in
		List.flatten [["Globals_Preset:"]; List.flatten globs]
	in

	let heap_ptr =
		["HeapPtr:";
		"\tdd HeapMemory"]
	in

	let unset_globals =
		["Globals_Unset:";
		Printf.sprintf "\tresd %d" (List.length imbc.imbc_uninit_globals)]
	in

	let heap_memory =
		["HeapMemory:";
		Printf.sprintf "\tresb %d" imbc.imbc_heapsize]
	in

	(* Collect all the file pieces together. *)
	List.flatten
		[file_header;
		externs;
		[""];

		section_code;
		[""];
		code_start;
		[""];
		vm_main_code;
		opcode_code;
		vm_extra_code;
		[""];

		section_data;
		[""];
		opcode_table;
		[""];
		bytecode_fun_table;
		[""];
		bytecode_funcs;
		[""];
		constants;
		[""];
		preset_globals;
		[""];
		string_constants;
		[""];
		heap_ptr;
		[""];

		section_bss;
		[""];
		unset_globals;
		[""];
		heap_memory]

(** Writes the assembly file. *)
let write_asm imbc fn =
	let code = collect_code imbc in
	Util.write_file_lines fn code
