roc.R 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. #' Create a ROC curve
  2. #'
  3. #' @param labels vector of outcomes
  4. #' @param scores vector of variable values
  5. #' @param decreasing ordering of scores
  6. #'
  7. #' @return data frame with columns TPR, FPR, labels, scores, V
  8. #' V is combination of V0 (negative cases) and V1 (positive cases)
  9. #' V0 is portion of positive cases with value larger than score
  10. #' V1 is portion of negative cases with value below score
  11. #'
  12. #' @export
  13. simple.roc <- function(labels, scores, decreasing = TRUE) {
  14. # Odstranimo NA vrednosti iz labels in scores
  15. valid_data = stats::complete.cases(labels, scores)
  16. labels = labels[valid_data]
  17. scores = scores[valid_data]
  18. labels <- labels[order(scores, decreasing=decreasing)]
  19. scores <- scores[order(scores, decreasing=decreasing)]
  20. labels = base::as.logical(labels)
  21. x = base::data.frame(TPR = base::cumsum(labels) / base::sum(labels),
  22. FPR = base::cumsum(!labels) / base::sum(!labels),
  23. labels, scores)
  24. s1 = scores[labels]
  25. s0 = scores[!labels]
  26. #V0 are 0 components of test, for each test-negative case find number of positive cases with value above it (ideally, count is equal to n1, number of positive cases)
  27. V0 = base::sapply(s0, psi01, s=s1)
  28. #V1 are 1 components of test; for each test-positive case find number of negative cases with score below its score (ideally, count is equal to n0, number of negative cases)
  29. V1 = base::sapply(s1, psi10, s=s0)
  30. n0=base::length(s0)
  31. n1=base::length(s1)
  32. #convert V0 to portions, to [0,1] range
  33. V0=V0/n1
  34. #convert V1 to portions, to [0,1] range
  35. V1=V1/n0
  36. V = labels
  37. V[labels] <- V1
  38. V[!labels] <- V0
  39. x$V = V
  40. x
  41. }
  42. psi01 <- function(y, s) {
  43. v1 = base::sum(s > y, na.rm = TRUE)
  44. v2 = base::sum(s == y, na.rm = TRUE)
  45. v1 + 0.5 * v2
  46. }
  47. psi10 <- function(x, s) {
  48. v1 = base::sum(s < x, na.rm = TRUE)
  49. v2 = base::sum(s == x, na.rm = TRUE)
  50. v1 + 0.5 * v2
  51. }
  52. #' Calculate AUC
  53. #' @param roc as returned by simple.roc
  54. #'
  55. #' @return AUC (numeric)
  56. #'
  57. #' @export
  58. simple.getAUC <- function(roc) {
  59. V0 = roc[!roc$labels, 'V']
  60. mean(V0)
  61. }
  62. #' Calculate metrics associated with ROC
  63. #'
  64. #' @param ROC roc as returned by simple.roc
  65. #'
  66. #' @return a list with elemnts: : FPR, TPR, threshold, FPR CI95:min , FPR CI95: max, TPR CI95:min, TPR CI95:max
  67. #' all evaluated at optimum point (Youden index)
  68. #'
  69. #' @export
  70. simple.compute_roc_metrics <- function(ROC) {
  71. #determine optimal point on ROC curve using Youden index
  72. #ROC is the output of simple.roc
  73. #output is a list : tpr, tpr CI95:min , tpr CI95: max, fpr, fpr CI95:min, fpr CI95:max, value of score,
  74. #all evaluated at optimum point,
  75. n=base::sum(ROC$labels)
  76. m=base::sum(!ROC$labels)
  77. nt=base::length(ROC$TPR)
  78. dist<-base::abs(ROC$TPR-ROC$FPR)
  79. imax<-base::which.max(dist)
  80. hTPR<-stats::prop.test(x=ROC$TPR[imax]*n,n=n,conf.level=0.95,correct=FALSE)
  81. hFPR<-stats::prop.test(x=ROC$FPR[imax]*m,n=m,conf.level=0.95,correct=FALSE)
  82. return(base::list(
  83. FPR = ROC$FPR[imax],
  84. TPR = ROC$TPR[imax],
  85. threshold = ROC$scores[imax],
  86. lFPR = hFPR$conf.int[1],
  87. hFPR = hFPR$conf.int[2],
  88. lTPR = hTPR$conf.int[1],
  89. hTPR = hTPR$conf.int[2]
  90. ))
  91. }
  92. #' Calculate covariance of AUC (deLonghi)
  93. #'
  94. #' @param roc roc as calculated by simple.roc
  95. #'
  96. #' @return standard deviation of AUC
  97. #'
  98. #' @export
  99. simple.sAUC <- function(roc) {
  100. V1 = roc[roc$labels, 'V']
  101. AUC = base::mean(V1)
  102. V0 = roc[!roc$labels, 'V']
  103. n0 = base::length(V0)
  104. n1 = base::length(V1)
  105. S0 = base::sum((V0 - AUC) * (V0 - AUC)) / (n0 - 1)
  106. S1 = base::sum((V1 - AUC) * (V1 - AUC)) / (n1 - 1)
  107. base::sqrt(S0 / n0 + S1 / n1)
  108. }
  109. #' Calculate covariance of AUC (approximation)
  110. #'
  111. #' @param roc roc as calculated by simple.roc
  112. #'
  113. #' @return standard deviation of AUC
  114. #'
  115. #' @export
  116. simple.sAUCapprox <- function(roc) {
  117. n = base::length(roc$labels)
  118. auc = simple.getAUC(roc)
  119. auc_se = base::sqrt((auc * (1 - auc) + (n - 1) * (auc - 0.5)^2) / n)
  120. auc_se
  121. }
  122. #' Check whether two AUCs are statistically significantly different
  123. #'
  124. #' @param rocA roc of the first test as returned by simple.roc
  125. #' @param rocB roc of the second test as returned by simple.roc
  126. #'
  127. #' @return p of the test
  128. #'
  129. #' @export
  130. simple.sAUC2<-function(rocA,rocB){
  131. #NCSS
  132. #https://www.ncss.com/wp-content/themes/ncss/pdf/Procedures/NCSS/Comparing_Two_ROC_Curves-Paired_Design.pdf
  133. #calculate combined variance of two predictions on the same dataset
  134. #and give p-value that their performance is the same (or, with low p, that it is different)
  135. V0A=rocA[!rocB$labels,'V']
  136. V0B=rocB[!rocB$labels,'V']
  137. V1A=rocA[rocA$labels,'V']
  138. V1B=rocB[rocB$labels,'V']
  139. AUCA=base::mean(V0A)
  140. AUCB=base::mean(V0B)
  141. n0=base::length(V0A)#the same as V0B
  142. n1=base::length(V1A)#the same as V1B
  143. base::print(base::sprintf('A=%f B=%f',AUCA,AUCB))
  144. #variance of the 0 component, A
  145. S0A=base::sum((V0A-AUCA)*(V0A-AUCA))/(n0-1)
  146. #variance of 1 component, A
  147. S1A=base::sum((V1A-AUCA)*(V1A-AUCA))/(n1-1)
  148. #variance of A
  149. SA=S0A/n0+S1A/n1
  150. #variance of the 0 component, B
  151. S0B=base::sum((V0B-AUCB)*(V0B-AUCB))/(n0-1)
  152. #variance of 1 component, B
  153. S1B=base::sum((V1B-AUCB)*(V1B-AUCB))/(n1-1)
  154. #variance of B
  155. SB=S0B/n0+S1B/n1
  156. #covariance 0 component
  157. S0AB=base::sum((V0A-AUCA)*(V0B-AUCB))/(n0-1)
  158. #covariance 1 component
  159. S1AB=base::sum((V1A-AUCA)*(V1B-AUCB))/(n1-1)
  160. #covariance
  161. SAB=S0AB/n0+S1AB/n1
  162. S=SA+SB-2*SAB
  163. #is there a significant difference
  164. z=base::abs(AUCA-AUCB)/base::sqrt(S)
  165. p=2*stats::pnorm(z,mean=0,sd=1,lower.tail=FALSE)
  166. #is A larger than B
  167. #z=(AUCA-AUCB)/sqrt(S)
  168. #p=pnorm(z,mean=0,sd=1,lower.tail=FALSE)
  169. base::print(base::sprintf('SA2=%f SB2=%f SAB2=%f S2=%f S=%f z=%f p=%f',SA,SB,SAB,S,base::sqrt(S),z,p))
  170. p
  171. }
  172. #' Plot a ROC curve with associated annotations
  173. #'
  174. #'@param df data frame
  175. #'@param var variable to use to stratify patients
  176. #'@param col color of the line drawn
  177. #'@param x x coordinate of legend
  178. #'@param y y coordinate of legend
  179. #'@param unit - what unit to associate to thrshold on legend (ml)
  180. #'@param precise number of decimal places to use when reporting opt threshold, TRUE:2, FALSE:0
  181. #'@param target column that holds binary outcomes
  182. #'
  183. #'@return list object with items: roc object as created by simple.roc, thr optimal threshold, legend_text text to be put on final legend
  184. #'@export
  185. simple.plotROC<-function(df,var,col="black",x=0.65,y=0.1,unit="ml",precise=FALSE,target='alive'){
  186. roc=simple.roc(df[,target],df[,var])
  187. auc=simple.getAUC(roc)
  188. #calculate sAUC (sigma AUC, see DeLonghi)
  189. sAUC=simple.sAUC(roc)
  190. #approximate version of sAUC for historical reasons
  191. sAUCapprox <- simple.sAUCapprox(roc)
  192. #determine optimum point for test use (Youden index)
  193. roc_metrics <- simple.compute_roc_metrics(roc)
  194. #report sensitivity/specificity/CI95 at opt threshold
  195. print(sprintf('[%s] Opt: sens %.2f (%.2f,%.2f), spec %.2f (%.2f,%.2f)',
  196. var,roc_metrics$TPR,roc_metrics$lTPR,roc_metrics$hTPR,
  197. 1-roc_metrics$FPR,1-roc_metrics$hFPR,1-roc_metrics$lFPR))
  198. legend_text <- sprintf("[%s] AUC: %.2f (+- %.2f/%.2f), OPT THR: %.2f",
  199. var, auc, sAUC, sAUCapprox, roc_metrics$threshold)
  200. #draw opt point
  201. graphics::points(roc_metrics$FPR,roc_metrics$TPR,pch=1,col=col,cex=2)
  202. graphics::lines(roc$FPR,roc$TPR,col=col)
  203. base::list(roc=roc,thr=roc_metrics$threshold,legend_text=legend_text)
  204. }
  205. #'Plot an assembly of ROCs
  206. #'
  207. #'@param df data frame
  208. #'@param vars vector of variables
  209. #'@param cols vector of color names
  210. #'@param x coordinate for legend
  211. #'@param y coordinate for legend
  212. #'@param unit unit for threshold in labels
  213. #'@param precise number of decimal places to use when reporting opt threshold, TRUE:2, FALSE:0
  214. #'@param target column that holds binary outcomes
  215. #'
  216. #'@return ggplot2 graphical object
  217. #'
  218. #'@export
  219. simple.plotROCgg<-function(df,vars,cols,x=0.7,y=0.3,unit="ml",precise="FALSE",target="alive"){
  220. if (!requireNamespace('ggplot2',quiet=TRUE)){
  221. print('ggplot2 not available. Use simple.plotROC function')
  222. return(NULL)
  223. }
  224. #ggplot alternative
  225. cvalues<-base::c()
  226. colors_used<-base::c()
  227. i=1
  228. g<-ggplot2::ggplot()
  229. for (var in vars) {
  230. if (var %in% base::names(df)) {
  231. # Pred izračunom ROC odstranimo vrstice z NA vrednostmi
  232. df<-mapNA(df,var,0)
  233. roc=simple.roc(df[,target],df[,var])
  234. #!! defuses evaluation of variable. https://rlang.r-lib.org/reference/topic-inject.html
  235. col=cols[i]
  236. roc_metrics <- simple.compute_roc_metrics(roc)
  237. auc=simple.getAUC(roc)
  238. sAUC=simple.sAUC(roc)
  239. lab <- base::sprintf("[%s] AUC: %.2f (+- %.2f), OPT THR: %.2f",
  240. var, auc, sAUC, roc_metrics$threshold)
  241. g<-g+ggplot2::geom_line(ggplot2::aes(x=!!roc$FPR,y=!!roc$TPR,color=!!lab))
  242. aes_p=ggplot2::aes(x=!!roc_metrics$FPR,y=!!roc_metrics$TPR,color=!!lab)
  243. g<-g+ggplot2::geom_point(aes_p,size=4,shape=1,show.legend=FALSE)
  244. colors_used<-base::c(colors_used,col)
  245. cvalues<-base::c(cvalues,lab)
  246. base::print(sprintf('%d/%d',base::length(colors_used),base::length(cvalues)))
  247. i=i+1
  248. }
  249. }
  250. base::names(colors_used)<-cvalues
  251. g+ggplot2::xlab('1-specificity')+
  252. ggplot2::ylab('sensitivity')+
  253. ggplot2::scale_color_manual(name='Variables',values=colors_used)+
  254. ggplot2::guides(color = ggplot2::guide_legend(position = "inside"))+
  255. myTheme()+
  256. ggplot2::theme(legend.position.inside=base::c(x,y))
  257. }
  258. myTheme<-function(){
  259. ggplot2::theme(
  260. axis.text=element_text(size=14),
  261. axis.title=element_text(size=16,face="bold"),
  262. legend.title=element_text(size=16,face="bold"),
  263. legend.text=element_text(size=12),
  264. plot.title=element_text(size=16,face="bold",hjust=0.5))
  265. }