/*
Program Purpose:
IML modules for initializing distrib macro
*/

libname _modules ".";

proc iml;

	**Load MCMC parameters from SAS datasets into IML;
	start load_parameters(multivar_mcmc_model);
	
		**extract variables;
		use (cat(multivar_mcmc_model, "_vars"));
			read all var {num_episodic} into num_episodic;
			read all var {num_daily} into num_daily;
			
			if num_episodic > 0 then do;
				read all var (cat("episodic_var", 1:num_episodic)) into episodic_variables;
			end;
			else do;
				episodic_variables = {};
			end;
			
			if num_daily > 0 then do;
				read all var (cat("daily_var", 1:num_daily)) into daily_variables;
			end;
			else do;
				daily_variables = {};
			end;
			
			read all var {has_never_consumers} into has_never_consumers;
		close (cat(multivar_mcmc_model, "_vars"));
		
		**extract covariates;
		num_variables = 2*num_episodic + num_daily;
		covariates = ListCreate(num_variables);
		intercepts = ListCreate(num_variables);
		do var_j = 1 to num_variables;
			use(cat(multivar_mcmc_model, "_covars"));
				read point(var_j) var {num_covariates} into num_covariates;
				
				if num_covariates > 0 then do;
					read point(var_j) var (cat("covariate", 1:num_covariates)) into covariates_j;
				end;
				else do;
					covariates_j = {};
				end;
				
				read point(var_j) var {intercept} into intercept_j;
			close (cat(multivar_mcmc_model, "_covars"));
			
			covariates$var_j = covariates_j;
			intercepts$var_j = intercept_j;
		end;
		
		if has_never_consumers = 1 then do;
			use (cat(multivar_mcmc_model, "_covarsnc"));
				read all var {num_covariates} into num_covariates;
				
				if num_covariates > 0 then do;
					read all var (cat("covariate", 1:num_covariates)) into never_consumer_covariates;
				end;
				else do;
					never_consumer_covariates = {};
				end;
				
				read all var {intercept} into never_consumer_intercept;
			close (cat(multivar_mcmc_model, "_covarsnc"));
		end;
		else do;
			never_consumer_covariates = {};
			never_consumer_intercept = {};
		end;
		
		**extract MCMC subjects;
		use (cat(multivar_mcmc_model, "_subjects"));
			read all var {subject} into mcmc_subjects;
		close (cat(multivar_mcmc_model, "_subjects"));
		
		**extract number of MCMC iterations;
		use (cat(multivar_mcmc_model, "_iters"));
			read all var {num_mcmc_iterations} into num_mcmc_iterations;
			read all var {num_burn} into num_burn;
			read all var {num_thin} into num_thin;
			read all var {num_post} into num_post;
		close (cat(multivar_mcmc_model, "_iters"));
		
		**Load MCMC parameter traces;
		beta = ListCreate(num_variables);
		do var_j = 1 to num_variables;
			use (cat(multivar_mcmc_model, "_beta", var_j));
				read all var _ALL_ into beta_j;
			close (cat(multivar_mcmc_model, "_beta", var_j));
			
			beta$var_j = beta_j;
		end;
		
		use (cat(multivar_mcmc_model, "_sigma_e"));
			read all var _ALL_ into sigma_e;
		close(cat(multivar_mcmc_model, "_sigma_e"));
		
		use (cat(multivar_mcmc_model, "_sigma_u"));
			read all var _ALL_ into sigma_u;
		close (cat(multivar_mcmc_model, "_sigma_u"));
		
		if has_never_consumers = 1 then do;
			use (cat(multivar_mcmc_model, "_alpha1"));
				read all var _ALL_ into alpha1;
			close (cat(multivar_mcmc_model, "_alpha1"));
		end;
		else do;
			alpha1 = {};
		end;
		
		**Load MCMC U matrices;
		if num_post > 0 then do;
			u_matrices_post = ListCreate(num_post);
			use (cat(multivar_mcmc_model, "_u_post"));
				read all var {post_mcmc_iteration} into post_mcmc_iteration;
			
				do i = 1 to num_post;
					read point(loc(post_mcmc_iteration = i)) var (cat("u_col", 1:num_variables)) into u_post_i;
					u_matrices_post$i = u_post_i;
				end;
			close (cat(multivar_mcmc_model, "_u_post"));
		end;
		else do;
			u_matrices_post = ListCreate();
		end;
		
		mcmc_parameters = [#"episodic_variables" = episodic_variables,
											 #"daily_variables" = daily_variables,
											 #"covariates" = covariates,
											 #"intercepts" = intercepts,
											 #"mcmc_subjects" = mcmc_subjects,
											 #"beta" = beta,
											 #"sigma_e" = sigma_e,
											 #"sigma_u" = sigma_u,
											 #"u_matrices_post" = u_matrices_post,
											 #"num_episodic" = num_episodic,
											 #"num_daily" = num_daily,
											 #"num_mcmc_iterations" = num_mcmc_iterations,
											 #"num_burn" = num_burn,
											 #"num_thin" = num_thin,
											 #"num_post" = num_post,
											 #"has_never_consumers" = has_never_consumers,
											 #"never_consumer_covariates" = never_consumer_covariates,
											 #"never_consumer_intercept" = never_consumer_intercept,
											 #"alpha1" = alpha1];
											 
		return mcmc_parameters;
	finish;

	**Calculate population variables (sample sizes, nuisance variable levels, number of variables);
	start calculate_population_variables(distrib_population,
																			 id,
																			 nuisance_weight,
																			 use_mcmc_u_matrices,
																			 dietary_supplements,
																			 num_simulated_u,
																			 num_post);
																			 
		**Extract subjects;
		use (distrib_population);
			read all var (id) into records;
			read point(uniqueby(records)) var (id) into subjects;
		close (distrib_population);
		
		**calculate total number of records and variables;
		num_records = nrow(records);
		num_subjects = nrow(subjects);
		
		**extract nuisance variable level weighting, if given;
		if nuisance_weight ^= "" then do;
			
			use (distrib_population);
				read all var (nuisance_weight) into nuisance_weighting;
			close (distrib_population);
		end;
		else do;
		
			nuisance_weighting = j(num_records, 1, 1);
		end;
		
		**normalize the sum of nuisance weights for each subject to 1;
		do i = 1 to num_subjects;
		
			subj = loc(records = subjects[i]);
			nuisance_weighting[subj] = nuisance_weighting[subj]/sum(nuisance_weighting[subj]);
		end;
		
		**extract dietary supplement data;
		if dietary_supplements ^= "" then do;
		
			use (dietary_supplements);
				read all var {variable} into variables_to_supplement;
				variables_to_supplement = variables_to_supplement`;
				
				read all var {supplement} into supplements;
			close (dietary_supplements);
			
			 
			
			use (distrib_population);
				read all var (supplements) into dietary_supplement_data;
			close (distrib_population);
		end;
		else do;
		
			variables_to_supplement = {};
			dietary_supplement_data = {};
		end;
		
		**Calculate number of replicates;
		if upcase(use_mcmc_u_matrices) ^= "N" then do;
		
			num_replicates = num_post;
		end;
		else do;
		
			num_replicates = num_simulated_u;
		end;
		
		**return population variables in list;
		population_variables = [#"records" = records,
														#"subjects" = subjects,
														#"nuisance_weighting" = nuisance_weighting,
														#"variables_to_supplement" = variables_to_supplement,
														#"dietary_supplement_data" = dietary_supplement_data,
														#"num_records" = num_records,
														#"num_subjects" = num_subjects,
														#"num_replicates" = num_replicates];
		
		return population_variables;
	finish;
	
	**Extract covariate data for each variable from input dataset;
	start extract_covariate_matrices(distrib_population,
																	 mcmc_covariates,
																	 mcmc_intercepts,
																	 has_never_consumers,
																	 never_consumer_covariates,
																	 never_consumer_intercept,
																	 num_records,
																	 num_episodic,
																	 num_daily);
																	 
		**Extract recall covariate matrices;
		num_variables = 2*num_episodic + num_daily;
		variable_covariate_matrices = ListCreate(num_variables);
		do var_j = 1 to num_variables;
			
			covariates_j = mcmc_covariates$var_j;
			intercept_j = mcmc_intercepts$var_j;
			
			use (distrib_population);
				if ^IsEmpty(covariates_j) then do;
					read all var (covariates_j) into covariate_matrix_j;
				end;
				else do;
					covariate_matrix_j = {};
				end;
				
				if upcase(intercept_j) = "Y" then do;
					covariate_matrix_j = j(num_records, 1, 1) || covariate_matrix_j;
				end;
			close (distrib_population);
			
			variable_covariate_matrices$var_j = covariate_matrix_j;
		end;
		
		**extract never-consumer covariate matrix;		
		if has_never_consumers = 1 then do;
		
			use (distrib_population);
			
				if ^IsEmpty(never_consumer_covariates) then do;
				
					read all var (never_consumer_covariates) into never_consumer_covariate_matrix;
				end;
				else do;
				
					never_consumer_covariate_matrix = {};
				end;
				
				if upcase(never_consumer_intercept) = "Y" then do;
				
					never_consumer_covariate_matrix = j(num_records, 1, 1) || never_consumer_covariate_matrix;
				end;
			close (distrib_population);
		end;
		else do;
		
			never_consumer_covariate_matrix = {};
		end;
		
		covariate_matrices = [#"variables" = variable_covariate_matrices,
													#"never_consumer" = never_consumer_covariate_matrix];
													
		return covariate_matrices;
	finish;
	
	**Calculates posterior means of beta, sigma-u, sigma-e, and alpha1;
	start calculate_mcmc_means(beta,
														 sigma_u,
														 sigma_e,
														 has_never_consumers,
														 alpha1,
														 num_mcmc_iterations,
														 num_burn,
														 num_thin,
														 num_episodic,
														 num_daily);

		num_variables = 2*num_episodic + num_daily;
														 
		**iterations to use in means;
		thinned_iterations = do(num_burn+1, num_mcmc_iterations, num_thin);
		num_thinned_iterations = ncol(thinned_iterations);
		
		**calculate means;
		beta_mean = ListCreate(num_variables);
		do var_j = 1 to num_variables;
			
			beta_j = beta$var_j;
			beta_thinned_j = beta_j[thinned_iterations,];
			beta_mean$var_j = beta_thinned_j[:,]`;
		end;
		
		sigma_e_thinned = sigma_e[thinned_iterations,];
		sigma_e_mean = sqrsym(sigma_e_thinned[:,]`);
		
		sigma_u_thinned = sigma_u[thinned_iterations,];
		sigma_u_mean = sqrsym(sigma_u_thinned[:,]`);
		
		**if never-consumers are allowed, calculate alpha1 mean;
		if has_never_consumers = 1 then do;
				
			alpha1_thinned = alpha1[thinned_iterations,];
			alpha1_mean = alpha1_thinned[:,]`;
		end;
		else do;
		
			alpha1_mean = {};
		end;
		
		**output list of MCMC means;
		mcmc_means = [#"beta" = beta_mean,
									#"sigma_u" = sigma_u_mean,
									#"sigma_e" = sigma_e_mean,
									#"alpha1" = alpha1_mean];
									
		return mcmc_means;
	finish;
	
	**Calculates XBeta, U standard deviation (square root of U), sample sizes, and consumer probabilities;
	start initialize_distrib_parameters(covariate_matrices,
																			beta_mean,
																			sigma_u_mean,
																			sigma_e_mean,
																			num_records,
																			num_episodic,
																			num_daily,
																			has_never_consumers,
																			g_matrix,
																			alpha1_mean);
																			
		**calculate XBeta;
		num_variables = 2*num_episodic + num_daily;
		
		xbeta = j(num_records, num_variables, .);
		do var_j = 1 to num_variables;
		
			xbeta[,var_j] = covariate_matrices$"variables"$var_j * beta_mean$var_j;
		end;
		
		**calculate the standard deviation of U;
		call eigen(sigma_u_eigvals, sigma_u_eigvecs, sigma_u_mean);
		u_standard_deviation = sigma_u_eigvecs * diag(sqrt(sigma_u_eigvals)) * sigma_u_eigvecs`;
		
		**if never-consumers are allowed, calculate consumer probabilities;
		if has_never_consumers = 1 then do;
		
			g_alpha = g_matrix * alpha1_mean;
			consumer_probabilities = cdf("Normal", g_alpha);
		end;
		else do;
		
			consumer_probabilities = {};
		end;
		
		**output list of distrib parameters;
		distrib_parameters = [#"xbeta" = xbeta,
													#"u_standard_deviation" = u_standard_deviation,
													#"sigma_e_mean" = sigma_e_mean,
													#"consumer_probabilities" = consumer_probabilities];
													
		return distrib_parameters;
	finish;
	
	**Initializes parameters for multivar distrib main loop;
	start initialize_distrib(distrib_population,
													 id,
													 nuisance_weight,
													 mcmc_covariates,
													 mcmc_intercepts,
													 beta,
													 sigma_u,
													 sigma_e,
													 use_mcmc_u_matrices,
													 dietary_supplements,
													 num_simulated_u,
													 num_episodic,
													 num_daily,
													 num_mcmc_iterations,
													 num_burn,
													 num_thin,
													 num_post,
													 distrib_seed=,
													 has_never_consumers,
													 never_consumer_covariates,
													 never_consumer_intercept,
													 alpha1);
													 
		**1. Set seed;
		if IsEmpty(distrib_seed) then do;
		
			distrib_seed = j(1, 1, .);
			call randgen(distrib_seed, "Uniform", 0, 1);
			distrib_seed = int(distrib_seed * 20000000 + 1);
		end;
		call randseed(distrib_seed);
		
		**2. Calculate population variables (sample sizes, number of covariate lists, number of variables);
		population_variables = calculate_population_variables(distrib_population,
																			 										id,
																			 										nuisance_weight,
																			 										use_mcmc_u_matrices,
																			 										dietary_supplements,
																			 										num_simulated_u,
																			 										num_post);
																													
		**3. Extract covariate matrices;
		covariate_matrices = extract_covariate_matrices(distrib_population,
																	 									mcmc_covariates,
																	 									mcmc_intercepts,
																	 									has_never_consumers,
																	 									never_consumer_covariates,
																	 									never_consumer_intercept,
																	 									population_variables$"num_records",
																	 									num_episodic,
																	 									num_daily);
																										
		**4. Calculate means of MCMC multivar parameters;
		mcmc_parameter_means = calculate_mcmc_means(beta,
														 										sigma_u,
														 										sigma_e,
														 										has_never_consumers,
														 										alpha1,
														 										num_mcmc_iterations,
														 										num_burn,
														 										num_thin,
														 										num_episodic,
														 										num_daily);
																								
		**5. Initialize distrib parameters;
		distrib_parameters = initialize_distrib_parameters(covariate_matrices,
																											 mcmc_parameter_means$"beta",
																											 mcmc_parameter_means$"sigma_u",
																											 mcmc_parameter_means$"sigma_e",
																											 population_variables$"num_records",
																											 num_episodic,
																											 num_daily,
																											 has_never_consumers,
																											 covariate_matrices$"never_consumer",
																											 mcmc_parameter_means$"alpha1");
																											 
		**6. Output distrib parameters;
		output_parameters = [#"xbeta" = distrib_parameters$"xbeta",
												 #"u_standard_deviation" = distrib_parameters$"u_standard_deviation",
												 #"sigma_e_mean" = distrib_parameters$"sigma_e_mean",
												 #"consumer_probabilities" = distrib_parameters$"consumer_probabilities",
												 #"records" = population_variables$"records",
												 #"subjects" = population_variables$"subjects",
												 #"nuisance_weighting" = population_variables$"nuisance_weighting",
												 #"variables_to_supplement" = population_variables$"variables_to_supplement",
												 #"dietary_supplement_data" = population_variables$"dietary_supplement_data",
												 #"num_records" = population_variables$"num_records",
												 #"num_subjects" = population_variables$"num_subjects",
												 #"num_replicates" = population_variables$"num_replicates",
												 #"distrib_seed" = distrib_seed];
												 
		return output_parameters;
	finish;
	
	reset storage=_modules.distrib_modules;
	store module=(load_parameters
								calculate_population_variables
								extract_covariate_matrices
								calculate_mcmc_means
								initialize_distrib_parameters
								initialize_distrib);
quit;