# discrete-normal model 
PM = function(x, s, m, G,  cnull) {
	u = G$x[,1]
	v = G$x[,2]
	fuv = G$y
	r = (m-1)/2
	A = matrix(NA, length(x), length(u))
	for (i in 1:length(x)){
		for (j in 1:length(u)){
			A[i,j] = dnorm(x[i], mean = u[j], sd = sqrt(v[j]^2/m)) * dgamma(s[i]^2, shape = r, scale = v[j]^2/r)
			}
			}
	as.vector((A%*%(u * fuv))/A%*%fuv)		
	}

# Local fdr needs a cutoff,

Lfdr = function(x, s, m, G, cnull){
    u = G$x[,1]
	v = G$x[,2]
	fuv = G$y
	r = (m-1)/2
	A = matrix(NA, length(x), length(u))
	for (i in 1:length(x)){
		for (j in 1:length(u)){
			A[i,j] = dnorm(x[i], mean = u[j], sd = sqrt(v[j]^2/m)) * dgamma(s[i]^2, shape = r, scale = v[j]^2/r)
			}
			}
    1 - c((A %*% (fuv * (u < cnull)))/(A %*% fuv))
}

# Linear shrinkage rule (implementing the empirical Bayes version of posteria mean)
hyperMLE <- function(para, y = y, s = s, t = t){
	# Normal-inverse-chi-squared prior 
	# panel data normal model with conjugate prior 
	# MAP (Gu, Zaheer, Li, IEEE (2014))
	# paper title: multiple population moment estimation: exploiting inter population correlation for efficient moment estimation in analog/mixed-signal validation
	# y is sample mean 
	# s is sample variance
	theta0 = para[1]
	kappa0 = para[2]
	nu0 = para[3]
	sig0sq = para[4]
	kappaT = kappa0 + t
	nuT = nu0 + t
	thetaT = (kappa0 * theta0 + t * y)/kappaT
	sigTsq = (nu0 * sig0sq + (t-1) * s + t * kappa0 * (theta0 - y)^2 /(kappa0 + t))/nuT
	-sum(log((gamma(nuT/2) / gamma(nu0/2)) * sqrt(kappa0 / kappaT) *( nu0 * sig0sq)^(nu0/2)/(nuT  * sigTsq)^(nuT/2) /pi^(T/2)))
	}

LinearPM = function(x, s, m, cnull, kappa0, theta0){
	(kappa0 * theta0 + m * x)/(kappa0 + m)
	}

Linearvalpha = function(x, s, m, cnull, para){
	# calculate P(theta >= theta_alpha | Y, S) based on NIX prior for (theta, sigma^2)
	# = 1 - F((theta_alpha - thetaT)/sqrt(sigTsq/kappaT)) with F being CDF of t with df = nuT 
	# need to feed in sample variance for s !!!
	theta0 = para[1]
	kappa0 = para[2]
	nu0 = para[3]
	sig0sq = para[4]
	kappaT = kappa0 + m
	nuT = nu0 + m
	thetaT = (kappa0 * theta0 + m * x)/kappaT
	sigTsq = (nu0 * sig0sq + (m-1) * s + m * kappa0 * (theta0 - x)^2 /(kappa0 + m))/nuT
	vlalpha = 1 - pt((cnull - thetaT)/sqrt(sigTsq/kappaT), df = nuT)
	return(vlalpha) 
	}
	

# Some generic code for computing inverses with uniroot
    # Useage Examples:
    # x0 = Finv(2, F = qnorm, sd = 2) 
    # x1 = mapply(Finv, 1:5/10, MoreArgs = list(F = qnorm, sd = 2)) 


rG = function(n, T, G) {
	r = (T-1)/2
sampler = sample(1:nrow(G$x),n, prob = G$y, replace = TRUE)
x = rnorm(n, G$x[sampler,1], sd = sqrt(G$x[sampler,2]^2/T))
s = sqrt(rgamma(n, shape = r, scale = G$x[sampler,2]^2/r))
list(x = x, s = s, T = T, sampler = sampler)
}

FDRcut1 = function(c, value, theta, cnull, gamma ){
	# for any thresholding rules: 1(value >= c)
	# calculate trueFDR - FDRcontrol 
	# uniroot will find the cutvalue c given gamma. 
	mean((1-value - gamma)*(value >= c))/mean((value >= c))  # this should approximate well, at least for when G is the true G0 and for tailp rules. Doesn't work for other rules. 
	}

