/*
Program Purpose: Store modules for MCMC error-checking in IML
*/

libname _modules ".";

proc iml;

	**Updates logger from IML;
	start update_logger(error_code,
											section,
											message);
											
		**Update logger dataset;
		use _log_record;
			read all var {section} into section_record;
			read all var {message} into message_record;
			read all var {error_code} into error_code_record;
		close _log_record;
		
		section_record = section_record // section;
		message_record = message_record // message;
		error_code_record = error_code_record // error_code;
		
		create _log_record from section_record message_record error_code_record[colname={"section" "message" "error_code"}];
			append from section_record message_record error_code_record;
		close _log_record;
		
		**Write to log;
		if error_code > 0 then do;
		
			if error_code = 1 then do;
			
				error_class = "NOTE";
			end;
			else if error_code = 2 then do;
			
				error_class = "WARNING";
			end;
			else do;
			
				error_class = "ERROR";
			end;
			
			print_command = cat('%put ', error_class, ': ', message, ';');
			call execute(print_command);
		end;
	finish update_logger;
	
	**Checks design matrices;
	start check_design_matrices(covariate_matrices,
															has_never_consumers);
	
		recall_covariate_matrices = covariate_matrices$"recall";
		num_variables = ListLen(recall_covariate_matrices);
		variable_names = ListGetAllNames(recall_covariate_matrices);
		do var_j = 1 to num_variables;
		
			variable = variable_names[var_j];
			
			if IsEmpty(recall_covariate_matrices$var_j$1) then do;
			
				call update_logger(3, "Design Matrices", cat("No covariates or intercept found for ", variable));
			end;
			else do;
			
				call update_logger(0, "Design Matrices", cat("Covariates and/or intercept found for ", variable));
			end;
			
			if ^IsEmpty(recall_covariate_matrices$var_j$1) then do;
			
				num_recalls = ListLen(recall_covariate_matrices$var_j);
				recall_covariate_matrix = recall_covariate_matrices$var_j$1;
				do day_k = 2 to num_recalls;
				
					recall_covariate_matrix = recall_covariate_matrix // recall_covariate_matrices$var_j$day_k;
				end;
				
				cross = recall_covariate_matrix` * recall_covariate_matrix;
				call eigen(eigvals, eigvecs, cross);
				if any(eigvals < 10**-8) then do;
				
					call update_logger(3, "Design Matrices", cat("Linearly dependent columns in design matrix for ", variable));
				end;
				else do;
				
					call update_logger(0, "Design Matrices", cat("Full rank design matrix for ", variable));
				end;
			end;		
		end;
		
		
		if has_never_consumers = 1 then do;
		
			never_consumer_covariate_matrix = covariate_matrices$"never_consumer";
			if IsEmpty(never_consumer_covariate_matrix) then do;
			
				call update_logger(3, "Design Matrices", "No never-consumer covariates or intercept found");
			end;
			else do;
			
				call update_logger(0, "Design Matrices", "Never-consumer covariates and/or intercept found");
			end;
			
			if ^IsEmpty(never_consumer_covariate_matrix) then do;
			
				cross = never_consumer_covariate_matrix` * never_consumer_covariate_matrix;
				call eigen(eigvals, eigvecs, cross);
				if any(eigvals < 10**-8) then do;
				
					call update_logger(3, "Design Matrices", "Linearly dependent columns in never-consumer design matrix");
				end;
				else do;
				
					call update_logger(0, "Design Matrices", "Full rank never-consumer design matrix");
				end;
			end;
		end;
	finish;
	
	start format_units(bytes);
	
		if bytes < 2**10 then do;
		
			value = bytes;
			units = "B";
		end;
		else if bytes < 2**20 then do;
		
			value = bytes/(2**10);
			units = "kB";
		end;
		else if bytes < 2**30 then do;
		
			value = bytes/(2**20);
			units = "MB";
		end;
		else do;
		
			value = bytes/(2**30);
			units = "GB";
		end;
		
		value = round(value, 0.1);
		
		result = catx(" ", value, units);
		return result;
	finish;
	
	start log_u_storage(u_matrix_prior,
											save_u_main,
											save_all_u,
											num_mcmc_iterations,
											num_burn,
											num_thin,
											num_post);
					
		**Estimated size of a single U matrix;						
		u_matrix_size = nrow(u_matrix_prior)*ncol(u_matrix_prior)*8;
		
		**Main chain MCMC U matrices;
		if upcase(save_u_main) = "Y" then do;
		
			if upcase(save_all_u) = "Y" then do;
			
				num_main = num_mcmc_iterations;
			end;
			else do;
			
				num_main = int((num_mcmc_iterations - num_burn)/num_thin);
			end;
			
			call update_logger(1, "U Storage", cat("Estimated memory for main chain MCMC U matrices: ", format_units(num_main*u_matrix_size)));
		end;
		
		**Post-MCMC U matrices;
		if num_post > 0 then do;
		
			call update_logger(1, "U Storage", cat("Estimated memory for post-MCMC U matrices: ", format_units(num_post*u_matrix_size)));
		end;
	finish;
	
	start check_initialized(covariate_matrices,
										 			mcmc_priors,
										 			save_u_main,
										 			save_all_u,
										 			num_mcmc_iterations,
										 			num_burn,
										 			num_thin,
										 			num_post,
										 			has_never_consumers);
										 
		call check_design_matrices(covariate_matrices,
															 has_never_consumers);
															 
		call log_u_storage(mcmc_priors$"u_matrix_prior",
											 save_u_main,
											 save_all_u,
											 num_mcmc_iterations,
											 num_burn,
											 num_thin,
											 num_post);
											 
		submit;
			data _NULL_;
				set _log_record;
		
				if error_code = 3 then abort abend 2;
			run;
		endsubmit;
	finish;
	
	reset storage = _modules.mcmc_modules;
	store module=(update_logger
								check_design_matrices
								format_units
								log_u_storage
								check_initialized);
quit;