/*
Program Purpose:
IML modules for the main MCMC loop
*/

libname _modules ".";

proc iml;

	**Main Gibbs sampler loop to fit multivar MCMC model;
	start mcmc_main_loop(num_mcmc_iterations,
											 num_burn,
											 num_thin,
											 num_post,
											 num_subjects,
											 num_episodic,
											 num_daily,
											 num_recalls,
											 recall_availability,
											 subject_weighting,
											 episodic_indicator_matrices,
											 has_never_consumers,
											 covariate_matrices,
											 weighted_covariate_matrices,
											 weighted_covariate_sq_sums,
											 never_consumer_covariate_matrix,
											 alpha1_mean_prior,
											 alpha1_covariance_prior,
											 consumer_probabilities_prior,
											 beta_mean_prior,
											 beta_covariance_prior,
											 r_matrix_prior,
											 theta_matrix_prior,
											 v_matrix_prior,
											 sigma_e_prior,
											 sigma_u_prior,
											 u_matrix_prior,
											 w_matrix_prior,
											 xbeta_prior,
											 xbeta_u_prior,
											 sigma_u_constant,
											 save_u_main,
											 save_all_u,
											 do_log_likelihood);
	
		**Initialize parameters to their priors;
		if has_never_consumers = 1 then do;
		
			alpha1 = alpha1_mean_prior;
			consumer_probabilities = consumer_probabilities_prior;
		end;
		else do;
		
			alpha1 = {};
			consumer_probabilities = {};
		end;	
		
		beta = beta_mean_prior;
		
		if num_episodic > 1 then do;
		
			r_matrix = r_matrix_prior;
			theta_matrix = theta_matrix_prior;
		end;
		else do;
		
			r_matrix = {};
			theta_matrix = {};
		end;
		
		v_matrix = v_matrix_prior;
		sigma_e = sigma_e_prior;
		sigma_u = sigma_u_prior;
		u_matrix = u_matrix_prior;
		w_matrix = w_matrix_prior;
		xbeta = xbeta_prior;
		xbeta_u = xbeta_u_prior;
		
		**Initialize parameter trace data;
		thinned_iterations = do(num_burn + 1, num_mcmc_iterations, num_thin);
		num_thinned_iterations = int((num_mcmc_iterations - (num_burn + 1))/num_thin) + 1;
		
		if has_never_consumers = 1 then do;
		
			alpha1_trace = j(num_mcmc_iterations, nrow(alpha1), .);
			
			consumer_probabilities_trace = j(num_mcmc_iterations, 1, .);
		end;	
		else do;
		
			alpha1_trace = {};
			consumer_probabilities_trace = {};
		end;
		
		beta_trace = ListCreate(ListLen(beta));
		do var_j = 1 to ListLen(beta);
		
			beta_trace$var_j = j(num_mcmc_iterations, nrow(beta$var_j), .);
		end;
		
		sigma_e_trace = j(num_mcmc_iterations, sum(1:nrow(sigma_e)));
		
		sigma_u_trace = j(num_mcmc_iterations, sum(1:nrow(sigma_u)));
		
		**Initialize U matrix trace data for main MCMC and post-MCMC iterations;
		if upcase(save_u_main) = "Y" then do;
		
			if upcase(save_all_u) = "Y" then do;
			
				u_matrix_main = ListCreate(num_mcmc_iterations);
				saved_u_main = 1:num_mcmc_iterations;
			end;
			else do;
			
				u_matrix_main = ListCreate(num_thinned_iterations);
				saved_u_main = thinned_iterations;
			end;
		end;	
		else do;
		
			u_matrix_main = ListCreate();
			saved_u_main = {};
		end;	
		
		if num_post > 0 then do;
		
			u_matrix_post = ListCreate(num_post);
		end;
		else do;
		
			u_matrix_post = ListCreate();
		end;	
		
		**Gibbs sampler loop;		 
		do iter = 1 to num_mcmc_iterations + num_post;
		
			**Update Ni;
			conni1 = update_conni1(has_never_consumers,
														 alpha1,
														 consumer_probabilities,
														 xbeta_u,
														 never_consumer_covariate_matrix,
														 episodic_indicator_matrices,
														 recall_availability,
														 num_subjects,
														 num_recalls);
			
			**Update alpha1;
			alpha1 = update_alpha1(has_never_consumers,
														 alpha1_mean_prior,
														 alpha1_covariance_prior,
														 conni1,
														 never_consumer_covariate_matrix,
														 subject_weighting);
			
			**Update consumer probabilities;
			consumer_probabilities = update_consumer_probabilities(has_never_consumers,
																														 never_consumer_covariate_matrix,
																														 alpha1);
			
			**Update W matrix;
			w_matrix = update_w_matrix(w_matrix,
																 xbeta_u,
																 sigma_e,
																 recall_availability,
																 episodic_indicator_matrices,
																 num_subjects,
																 num_episodic,
																 num_daily,
																 num_recalls);
			
			if iter <= num_mcmc_iterations then do;
			
				**Calculate W-XBeta-U cross-residual sum;
				**The W-XBeta-U calculation corresponds to the error (epsilon) term in Equation 3.5 of Zhang, et al. (2011);
				w_cross_residual_sum = calculate_w_cross_residual_sum(w_matrix,
																															xbeta_u,
																															recall_availability,
																															subject_weighting,
																															num_episodic,
																															num_daily,
																															num_recalls);
				
				**Update r matrix;
				r_matrix = update_r_matrix(r_matrix,
																	 theta_matrix,
																	 v_matrix,
																	 w_cross_residual_sum,
																	 recall_availability,
																	 subject_weighting);
				
				**Update theta matrix;
				theta_matrix = update_theta_matrix(theta_matrix,
																					 r_matrix,
																					 v_matrix,
																					 w_cross_residual_sum);
				
				**Update v matrix;
				v_matrix = update_v_matrix(v_matrix,
																	 r_matrix,
																	 theta_matrix,
																	 w_cross_residual_sum,
																	 recall_availability,
																	 subject_weighting,
																	 num_episodic,
																	 num_daily);
				
				**Update sigma-e;
				sigma_e = v_matrix * v_matrix`;
				
				**Update sigma-u;
				sigma_u = update_sigma_u(sigma_u,
																 sigma_u_prior,
																 u_matrix,
																 subject_weighting,
																 sigma_u_constant,
																 num_subjects);
			end;
			
			**Update U matrix;
			u_matrix = update_u_matrix(sigma_u,
																 sigma_e,
																 w_matrix,
																 xbeta,
																 recall_availability,
																 num_subjects,
																 num_episodic,
																 num_daily,
																 num_recalls,
																 has_never_consumers,
																 u_matrix,
																 conni1);
			
			if iter <= num_mcmc_iterations then do;
			
				**Update beta;
				beta = update_beta(weighted_covariate_matrices,
													 weighted_covariate_sq_sums,
													 recall_availability,
													 w_matrix,
													 u_matrix,
													 sigma_e,
													 xbeta,
													 beta_mean_prior,
													 beta_covariance_prior,
													 num_subjects,
													 num_episodic,
													 num_daily,
													 num_recalls,
													 has_never_consumers,
													 conni1,
													 beta$1,
													 covariate_matrices);
			end;
			
			**Saving parameter traces for this iteration;
			if iter <= num_mcmc_iterations then do;
			
				if has_never_consumers = 1 then do;
				
					alpha1_trace[iter,] = alpha1`;
					
					consumer_probabilities_trace[iter] = mean(consumer_probabilities);
				end;
				
				do var_j = 1 to ListLen(beta);
				
					beta_trace_j = beta_trace$var_j;
					beta_trace_j[iter,] = beta$var_j`;
					beta_trace$var_j = beta_trace_j;
				end;
				
				sigma_e_trace[iter,] = symsqr(sigma_e)`;
				
				sigma_u_trace[iter,] = symsqr(sigma_u)`;
				
				**if specified, save U matrix;
				if upcase(save_u_main) = "Y" then do;
				
					if element(iter, saved_u_main) = 1 then do;
					
						u_iter = loc(saved_u_main = iter);
						u_matrix_main$u_iter = u_matrix;
					end;
				end;
			end;
			else do;
			
				**save U matrix (post-MCMC);
				post_iter = iter - num_mcmc_iterations;
				u_matrix_post$post_iter = u_matrix;
			end;
			
			**Calculate posterior means of MCMC parameters using thinned iterations after burn-in;
			if iter = num_mcmc_iterations then do;
			
				beta_mean = ListCreate(ListLen(beta_trace));
				do var_j = 1 to ListLen(beta_trace);
					beta_trace_j = beta_trace$var_j;
					beta_thinned_j = beta_trace_j[thinned_iterations,];
					beta_mean$var_j = beta_thinned_j[:,]`;
				end;
				
				sigma_e_thinned = sigma_e_trace[thinned_iterations,];
				sigma_e_mean = sqrsym(sigma_e_thinned[:,]`);
				
				sigma_u_thinned = sigma_u_trace[thinned_iterations,];
				sigma_u_mean = sqrsym(sigma_u_thinned[:,]`);
				
				**fix beta, sigma-e, and sigma-u at posterior means;
				beta = beta_mean;
				sigma_e = sigma_e_mean;
				sigma_u = sigma_u_mean;
			end;
			
			**Update Xbeta;
			if iter <= num_mcmc_iterations then do;
			
				do day_k = 1 to num_recalls;
				
					xbeta_k = xbeta$day_k;
					do var_j = 1 to 2*num_episodic + num_daily;
						
						xbeta_k[,var_j] = covariate_matrices$var_j$day_k * beta$var_j;
					end;
					xbeta$day_k = xbeta_k;
				end;
			end;
			
			**Update Xbeta-U;
			do day_k = 1 to num_recalls;
			
				xbeta_u$day_k = xbeta$day_k + u_matrix;
			end;
		end;
		
		**Calculate log-likelihood;
		log_likelihood = marginal_likelihood(do_log_likelihood,
																				 w_matrix,
																				 episodic_indicator_matrices,
																				 xbeta,
																				 sigma_u,
																				 sigma_e,
																				 recall_availability,
																				 subject_weighting,
																				 num_subjects,
																				 num_episodic,
																				 num_daily,
																				 num_recalls,
																				 has_never_consumers,
																				 never_consumer_covariate_matrix,
																				 alpha1);
																				 
		**Output parameter traces;
		mcmc_output = [#"alpha1" = alpha1_trace,
									 #"consumer_probabilities" = consumer_probabilities_trace,
									 #"beta" = beta_trace,
									 #"sigma_e" = sigma_e_trace,
									 #"sigma_u" = sigma_u_trace,
									 #"u_matrices_main" = u_matrix_main,
									 #"saved_u_main" = saved_u_main,
									 #"u_matrices_post" = u_matrix_post,
									 #"log_likelihood" = log_likelihood];
		return mcmc_output;
	finish;
	
	start output_mcmc(outname,
										log_likelihood,
										beta,
										sigma_u,
										sigma_e,
										alpha1,
										consumer_probabilities,
										u_matrices_main,
										u_matrices_post,
										mcmc_subjects,
										subject_weighting,
										num_mcmc_iterations,
										num_burn,
										num_thin,
										num_post,
										saved_u_main,
										has_never_consumers);
										
		**extract variable names;
		use (cat(outname, "_vars"));
			read all var {num_episodic} into num_episodic;
			read all var {num_daily} into num_daily;
			
			if num_episodic > 0 then do;
				read all var (cat("episodic_ind", 1:num_episodic)) into episodic_indicators;
				read all var (cat("episodic_amt", 1:num_episodic)) into episodic_amounts;
			end;
			
			if num_daily > 0 then do;
				read all var (cat("daily_amt", 1:num_daily)) into daily_amounts;
			end;
		close (cat(outname, "_vars"));
		
		num_variables = 2*num_episodic + num_daily;
		variables = j(1, num_variables, BlankStr(32));
		do food_j = 1 to num_episodic;
			variables[food_j*2 - 1] = episodic_indicators[food_j];
			variables[food_j*2] = episodic_amounts[food_j];
		end;
		do nutr_j = 1 to num_daily;
			variables[nutr_j + 2*num_episodic] = daily_amounts[nutr_j];
		end;
		
		**extract covariate names;
		covariates = ListCreate(num_variables);
		intercepts = ListCreate(num_variables);
		do var_j = 1 to num_variables;
			use (cat(outname, "_covars"));
				read point(var_j) var {num_covariates} into num_covariates;
				
				if num_covariates > 0 then do;
					read point(var_j) var (cat("covariate", 1:num_covariates)) into covariates_j;
				end;
				else do;
					covariates_j = {};
				end;
				
				read point(var_j) var {intercept} into intercept_j;
			close (cat(outname, "_covars"));
			
			covariates$var_j = covariates_j;
			intercepts$var_j = intercept_j;
		end;
		
		**store log-likelihood;
		if log_likelihood ^= . then do;
		
			create (cat(outname, "_ll")) from log_likelihood[colname={"log_likelihood"}];
				append from log_likelihood;
			close (cat(outname, "_ll"));
		end;
		
		do var_j = 1 to num_variables;
		
			beta_j = beta$var_j;
			
			variable_j = variables[var_j];
			covariates_j = covariates$var_j;
			intercept_j = intercepts$var_j;
			
			**names and labels for beta;
			beta_names = cat("beta", var_j, "_covariate", 1:ncol(covariates_j));
			beta_labels = cat("beta_", trim(variable_j), "_", trim(covariates_j));
			if upcase(intercept_j) = "Y" then do;
				beta_names = cat("beta", var_j, "_intercept") || beta_names;
				beta_labels = cat("beta_", trim(variable_j), "_intercept") || beta_labels;
			end;
		
			**store MCMC beta;
			beta_j_out = cat(outname, "_beta", var_j);
			create (beta_j_out) from beta_j[colname=beta_names];
				append from beta_j;
			close (beta_j_out);
			
			label_list = catx(" = ", trim(beta_names), trim(beta_labels));
			submit beta_j_out label_list;
				data &beta_j_out;
					set &beta_j_out;
				
					label &label_list;
				run;
			endsubmit;
		end;
		
		**Names and labels for Sigma-e and Sigma-u;
		sigma_names = j(1, sum(1:num_variables), BlankStr(32));
		sigma_labels = j(1, sum(1:num_variables), BlankStr(256));
		index = 0;
		do row = 1 to num_variables;
			do col = 1 to row;
				index = index + 1;
				sigma_names[index] = cat("row", row, "_col", col);
				sigma_labels[index] = cat(trim(variables[row]), "_", trim(variables[col]));
			end;
		end;
		
		sigma_e_names = cat("sigma_e_", trim(sigma_names));
		sigma_u_names = cat("sigma_u_", trim(sigma_names));
		
		sigma_e_labels = cat("sigma_e_", trim(sigma_labels));
		sigma_u_labels = cat("sigma_u_", trim(sigma_labels));
		
		**store MCMC Sigma-e;
		sigma_e_out = cat(outname, "_sigma_e");
		create (sigma_e_out) from sigma_e[colname=sigma_e_names];
			append from sigma_e;
		close (sigma_e_out);
		
		label_list = catx(" = ", trim(sigma_e_names), trim(sigma_e_labels));
		submit sigma_e_out label_list;
			data &sigma_e_out;
				set &sigma_e_out;
				
				label &label_list;
			run;
		endsubmit;
		
		**store MCMC Sigma-u;
		sigma_u_out = cat(outname, "_sigma_u");
		create (sigma_u_out) from sigma_u[colname=sigma_u_names];
			append from sigma_u;
		close (sigma_u_out);
		
		label_list = catx(" = ", trim(sigma_u_names), trim(sigma_u_labels));
		submit sigma_u_out label_list;
			data &sigma_u_out;
				set &sigma_u_out;
				
				label &label_list;
			run;
		endsubmit;
		
		**store MCMC main chain U matrices;
		if ^IsEmpty(saved_u_main) then do;
		
			u_main_names = "iteration" || cat("u_col", 1:num_variables);
			u_main_labels = "iteration" || cat("u_", trim(variables));
		
			u_main_out = cat(outname, "_u_main");
			create (u_main_out) var (u_main_names);
				do i = 1 to ListLen(u_matrices_main);
					u_main_i = u_matrices_main$i;
					iteration = repeat(saved_u_main[i], nrow(u_main_i), 1);
					
					append from iteration u_main_i;
				end;
			close (u_main_out);
			
			label_list = catx(" = ", trim(u_main_names), trim(u_main_labels));
			submit u_main_out label_list;
				data &u_main_out;
					set &u_main_out;
					
					label &label_list;
				run;
			endsubmit;
		end;
		
		**store post-MCMC U matrices;
		if num_post > 0 then do;
		
			u_post_names = "post_mcmc_iteration" || cat("u_col", 1:num_variables);
			u_post_labels = "post_mcmc_iteration" || cat("u_", trim(variables));
		
			u_post_out = cat(outname, "_u_post");
			create (u_post_out) var (u_post_names);
				do i = 1 to num_post;
					u_post_i = u_matrices_post$i;
					post_mcmc_iteration = repeat(i, nrow(u_post_i), 1);
					
					append from post_mcmc_iteration u_post_i;
				end;
			close (u_post_out);
			
			label_list = catx(" = ", trim(u_post_names), trim(u_post_labels));
			submit u_post_out label_list;
				data &u_post_out;
					set &u_post_out;
					
					label &label_list;
				run;
			endsubmit;
		end;
		
		if has_never_consumers = 1 then do;
		
			**extract never-consumer covariates;
			use (cat(outname, "_covarsnc"));
				read all var {num_covariates} into num_covariates;
				
				if num_covariates > 0 then do;
					read all var (cat("covariate", 1:num_covariates)) into never_consumer_covariates;
				end;
				else do;
					never_consumer_covariates = {};
				end;
				
				read all var {intercept} into never_consumer_intercept;
			close (cat(outname, "_covarsnc"));
		
			**names and labels for alpha1;
			alpha1_names = cat("alpha1_covariate", 1:ncol(never_consumer_covariates));
			alpha1_labels = cat("alpha_", trim(variables[1]), "_", never_consumer_covariates);
			if upcase(never_consumer_intercept) = "Y" then do;
				alpha1_names = {"alpha1_intercept"} || alpha1_names;
				alpha1_labels = cat("alpha_", trim(variables[1]), "_intercept") || alpha1_labels;
			end;
			
			**store MCMC alpha1;
			alpha1_out = cat(outname, "_alpha1");
			create (alpha1_out) from alpha1[colname=alpha1_names];
				append from alpha1;
			close (alpha1_out);
			
			label_list = catx(" = ", trim(alpha1_names), trim(alpha1_labels));
			submit alpha1_out label_list;
				data &alpha1_out;
					set &alpha1_out;
					
					label &label_list;
				run;
			endsubmit;
			
			**name and label for consumer probabilities;
			consumer_prob_name = "consumer_probability1";
			consumer_prob_label = cat("consumer_probability_", trim(variables[1]));
			
			**store consumer probabilities;
			consumer_prob_out = cat(outname, "_conprob1");
			create (consumer_prob_out) from consumer_probabilities[colname=consumer_prob_name];
				append from consumer_probabilities;
			close (consumer_prob_out);
			
			label_list = catx(" = ", consumer_prob_name, consumer_prob_label);
			submit consumer_prob_out label_list;
				data &consumer_prob_out;
					set &consumer_prob_out;
					
					label &label_list;
				run;
			endsubmit;
		end;
		
		**store subject IDs and weighting;
		unique_subjects = unique(mcmc_subjects)`;
		create (cat(outname, "_subjects")) from unique_subjects subject_weighting[colname={"subject" "weight"}];
			append from unique_subjects subject_weighting;
		close (cat(outname, "_subjects"));
		
		**store iteration numbers;
		num_trace = num_mcmc_iterations;
		
		create (cat(outname, "_iters")) var {"num_trace" "num_mcmc_iterations" "num_burn" "num_thin" "num_post"};
			append from num_trace num_mcmc_iterations num_burn num_thin num_post;
		close (cat(outname, "_iters"));
	finish;
	
	reset storage = _modules.mcmc_modules;
	store module=(mcmc_main_loop
								output_mcmc);
quit;