FDRcut2 = function(c, value, theta, cnull, gamma ){
	# for any thresholding rules: 1(value >= c)
	# calculate trueFDR - FDRcontrol 
	# uniroot will find the cutvalue c given gamma.
	mean((value >= c & theta < cnull))/mean(value >= c) - gamma   #this is not quite a feasible rule since in practise theta not observed
	}

set.seed(23)
T = 9
n = 50000
r = (T-1)/2
G = list(x = matrix(c(-1, 4, 5, 6,2,4), nrow = 3), y = c(0.85, 0.1, 0.05), sigma = 1)
# could also try matrix(c(-1, 4, 5, 6,2,6), then rej reg monotone, it is crucial here 4 and 5 are close, so that PM rule are confused

D = rG(n, T, G)
truecnull = 4.7
alpha = 0.05
valpha = rep(NA,n)
palpha = rep(NA,n)

hyperest = optim(par = c(0, 7,30,30), hyperMLE, method = "L-BFGS-B", y = D$x, s = D$s^2, t = rep(T, n))$par
cnull = hyperest[1] + sqrt(hyperest[4]/hyperest[2]) * qt(1-alpha, df = hyperest[3])  


plalpha = rep(NA, n)
vlalpha = rep(NA,n)
for (i in 1:n){
	valpha[i] = Lfdr(D$x[i], D$s[i], m = D$T, G = G, cnull = truecnull)
	palpha[i] = PM(D$x[i], D$s[i], m = D$T, G = G, cnull = truecnull)
	plalpha[i] = LinearPM(D$x[i], D$s[i]^2, m = T, cnull = cnull, kappa0 = hyperest[2], theta0 = hyperest[1])
	vlalpha[i] = Linearvalpha(D$x[i], D$s[i]^2, m = T, cnull = cnull, para = hyperest)
	}
	

gamma = 0.1

TLfdr0 = quantile(valpha, 1-alpha)
TLfdr1 = uniroot(FDRcut2,interval = c(0, max(valpha)-1e-03), value = valpha, theta = G$x[D$sampler,1],cnull = truecnull, gamma = gamma)$root

Tpm0 = quantile(palpha, 1-alpha)
Tpm1 = uniroot(FDRcut2, interval = c(0, max(palpha)-1e-03), value = palpha,theta = G$x[D$sampler,1],cnull = truecnull, gamma = gamma)$root

Tpml0 = quantile(plalpha, 1-alpha)
Tpml1 = uniroot(FDRcut2, interval = c(min(plalpha)+1e-03, max(plalpha)-1e-03), value = plalpha, theta = G$x[D$sampler,1],cnull = cnull, gamma = gamma)$root

TLfdrl0 = quantile(vlalpha, 1-alpha)
TLfdrl1 = uniroot(FDRcut2,interval = c(0, max(vlalpha)-1e-03), value = vlalpha, theta = G$x[D$sampler,1],cnull = cnull, gamma = gamma)$root

ygrid = seq(1, 10, length = 200)
sgrid = seq(0.01,10, length = 200)
vys = matrix(NA, nrow = length(ygrid), ncol = length(sgrid))
pys  = matrix(NA, nrow = length(ygrid), ncol = length(sgrid))
plys  = matrix(NA, nrow = length(ygrid), ncol = length(sgrid))
vlys = matrix(NA, nrow = length(ygrid), ncol = length(sgrid))
for (i in 1:length(sgrid)){
	for (j in 1:length(ygrid)){
		vys[i,j] = Lfdr(ygrid[j], sgrid[i], m = D$T, G = G, cnull = truecnull)
		pys[i,j] = PM(ygrid[j], sgrid[i], m = D$T,  G = G, cnull = truecnull)
		plys[i,j] = LinearPM(ygrid[j], sgrid[i]^2, m = D$T, kappa0 = hyperest[2], theta0 = hyperest[1])
		vlys[i,j] = Linearvalpha(ygrid[j], sgrid[i]^2, m = D$T, cnull = cnull, para = hyperest)
		}
		}
pdf("level_rejreg_normal_discrete_panel.pdf", height = 5, width = 15)	
par(mfrow=c(1,3))		
contour(sgrid,ygrid, vys, levels = c(0.01, 0.05, 1:9/10), xlab = "s", ylab = "y", main = "level curves tailp")		
contour(sgrid,ygrid,pys,levels = c(4.8, 4.5,4.1, 4,3.5,2.5,1, 0),col=2,lty = 2, xlab = "s", ylab = "y", main = "level curves PM")	

