open Util
open Syntax
open Syntaxutil
open Typecheckenv
open Prettyprint
open Prettyutil
open Metadata

module SH = Data.StringCols.Hash

(* compare, unify, and substitute types *)

(* ------------------------------------------------------------
 * Type Substitution
 * ------------------------------------------------------------ *)

type subst = {binds : typ strhash; freevars : string list}

let new_subst ids : subst = 
	{binds = SH.create (); freevars = List.map id_str ids}
let nosubst : subst = new_subst []

let is_bound_var (subst : subst) id = SH.mem (id_str id) subst.binds
let subst_find (subst : subst) id = SH.find (id_str id) subst.binds
let subst_add subst id ty = SH.add (id_str id) ty subst.binds
let is_var subst id = member (id_str id) subst.freevars

let mksubst env tyvars tyargs = 
	let subst = new_subst tyvars in
	tc_iter2 env "type arguments" (subst_add subst) tyvars tyargs;
	subst

let rec check_subst_complete env subst = 
	match List.filter (fun v -> not (SH.mem v subst.binds)) subst.freevars with
		| [] -> ()
		| xs -> tyerr env (str "Some type vars left unbound:" <++>
					pprint_list str xs)

(* substitute a full type into a coretype *)
and subst_coretyp subst coretyp = match coretyp with
	| TyWild id when is_bound_var subst id -> subst_find subst id
	| _ ->	[SCoreHere],coretyp,[]	

and subst_declmod subst declmod = match declmod with
	| DFun (ArgsFull (m,tyifaces,args),va) -> 
		DFun (ArgsFull (m,tyifaces,List.map (subst_decl subst) args),va)
	| DWithArgs typs -> DWithArgs (List.map (subst_typ subst) typs) 
	| _ -> declmod


	
(* incorporate the specifiers and decltyps from the coretyp subst *)
and subst_typ subst (specs,coretyp,declmods) = 
	let newspecs,coretyp2,newmods = subst_coretyp subst coretyp in
	let decltyp = List.map (subst_declmod subst) declmods in
	insert_specs specs newspecs,coretyp2,decltyp @ newmods
		
and subst_init_decl subst extramods ((declmods,id_opt),init) =
	let declmods = List.map (subst_declmod subst) declmods in
	(declmods @ extramods,id_opt),init

and subst_decl subst (m,((specs,coretyp),init_decls)) =
	let newspecs,coretyp2,extramods = subst_coretyp subst coretyp in
	let init_decls = List.map (subst_init_decl subst extramods) init_decls in
	m,((insert_specs specs newspecs,coretyp2),init_decls)

let subst_funsig substs decl = subst_decl substs decl


(* ------------------------------------------------------------
 * Type unification
 * ------------------------------------------------------------ *)

(* TODO: handle subtyping more elegently *)
(* TODO: handle conversion between function and pointer to function
   more correctly *)


