/*
Program Purpose: Store modules for initializing MCMC
*/

libname _modules ".";

proc iml;

	**Splits input dataset by repeated observations and extracts subject, observation, and weight data;
	start initialize_subject_data(input_data,
																id,
																repeat_obs,
																weight,
																variable_data);
																
		use (input_data);
			read all var id into subjects;
		close (input_data);
		
		unique_subjects = unique(subjects)`;
		num_subjects = nrow(unique_subjects);
		
		use (input_data);
			read all var repeat_obs into recalls_raw;
		close (input_data);
		
		unique_recalls = unique(recalls_raw)`;
		num_recalls = nrow(unique_recalls);
		
		recalls = j(nrow(recalls_raw), 1, .);
		do day_k = 1 to num_recalls;
		
			recalls[loc(recalls_raw = unique_recalls[day_k])] = day_k;
		end;
		
		recall_availability = j(num_subjects, num_recalls, .);
		do day_k = 1 to num_recalls;
		
			recall_availability[,day_k] = element(unique_subjects, subjects[loc(recalls = day_k)]);
		end;
		
		if weight ^= "" then do;
		
			use (input_data);
				read point(uniqueby(subjects)) var weight into subject_weighting;
			close (input_data);
		end;
		else do;
		
			subject_weighting = j(num_subjects, 1, 1);
		end;
		subject_weighting = subject_weighting/mean(subject_weighting);
		
		
		use (variable_data);
			read all var {num_episodic} into num_episodic;
			read all var {num_daily} into num_daily;
			read all var {has_never_consumers} into has_never_consumers;
		close (variable_data);
		
		subject_data = [#"subjects" = subjects,
										#"recalls" = recalls,
										#"recall_availability" = recall_availability,
										#"weighting" = subject_weighting,
										#"num_subjects" = num_subjects,
										#"num_episodic" = num_episodic,
										#"num_daily" = num_daily,
										#"num_recalls" = num_recalls,
										#"has_never_consumers" = has_never_consumers];
		return subject_data;
	finish;
	
	**Extracts model variable names from SAS dataset;
	start extract_variable_data(variable_data);
	
		use (variable_data);
			read all var {num_episodic} into num_episodic;
			read all var {num_daily} into num_daily;
			
			if num_episodic > 0 then do;
				read all var (cat("episodic_var", 1:num_episodic)) into episodic_variables;
				read all var (cat("episodic_ind", 1:num_episodic)) into episodic_indicators;
				read all var (cat("episodic_amt", 1:num_episodic)) into episodic_amounts;
			end;
			else do;
			
				episodic_variables = {};
				episodic_indicators = {};
				episodic_amounts = {};
			end;
			
			if num_daily > 0 then do;
				read all var (cat("daily_var", 1:num_daily)) into daily_variables;
				read all var (cat("daily_amt", 1:num_daily)) into daily_amounts;
			end;
			else do;
			
				daily_variables = {};
				daily_amounts = {};
			end;
			
			read all var {has_never_consumers} into has_never_consumers;
		close (variable_data);
		
		variables = [#"episodic_variables" = episodic_variables,
								 #"episodic_indicators" = episodic_indicators,
								 #"episodic_amounts" = episodic_amounts,
								 #"daily_variables" = daily_variables,
								 #"daily_amounts" = daily_amounts,
								 #"num_episodic" = num_episodic,
								 #"num_daily" = num_daily,
								 #"has_never_consumers" = has_never_consumers];
								 
		return variables;
	finish;
	
	**Initializes model variable matrices for the MCMC function;
	start initialize_variable_matrices(input_data,
																		 subjects,
																		 recalls,
																		 recall_availability,
																		 variable_data,
																		 num_subjects,
																		 num_episodic,
																		 num_daily,
																		 num_recalls);
																		 
		use (variable_data);
			if num_episodic > 0 then do;
				read all var (cat("episodic_ind", 1:num_episodic)) into episodic_indicators;
				read all var (cat("episodic_amt", 1:num_episodic)) into episodic_amounts;
			end;
			
			if num_daily > 0 then do;
				read all var (cat("daily_amt", 1:num_daily)) into daily_amounts;
			end;
		close (variable_data);
		
		episodic_indicator_data = ListCreate(num_recalls);
		episodic_amount_data = ListCreate(num_recalls);
		daily_amount_data = ListCreate(num_recalls);
		do day_k = 1 to num_recalls;
		
			subjects_in_recall = loc(recall_availability[,day_k]);
			
			if num_episodic > 0 then do;
			
				**Episodic indicator data;
				episodic_indicator_data_k = j(num_subjects, num_episodic, 0);
			
				use (input_data);
					read point(loc(recalls = day_k)) var episodic_indicators into variable_data;
				close (input_data);
			
				episodic_indicator_data_k[subjects_in_recall,] = variable_data;
				episodic_indicator_data$day_k = episodic_indicator_data_k;
				
				**Episodic amount data;
				episodic_amount_data_k = j(num_subjects, num_episodic, 0);
			
				use (input_data);
					read point(loc(recalls = day_k)) var episodic_amounts into variable_data;
				close (input_data);
			
				episodic_amount_data_k[subjects_in_recall,] = variable_data;
				episodic_amount_data_k[loc(^episodic_indicator_data_k)] = 0;
				episodic_amount_data$day_k = episodic_amount_data_k;
			end;
			else do;
			
				episodic_indicator_data$day_k = {};
				episodic_amount_data$day_k = {};
			end;
			
			if num_daily > 0 then do;
			
				**Daily amount data;	
				daily_amount_data_k = j(num_subjects, num_daily, 0);
					
				use (input_data);
					read point(loc(recalls = day_k)) var daily_amounts into variable_data;
				close (input_data);
			
				daily_amount_data_k[subjects_in_recall,] = variable_data;
				daily_amount_data$day_k = daily_amount_data_k;
			end;
			else do;
			
				daily_amount_data$day_k = {};
			end;
		end;
		
		variable_matrices = [#"episodic_indicators" = episodic_indicator_data,
												 #"episodic_amounts" = episodic_amount_data,
												 #"daily_amounts" = daily_amount_data];
		return variable_matrices;
	finish;
	
	**Initializes covariate matrices for the MCMC function;
	start initialize_covariate_matrices(input_data,
																			subjects,
																			recalls,
																			recall_availability,
																			subject_weighting,
																			covariate_data,
																			has_never_consumers,
																			never_consumer_covariate_data,
																			num_subjects,
																			num_episodic,
																			num_daily,
																			num_recalls);																
																			
		num_variables = 2*num_episodic + num_daily;
		
		**recall covariate matrices;
		recall_covariate_matrices = ListCreate(num_variables);
		do var_j = 1 to num_variables;
		
			recall_covariate_matrices$var_j = ListCreate(num_recalls);
			
			use (covariate_data);
				read point(var_j) var {num_covariates} into num_covariates;
				if num_covariates > 0 then do;
					read point(var_j) var (cat("covariate", 1:num_covariates)) into covariates_j;
				end;
				read point(var_j) var {intercept} into intercept_j;
				read point(var_j) var {variable} into variable;
			close (covariate_data);
			
			do day_k = 1 to num_recalls;
			
				subjects_in_recall = loc(recall_availability[,day_k]);
				
				**covariates;
				if num_covariates > 0 then do;
				
					recall_covariate_matrix_j_k = j(num_subjects, num_covariates, 0);
				
					use (input_data);
						read point(loc(recalls = day_k)) var covariates_j into covariate_matrix;
					close (input_data);
				
					recall_covariate_matrix_j_k[subjects_in_recall,] = covariate_matrix;
				end;
				else do;
				
					recall_covariate_matrix_j_k = {};
				end;
				
				**intercept;
				if upcase(intercept_j) = "Y" then do;
				
					recall_covariate_matrix_j_k = recall_availability[,day_k] || recall_covariate_matrix_j_k;
				end;
				
				recall_covariate_matrices$var_j$day_k = recall_covariate_matrix_j_k;
			end;
			
			call ListSetName(recall_covariate_matrices, var_j, variable);
		end;
		
		**weighted covariate matrices;
		wt_recall_covariate_matrices = ListCreate(num_variables);
		wt_recall_covariate_sq_sums = ListCreate(num_variables);
		do var_j = 1 to num_variables;
		
			num_covariates = ncol(recall_covariate_matrices$var_j$1);
			
			if num_covariates > 0 then do;
			
				wt_recall_covariate_matrices$var_j = ListCreate(num_recalls);
				wt_recall_covariate_sq_sums$var_j = j(num_covariates, num_covariates, 0);
				do day_k = 1 to num_recalls;
			
					wt_recall_covariate_matrices$var_j$day_k = (subject_weighting # recall_availability[,day_k] # recall_covariate_matrices$var_j$day_k)`;
					wt_recall_covariate_sq_sums$var_j = wt_recall_covariate_sq_sums$var_j + wt_recall_covariate_matrices$var_j$day_k * recall_covariate_matrices$var_j$day_k;
				end;
			end;
			else do;
			
				wt_recall_covariate_matrices$var_j = ListCreate(num_recalls);
				wt_recall_covariate_sq_sums$var_j = {};
			end;
		end;
		
		**never-consumer covariate matrix;
		if has_never_consumers = 1 then do;
		
			use (never_consumer_covariate_data);
				read all var {num_covariates} into num_covariates;
				if num_covariates > 0 then do;
					read all var (cat("covariate", 1:num_covariates)) into never_consumer_covariates;
				end;
				read all var {intercept} into never_consumer_intercept;
			close (never_consumer_covariate_data);
	
			if num_covariates > 0 then do;
			
				use (input_data);
					read point(uniqueby(subjects)) var never_consumer_covariates into never_consumer_covariate_matrix;
				close (input_data);
			end;
			else do;
			
				never_consumer_covariate_matrix = {};
			end;
			
			if never_consumer_intercept = "Y" then do;
			
				never_consumer_covariate_matrix = j(num_subjects, 1, 1) || never_consumer_covariate_matrix;
			end;
		end;
		else do;
		
			never_consumer_covariate_matrix = {};
		end;
		
		**output covariate data;
		covariate_matrices = [#"recall" = recall_covariate_matrices,
													#"wt_recall" = wt_recall_covariate_matrices,
													#"wt_sq_sums" = wt_recall_covariate_sq_sums,
													#"never_consumer" = never_consumer_covariate_matrix];
		return covariate_matrices;
	finish initialize_covariate_matrices;
	
	**Initializes MCMC priors and starting values - beta, sigma-e, r, theta, and sigma-u, W, U, and never-consumer parmaeters (alpha1 and consumer probabilities);
	**Priors and starting values are defined in Appendices A.2 and A.3 of Zhang, et al. (2011) and Section 2.3.3 of Bhadra, et al. (2020) for never-consumers;
	start initialize_priors(episodic_indicator_matrices,
													episodic_amount_matrices,
													daily_amount_matrices,
													recall_covariate_matrices,
													sigma_u_prior_data,
													num_subjects,
													num_episodic,
													num_daily,
													num_recalls,
													recall_availability,
													has_never_consumers,
													never_consumer_covariate_matrix);
													
		num_variables = 2*num_episodic + num_daily;
		
		**initialize beta prior;
		beta_mean_prior = ListCreate(num_variables);
		beta_covariance_prior = ListCreate(num_variables);
		do var_j = 1 to num_variables;
		
			num_covariates = ncol(recall_covariate_matrices$var_j$1);
			
			**beta mean prior (number of covariates x 1);
			beta_mean_prior_j = j(num_covariates, 1, 0);
			beta_mean_prior$var_j = beta_mean_prior_j;
			
			**beta covariance prior (number of covariates x number of covariates);
			beta_covariance_prior_j = i(num_covariates) # 100;
			beta_covariance_prior$var_j = beta_covariance_prior_j;
		end;
		beta_prior = [#"mean" = beta_mean_prior,
									#"covariance" = beta_covariance_prior];
		
		**initialize r matrix prior (number of episodically consumed foods - 1 x 1) and theta matrix prior ((number of episodically consumed foods - 1)^2 x 1);
		if num_episodic >= 2 then do;
		
			r_matrix_prior = j(num_episodic - 1, 1, 0);
			theta_matrix_prior = j((num_episodic - 1)##2, 1, 0);
		end;
		else do;
		
			r_matrix_prior = {};
			theta_matrix_prior = {};
		end;
		
		**Initialize V matrix prior (number of episodic and daily variables x number of episodic and daily variables);
		v_matrix_prior = i(num_variables);
		
		**Initialize sigma-e prior (number of episodic and daily variables x number of episodic and daily variables);
		sigma_e_prior = v_matrix_prior * v_matrix_prior`;
		
		**Initialize sigma-u prior (number of episodic and daily variables x number of episodic and daily variables);
		if sigma_u_prior_data ^= "" then do;
		
			use (sigma_u_prior_data);
				read all var _ALL_ into sigma_u_prior;
			close (sigma_u_prior_data);
		end;
		else do;
		
			sigma_u_prior = j(num_variables, num_variables, 0.5) + i(num_variables) # 0.5;
		end;
		
		**Initialize U matrix prior (number of subjects x number of episodic and daily variables);
		u_prior_normals = j(num_subjects, num_variables, .);
		call randgen(u_prior_normals, "Normal", 0, 1);
		
		call eigen(sigma_u_eigvals, sigma_u_eigvecs, sigma_u_prior);
		sqrt_sigma_u = sigma_u_eigvecs * diag(sqrt(sigma_u_eigvals)) * sigma_u_eigvecs`;
		
		u_matrix_prior = u_prior_normals * sqrt_sigma_u;
		
		**Initialize XBeta prior (number of subjects x number of episodic and daily variables for each recall);
		xbeta_prior = ListCreate(num_recalls);
		do day_k = 1 to num_recalls;
		
			xbeta_prior_k = j(num_subjects, num_variables, .);
			do var_j = 1 to num_variables;
			
				xbeta_prior_k[,var_j] = recall_covariate_matrices$var_j$day_k * beta_mean_prior$var_j;
			end;
			
			xbeta_prior$day_k = xbeta_prior_k;
		end;
		
		**Initialize XBeta_U prior (number of subjects x number of episodic and daily variables for each recall);
		xbeta_u_prior = ListCreate(num_recalls);
		do day_k = 1 to num_recalls;
		
			xbeta_u_prior_k = xbeta_prior$day_k + u_matrix_prior;
			
			xbeta_u_prior$day_k = xbeta_u_prior_k;
		end;
		
		**Initialize W matrix prior (number of subjects x number of episodic and daily variables for each recall);
		w_matrix_prior = ListCreate(num_recalls);
		do day_k = 1 to num_recalls;
		
			w_matrix_prior_k = j(num_subjects, num_variables, .);
			do var_j = 1 to num_variables;
			
				if var_j <= 2*num_episodic & mod(var_j, 2) = 1 then do;
				
					food_j = int((var_j + 1)/2);
					w_matrix_normals = j(num_subjects, 1, .);
					call randgen(w_matrix_normals, "Normal", 0, 1);
					w_matrix_prior_k[,var_j] = (2 # episodic_indicator_matrices$day_k[,food_j] - 1) # abs(xbeta_u_prior$day_k[,var_j] + w_matrix_normals) # recall_availability[,day_k];
				end;
				else if var_j <= 2*num_episodic & mod(var_j, 2) = 0 then do;
				
					food_j = int((var_j + 1)/2);
					w_matrix_prior_k[,var_j] = episodic_amount_matrices$day_k[,food_j];
				end;
				else do;
				
					nutr_j = var_j - 2*num_episodic;
					w_matrix_prior_k[,var_j] = daily_amount_matrices$day_k[,nutr_j];
				end;
			end;
			
			w_matrix_prior$day_k = w_matrix_prior_k;
		end;
		
		**never-consumers priors;
		if has_never_consumers = 1 then do;
		
			num_never_consumer_covariates = ncol(never_consumer_covariate_matrix);
			
			**Initialize alpha1 mean (number of never-consumer covariates x 1) and covariance (number of never-consumer covariates x number of never-consumer covariates) priors;
			alpha1_mean_prior = j(num_never_consumer_covariates, 1, 0);
			
			alpha1_covariance_prior = i(num_never_consumer_covariates);
			
			**Initialize consumer probabilities prior (number of subjects x 1);
			g_alpha_prior = never_consumer_covariate_matrix * alpha1_mean_prior;
			consumer_probabilities_prior = cdf("Normal", g_alpha_prior);
		end;
		else do;
		
			alpha1_mean_prior = {};
			alpha1_covariance_prior = {};
			consumer_probabilities_prior = {};
		end;
		alpha1_prior = [#"mean" = alpha1_mean_prior,
										#"covariance" = alpha1_covariance_prior];
										
		mcmc_priors = [#"alpha1_prior" = alpha1_prior,
									 #"consumer_probabilities_prior" = consumer_probabilities_prior,
									 #"beta_prior" = beta_prior,
									 #"sigma_e_prior" = sigma_e_prior,
									 #"r_matrix_prior" = r_matrix_prior,
									 #"theta_matrix_prior" = theta_matrix_prior,
									 #"v_matrix_prior" = v_matrix_prior,
									 #"sigma_u_prior" = sigma_u_prior,
									 #"u_matrix_prior" = u_matrix_prior,
									 #"xbeta_prior" = xbeta_prior,
									 #"xbeta_u_prior" = xbeta_u_prior,
									 #"w_matrix_prior" = w_matrix_prior];
		return mcmc_priors;
	finish;
	
	start initialize_mcmc(input_data,
												id,
												repeat_obs,
												weight,
												variable_data,
												covariate_data,
												mcmc_seed=,
												sigma_u_prior_data);
												
		**1. Set seed for MCMC;
		if IsEmpty(mcmc_seed) then do;
		
			mcmc_seed = j(1, 1, .);
			call randgen(mcmc_seed, "Uniform", 0, 1);
			mcmc_seed = int(mcmc_seed * 20000000 + 1);
		end;
		call randseed(mcmc_seed);
		
		**2. Sort input data;
		submit input_data id repeat_obs;
			proc sort data=&input_data out=_input; by &id &repeat_obs; run;
		endsubmit;
		
		**3. Define subject-level data;
		subject_data = initialize_subject_data("_input",
																					 id,
																					 repeat_obs,
																					 weight,
																					 variable_data);
																					 
		**4. Define variable matrices;
		variable_matrices = initialize_variable_matrices("_input",
																										 subject_data$"subjects",
																										 subject_data$"recalls",
																										 subject_data$"recall_availability",
																										 variable_data,
																										 subject_data$"num_subjects",
																									 	 subject_data$"num_episodic",
																									 	 subject_data$"num_daily",
																									 	 subject_data$"num_recalls");
																									 	 
		**5. Define covariate matrices;
		covariate_matrices = initialize_covariate_matrices("_input",
																										 	 subject_data$"subjects",
																										 	 subject_data$"recalls",
																										 	 subject_data$"recall_availability",
																										 	 subject_data$"weighting",
																										 	 covariate_data,
																										 	 subject_data$"has_never_consumers",
																										 	 cats(covariate_data, "nc"),
																										 	 subject_data$"num_subjects",
																										 	 subject_data$"num_episodic",
																										 	 subject_data$"num_daily",
																										 	 subject_data$"num_recalls");
																										 	 
		**6. Initialize MCMC priors;
		mcmc_priors = initialize_priors(variable_matrices$"episodic_indicators",
																		variable_matrices$"episodic_amounts",
																		variable_matrices$"daily_amounts",
																		covariate_matrices$"recall",
																		sigma_u_prior_data,
																		subject_data$"num_subjects",
																		subject_data$"num_episodic",
																		subject_data$"num_daily",
																		subject_data$"num_recalls",
																		subject_data$"recall_availability",
																		subject_data$"has_never_consumers",
																		covariate_matrices$"never_consumer");
																		
		mcmc_parameters = [#"subject_data" = subject_data,
											 #"covariate_matrices" = covariate_matrices,
											 #"variable_matrices" = variable_matrices,
											 #"priors" = mcmc_priors,
											 #"mcmc_seed" = mcmc_seed];
		return mcmc_parameters;
	finish;
	
	reset storage = _modules.mcmc_modules;
	store module=(initialize_subject_data
								initialize_variable_matrices
								initialize_covariate_matrices
								initialize_priors
								initialize_mcmc);
quit;