contour(sgrid,ygrid,vys, levels = round(max(TLfdr0,TLfdr1),digits = 3),xlab = "s", ylab = "y", main = "selection region")
contour(sgrid, ygrid,pys,levels = round(max(Tpm0, Tpm1), digits = 3), add = TRUE, col = 2, lty = 2)	
#contour(sgrid,ygrid,vlys, levels = round(max(TLfdrl0,TLfdrl1),digits = 3),add=TRUE,col=4, lty = 2)

#contour(sgrid,ygrid,vys, levels = round(TLfdr0,digits = 3),xlab = "s", ylab = "y", main = "rejection region (capacity)")
#contour(sgrid, ygrid,pys,levels = round(Tpm0, digits = 3), add = TRUE, col = 2, lty = 2)	

	dev.off()
	

# visualize the function vys
#require(rgl)
#ys = expand.grid(s = sgrid, y = ygrid)
#ys$v = c(vys)

#plot3d(ys[,1],ys[,2],ys[,3])


	
# power analysis 
   	# P(non-null case & reject)/P(non-null) : interpret as among all non-null cases, what's the probability of correct rejection. 
# use sample counterpart to approximate 



set.seed(22)
require(LaplacesDemon)
R = 200
alpha.gamma = expand.grid(c(0.05, 0.1, 0.15), c(0.01, 0.05, 0.1))
alpha.gamma$cnull = rep(c(4.7, 3.7, 3.7), 3)
power = array(NA,c(R,12,nrow(alpha.gamma)))
mfdr  = array(NA,c(R,12,nrow(alpha.gamma)))
rejprop = array(NA,c(R,12,nrow(alpha.gamma)))
FDRcomp = array(NA,c(R,12,nrow(alpha.gamma)))
for (j in 1:nrow(alpha.gamma)){
for (sim in 1:R){
T = 9
n = 50000
r = (T-1)/2
G = list(x = matrix(c(-1, 4, 5, 6,2,4), nrow = 3), y = c(0.85, 0.1, 0.05), sigma = 1)

D = rG(n, T, G)
theta = G$x[D$sampler,1]
cnull = alpha.gamma[j,3]
alpha = alpha.gamma[j,1]
gamma = alpha.gamma[j,2]
valpha = rep(NA,n)
palpha = rep(NA,n)
zscore = rep(NA,n)
MLE = D$x
for (i in 1:n){
	valpha[i] = Lfdr(D$x[i], D$s[i], m = D$T, G = G, cnull = cnull)
	palpha[i] = PM(D$x[i], D$s[i], m = D$T, G = G, cnull = cnull)
	zscore[i] = (D$x[i]-(cnull + 0.3))/(D$s[i]/sqrt(T))  # oracle z value 
		}
	


TLfdr0 = quantile(valpha, 1-alpha)
TLfdr1 = uniroot(FDRcut2,interval = c(0, max(valpha)-1e-03), value = valpha, theta = G$x[D$sampler,1],cnull = cnull, gamma = gamma,extendInt="yes")$root

Tpm0 = quantile(palpha, 1-alpha)
Tpm1 = uniroot(FDRcut2, interval = c(0, max(palpha)-1e-03), value = palpha,theta = G$x[D$sampler,1],cnull = cnull, gamma = gamma,extendInt="yes")$root

Tz0 = quantile(zscore, 1 -alpha)
Tz1 = try(uniroot(FDRcut2, interval = c(min(zscore)+1e-03, max(zscore)-1e-03), value = zscore,theta = G$x[D$sampler,1],cnull = cnull, gamma = gamma,extendInt="yes")$root, silent = TRUE)

Ty0 = quantile(MLE, 1-alpha)
Ty1 = try(uniroot(FDRcut2, interval = c(min(MLE)+1e-03, max(MLE)-1e-03), value = MLE,theta = G$x[D$sampler,1],cnull = cnull, gamma = gamma,extendInt="yes")$root)
# power analysis 
   	# P(non-null case & reject)/P(non-null) : interpret as among all non-null cases, what's the probability of correct rejection. 
# use sample counterpart to approximate 
Pnonnull = mean(theta  >= cnull) 
selp = which(palpha > max(Tpm0,Tpm1))
selv = which(valpha> max(TLfdr0,TLfdr1))

selp1 = which(palpha > Tpm0)
selv1= which(valpha> TLfdr0)


selp2 = which(palpha > Tpm1)
selv2 = which(valpha> TLfdr1)



trueset = which(theta >= cnull)
discp = length(intersect(trueset, selp))
discv = length(intersect(trueset,selv))
powerp = length(intersect(trueset, selp))/n/Pnonnull
powerv = length(intersect(trueset, selv))/n/Pnonnull
powerp1 = length(intersect(trueset, selp1))/n/Pnonnull
powerv1 = length(intersect(trueset, selv1))/n/Pnonnull
powerp2 = length(intersect(trueset, selp2))/n/Pnonnull
powerv2 = length(intersect(trueset, selv2))/n/Pnonnull

rejp = length(selp)/n
rejv = length(selv)/n
rejp1 = length(selp1)/n
rejp2 = length(selp2)/n
rejv1 = length(selv1)/n
rejv2 = length(selv2)/n
FDRp = length(setdiff(selp,trueset))/length(selp)
FDRv = length(setdiff(selv,trueset))/length(selv)
FDRp1 = length(setdiff(selp1,trueset))/length(selp1)
FDRp2 = length(setdiff(selp2, trueset))/length(selp2)
FDRv1 = length(setdiff(selv1, trueset))/length(selv1)
FDRv2 = length(setdiff(selv2, trueset))/length(selv2)

if (inherits(Tz1, "try-error")){
	powerz = NA
	powerz1 = NA
	powerz2 = NA
	FDRz = NA
	FDRz1 = NA
	FDRz2 = NA
	rejz = NA
	rejz1 = NA
	rejz2 = NA
	}else{
		selz = which(zscore> max(Tz0, Tz1))
		selz1 = which(zscore > Tz0)
		selz2 = which(zscore > Tz1)
		powerz = length(intersect(trueset,selz))/n/Pnonnull
		powerz1 = length(intersect(trueset, selz1))/n/Pnonnull
		powerz2 = length(intersect(trueset, selz2))/n/Pnonnull
		rejz = length(selz)/n
		rejz1 = length(selz1)/n
		rejz2 = length(selz2)/n
		FDRz = length(setdiff(selz, trueset))/length(selz)
		FDRz1 = length(setdiff(selz1, trueset))/length(selz1)
		FDRz2 = length(setdiff(selz2, trueset))/length(selz2)
	}

if (inherits(Ty1, "try-error")){
	powery = NA
	powery1 = NA
	powery2 = NA
	FDRy = NA
	FDRy1 = NA
	FDRy2 = NA
	rejy = NA
	rejy1 = NA
	rejy2=NA
	}else{
sely = which(MLE > max(Ty0, Ty1))
sely1 = which(MLE > Ty0)
sely2 = which(MLE > Ty1)
powery = length(intersect(trueset,sely))/n/Pnonnull
powery1 = length(intersect(trueset, sely1))/n/Pnonnull
powery2 = length(intersect(trueset,sely2))/n/Pnonnull
rejy = length(sely)/n
rejy1 = length(sely1)/n
rejy2 = length(sely2)/n
FDRy = length(setdiff(sely, trueset))/length(sely)
FDRy1 = length(setdiff(sely1, trueset))/length(sely1)
FDRy2 = length(setdiff(sely2, trueset))/length(sely2)
}

power[sim,,j] = c(powerp, powerv, powerz,powery, powerp1, powerv1, powerz1, powery1, powerp2, powerv2, powerz2, powery2)
mfdr[sim,,j] = c(FDRp, FDRv, FDRz, FDRy, FDRp1, FDRv1, FDRz1, FDRy1, FDRp2, FDRv2, FDRz2, FDRy2)
rejprop[sim,,j] = c(rejp, rejv, rejz,rejy, rejp1, rejv1, rejz1, rejy1,rejp2, rejv2, rejz2, rejy2)
print(c(j,sim))
}
}


#POWER = round(apply(power, 2, mean),digits = 4)
#MFDR = round(apply(mfdr,2,mean), digits = 4)
#REJ = round( apply(rejprop,2,mean), digits = 4)

#require(xtable)
#xtable(rbind(POWER,MFDR,REJ), digits = 4)