(* parentheses have no affect on a type's meaning *)
let rec strip_parenmods mods = match mods with
	| DParens::xs -> strip_parenmods xs
	| x::xs -> x :: strip_parenmods xs
	| [] -> []

let rec unify_declmod env subst modfound modwant = 
	match modfound,modwant with
	| DFun (args1,kind1), DFun (args2,kind2) ->
		unify_funkind env kind1 kind2;
		unify_funargs env subst args1 args2
	| DPtr specs1, DPtr specs2 ->
		unify_specquals env specs1 specs2
	| DPtr [], DArray _ -> ()
	| DArray _, DPtr [] -> ()
	| DFatPtr _, DFatPtr _ -> ()
	| DWithArgs args1, DWithArgs args2 ->
		tc_iter2 env "type arguments" (unify_typs env subst) args1 args2 
	| DBitField _, DBitField _ -> ()
	| _ -> error_mismatch env (str "declarator mismatch:" <++>
			pprint_declmod modfound empty <++> str "vs" <++>
			pprint_declmod modwant empty)

and unify_funkind env kindfound kindwant = match kindfound, kindwant with
	| DictMethod _, IfaceMethod -> ()
	| x,y when x = y -> ()
	| _ -> error_mismatch env (str "function kind mismatch:" <++>
			pprint_funkind kindfound <++> str "vs" <++>
			pprint_funkind kindwant)


(* unify the sets of declmods, allowing extras to flow into a Wild var *)
and unify_declmod_lists env subst modsfound modswant = 
	match modsfound,modswant with
	| DParens::xs,_ -> unify_declmod_lists env subst xs modswant
	| _,DParens::ys -> unify_declmod_lists env subst modsfound ys
	| DFun _::xs,DPtr _::ys -> unify_declmod_lists env subst modsfound ys
	| DPtr _::xs,DFun _::ys -> unify_declmod_lists env subst xs modswant
	| x::xs,y::ys -> 
			unify_declmod env subst x y; 
			unify_declmod_lists env subst xs ys
	| [],ys -> ys 
	| x::xs,[] -> error_mismatch env (str "extra declarator:" <++> 
			pprint_declmod x empty)

(* is_var check needed, to avoid substituting when doing equality check *)
and unify_coretyps env subst coretyfound tywant = 
	match coretyfound with
	| TyWild id when is_var subst id ->
			if not (is_wild_typ tywant) then set_repchanged env;
			if is_bound_var subst id then 
				same_type env (subst_find subst id) tywant
			else
				subst_add subst id tywant
	| _ when typ_mods tywant = [] && 
		coretyp_equal coretyfound (typ_core tywant) -> ()	
	| _ -> error_mismatch env (str "core type mismatch:" <++>
				pprint_coretyp coretyfound <++> str "vs" <++> 
				pprint_abstyp tywant)
		
and unify_funargs env subst args1 args2 = 
	match args1,args2 with
	| ArgsNoinfo, _ -> ()
	| ArgsFull (_,_,decls1), ArgsFull (_,_,decls2) ->
			tc_iter2 env "function arguments" 
				(unify_decl env subst) decls1 decls2
	| _ -> error_mismatch env (str "Fun argument mismatch")
	
(* TODO: properly handle subtyping on const and volatile *)
and unify_specquals env sq1 sq2 = 
	if not (same_elements sq1 sq2) then 
		tyerr env (str "Type specifier mismatch:" <++>
					pprint_seq (pprint_specifier empty) sq1 <++>
					str "vs" <++>			
					pprint_seq (pprint_specifier empty) sq2)
	
(* TODO: deal with specquals *)					
and unify_typs env subst tyfound tywant =
	let vspecs,vcore,vmods = resolve_typ env tyfound in
	let fspecs,fcore,fmods = resolve_typ env tywant in
	let extramods = unify_declmod_lists env subst vmods fmods in
	unify_coretyps env subst vcore (fspecs,fcore,extramods)
	
and unify_initdcls env subst vbase fbase vinit_dcl finit_dcl =
	let vspecs,vcore = vbase in
	let fspecs,fcore = fbase in
	let (vmods,vid_opt),_ = vinit_dcl in
	let (fmods,fid_opt),_ = finit_dcl in
	unify_typs env subst (vspecs,vcore,vmods) (fspecs,fcore,fmods) 

and unify_decl env subst (vm,(vbase,vinit_dcl)) (fm,(fbase,finit_dcl)) = 
	tc_iter2 env "declarators" 
		(unify_initdcls env subst vbase fbase) vinit_dcl finit_dcl		


(* ------------------------------------------------------------
 * Wrappers for type unification
 * ------------------------------------------------------------ *)

and generic_unify subst env expector tyfound tywant =
	let showtyp = subst_typ subst tyfound in
	let whatcheck = pprint_type_mismatch expector showtyp tywant in
	let env = with_unify env whatcheck in
	unify_typs env subst tyfound tywant

and fill_in_generic_arg_types subst env tyfound tywant = 
	let env = with_repcheck env in
	if void_matches tyfound tywant then () else begin
		generic_unify subst env (str "function") tyfound tywant;
		if get_repchanged env && is_function_typ tyfound then
			set_needscast (get_meta env) tyfound;
	end

and fill_in_generic_dict_type subst env iface tyfound tywant = 
	generic_unify subst env (str "dictionary for" <++> pprint_id iface)
		tyfound tywant

and fill_in_pattern_typ subst env tyfound tywant =
	generic_unify subst env (str "constructor") tyfound tywant

and void_matches tyfound tywant = match tyfound, tywant with
	| (_,TyVoid,[DPtr _]),(_,_,DPtr _::_) -> true
	| (_,TyVoid,[DPtr _]),(_,_,DFun _::_) -> true
	| _ -> false

(* ------------------------------------------------------------
 * Other functions relating to unification
 * ------------------------------------------------------------ *)
	
and same_type env ty_a ty_b = 
	let env = with_repcheck env in
	generic_unify nosubst env (str "context") ty_b ty_a
					   
and same_funsig env sigfound sigwant = 
	let whatcheck = pprint_funsig_mismatch sigfound sigwant in
	let env = with_unify env whatcheck in
	unify_decl env nosubst sigfound sigwant


