three_parameter_screening.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import numpy as np
  2. from numpy.random import Generator
  3. from numpy.typing import NDArray
  4. from scipy import stats
  5. from scipy.optimize import minimize
  6. # region 3 Key Parameters
  7. def sensitivity_function(t, b0, b1, t_bar):
  8. """
  9. Calculate the sensitivity function β(t).
  10. Parameters
  11. ----------
  12. t : float
  13. Time variable.
  14. b0 : float
  15. Coefficient b0.
  16. b1 : float
  17. Coefficient b1.
  18. t_bar : float
  19. Average age at entry in the study group.
  20. average_life_expectancy : float
  21. Average life expectancy in years.
  22. Returns
  23. -------
  24. float
  25. Sensitivity at time t.
  26. """
  27. return 1 / (1 + np.exp(-b0 - b1 * (t - t_bar)))
  28. def transition_probability_function(t, w_max, mu, sigma):
  29. """
  30. Calculate the transition probability function w(t).
  31. Parameters
  32. ----------
  33. t : float
  34. Time variable, must be greater than 0.
  35. w_max : float
  36. Maximum weight.
  37. mu : float
  38. Mean of the log-normal distribution.
  39. sigma : float
  40. Standard deviation of the log-normal distribution.
  41. Returns
  42. -------
  43. float
  44. Transition probability at time t.
  45. """
  46. if t <= 0:
  47. raise ValueError("t must be greater than 0")
  48. return w_max * (1 / (np.sqrt(2 * np.pi) * sigma * t)) * np.exp(-((np.log(t) - mu) ** 2) / (2 * sigma ** 2))
  49. def sojourn_time_function(x, kappa, rho):
  50. """
  51. Calculate the sojourn time function q(x).
  52. Parameters
  53. ----------
  54. x : float
  55. Time variable.
  56. kappa : float
  57. Shape parameter.
  58. rho : float
  59. Scale parameter.
  60. Returns
  61. -------
  62. float
  63. Sojourn time at x.
  64. """
  65. return (kappa * x ** (kappa - 1) * rho ** kappa) / ((1 + (x * rho) ** kappa) ** 2)
  66. # endregion
  67. # region Random Variates
  68. def transition_survival_rvs(rng: Generator, w: float, size: int = 1) -> np.ndarray[np.float64]:
  69. """
  70. Generate random variates of transition survival.
  71. Parameters
  72. ----------
  73. rng : np.random.Generator
  74. Random number generator.
  75. w : float
  76. Scale parameter that represents the probability of transition
  77. from healthy state to preclinical state.
  78. size : int, optional
  79. Number of random variates to generate.
  80. Returns
  81. -------
  82. np.ndarray[np.float64]
  83. Times when a population transitions from healthy state to preclinical state.
  84. """
  85. return rng.uniform(low=0, high=1, size=size) / w
  86. def transition_survival_rvs_alt(rng: Generator, t: float, a: float, size: int = 1) -> np.ndarray[np.float64]:
  87. """
  88. Generate random variates of transition survival.
  89. Alternative implementation of transition_survival_rvs that returns np.nan for times that are greater than T.
  90. Parameters
  91. ----------
  92. rng : np.random.Generator
  93. Random number generator.
  94. t : float
  95. Life expectancy.
  96. a : float
  97. Lifetime risk.
  98. size : int, optional
  99. Number of random variates to generate.
  100. Returns
  101. -------
  102. np.ndarray[np.float64]
  103. Times when a population transitions from healthy state to preclinical state.
  104. """
  105. times = rng.uniform(low=0, high=1, size=size) / a
  106. times[times < 1] = np.nan
  107. return times * t
  108. def sojourn_survival_rvs(rng: Generator, mu: float, t_w: np.ndarray[np.float64]) -> np.ndarray[np.float64]:
  109. """
  110. Generate random variates of sojourn survival.
  111. Parameters
  112. ----------
  113. rng : np.random.Generator
  114. Random number generator.
  115. mu : float
  116. Scale parameter that represents the probability of transition
  117. from preclinical state to detectable state.
  118. t_w : float
  119. Scale parameter that indicates the time when the person transitioned
  120. from healthy state to preclinical state.
  121. Returns
  122. -------
  123. float
  124. Time when a person transitions from preclinical state to cancer state.
  125. """
  126. return -mu * np.log(1-rng.uniform(low=0, high=1, size=len(t_w)))
  127. # region Likelihood
  128. def likelihood_function(D, I, n, s, r, alpha, beta):
  129. """
  130. Calculate the likelihood function.
  131. Parameters
  132. ----------
  133. D : list
  134. List of probabilities of preclinical diagnosis.
  135. I : list
  136. List of probabilities of clinical incidence.
  137. n : list
  138. List of total number of cases for each interval.
  139. s : list
  140. List of number of preclinical diagnoses for each interval.
  141. r : list
  142. List of number of clinical incidences for each interval.
  143. alpha : list
  144. List of alpha values for the product term.
  145. beta : float
  146. The probability of detection.
  147. Returns
  148. -------
  149. float
  150. The likelihood value.
  151. """
  152. L = 1.0
  153. for i in range(len(D)):
  154. term1 = D[i] ** s[i]
  155. term2 = I[i] ** r[i]
  156. term3 = (1 - D[i] - I[i]) ** (n[i] - s[i] - r[i])
  157. product_term = np.prod([(alpha[j] / beta) ** s[i][j] for j in range(3)])
  158. L_i = term1 * term2 * term3 * product_term
  159. L *= L_i
  160. return L
  161. def probability_of_clinical_incidence(beta, w, mu, t, Q):
  162. """
  163. Calculate the probability of clinical incidence.
  164. Parameters
  165. ----------
  166. w : float
  167. Weighting factor.
  168. mu : float
  169. Mean of the distribution.
  170. t : list
  171. List of time intervals.
  172. Q : function
  173. Function Q(t) representing the probability distribution.
  174. beta : float
  175. The probability of detection.
  176. Returns
  177. -------
  178. list
  179. A list of probabilities I_i for each time interval.
  180. """
  181. I = []
  182. for i in range(1, len(t) + 1):
  183. sum_term = sum((1 - beta) ** (i - j - 1) * (Q(t[i - 1] - t[j]) - Q(t[i] - t[j])) for j in range(i))
  184. I_i = w * mu * ((t[i] - t[i - 1]) / mu - beta * sum_term)
  185. I.append(I_i)
  186. return I
  187. def probability_of_preclinical_diagnosis(beta, w, mu, t, Q):
  188. """
  189. Calculate the probability of preclinical diagnosis.
  190. Parameters
  191. ----------
  192. beta : float
  193. The probability of detection.
  194. w : float
  195. Weighting factor.
  196. mu : float
  197. Mean of the distribution.
  198. t : list
  199. List of time intervals.
  200. Q : function
  201. Function Q(t) representing the probability distribution.
  202. Returns
  203. -------
  204. list
  205. A list of probabilities D_i for each time interval.
  206. """
  207. D = []
  208. for i in range(1, len(t) + 1):
  209. if i == 1:
  210. D_i = beta * w * mu
  211. else:
  212. sum_term = sum((1 - beta) ** (i - j - 1) * Q(t[i - 1] - t[j - 1]) for j in range(1, i))
  213. D_i = beta * w * mu * (1 - beta * sum_term)
  214. D.append(D_i)
  215. return D
  216. # endregion
  217. # region Optimization
  218. def loss_function(params, real_data, population_size):
  219. sensitivity, specificity, risk, screening_interval, onset_interval = params
  220. simulated_data = analytic_simulation(sensitivity, specificity, risk, screening_interval, onset_interval, population_size)
  221. return np.sum((simulated_data - real_data) ** 2)
  222. def fit_to_real_data(real_data, initial_guess, population_size):
  223. result = minimize(loss_function, initial_guess, args=(real_data, population_size), method='BFGS')
  224. return result.x
  225. # endregion
  226. # region Deprecated
  227. def analytic_simulation(sensitivity, specificity, risk, screening_interval, onset_interval, population_size, average_life_expectancy) -> np.array:
  228. """
  229. Analytic simulation of screening programme.
  230. Parameters
  231. ----------
  232. sensitivity : float
  233. Sensitivity of the screening test.
  234. specificity : float
  235. Specificity of the screening test.
  236. risk : float
  237. Risk of cancer in the population.
  238. screening_interval : float
  239. Interval between screenings in years.
  240. onset_interval : float
  241. Interval between onset of cancer and detection in years.
  242. population_size : int
  243. Size of the population.
  244. Returns
  245. -------
  246. array
  247. A numpy array with the following elements:
  248. - The number of people with cancer.
  249. - The number of people with cancer detected through screening.
  250. - The number of people with cancer that was not detected through screening and was detected through symptoms.
  251. - The number of people without cancer.
  252. - The number of false positives.
  253. """
  254. population = stats.bernoulli.rvs(risk, size=population_size)
  255. detected_probability = stats.geom.cdf(onset_interval * screening_interval, sensitivity)
  256. # cancer
  257. cancer_quantity = sum(population)
  258. screening_population = stats.bernoulli.rvs(p=detected_probability, size=cancer_quantity)
  259. screening_detected_q = sum(screening_population)
  260. interval_cancer_q = len(screening_population) - screening_detected_q
  261. # recall
  262. recall_q = len(population) - cancer_quantity
  263. recall_total = 0
  264. recall_total = np.sum(stats.binom.rvs(n=(int)(average_life_expectancy * screening_interval), p=(1-specificity), size=recall_q))
  265. return np.array([cancer_quantity, screening_detected_q, interval_cancer_q, recall_q, recall_total])
  266. # endregion