contact_function.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from dataclasses import dataclass
  2. import numpy as np
  3. import time
  4. import ctypes
  5. clib = ctypes.cdll.LoadLibrary('./overlap_algorithm.so')
  6. clib.contact_function.argtypes = [
  7. ctypes.c_double,
  8. ctypes.c_double,
  9. ctypes.c_double,
  10. ctypes.c_double,
  11. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  12. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  13. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  14. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  15. ctypes.c_int,
  16. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  17. np.ctypeslib.ndpointer(dtype=np.intc, ndim=1, flags='C_CONTIGUOUS'),
  18. np.ctypeslib.ndpointer(dtype=np.intc, ndim=1, flags='C_CONTIGUOUS'),
  19. ]
  20. clib.contact_function.restype = ctypes.c_double
  21. clib.contact_function_multi.argtypes = [
  22. ctypes.c_int,
  23. ctypes.c_double,
  24. ctypes.c_double,
  25. ctypes.c_double,
  26. ctypes.c_double,
  27. np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, flags='C_CONTIGUOUS'),
  28. np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, flags='C_CONTIGUOUS'),
  29. np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, flags='C_CONTIGUOUS'),
  30. np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, flags='C_CONTIGUOUS'),
  31. ctypes.c_int,
  32. np.ctypeslib.ndpointer(dtype=np.intc, ndim=1, flags='C_CONTIGUOUS'),
  33. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  34. ]
  35. clib.contact_function_multi.restype = ctypes.c_void_p
  36. def minimizer_map(minimizer: str) -> int:
  37. """Maps minimizer name to an integer for use in C functions from "overlap_algorithm.c"."""
  38. if minimizer == 'brent':
  39. return 0
  40. elif minimizer == 'brent_early':
  41. return 1
  42. elif minimizer == 'gss':
  43. return 2
  44. elif minimizer == 'gss_early':
  45. return 3
  46. else:
  47. raise ValueError('Unknown minimizer: "' + minimizer + '"')
  48. @dataclass
  49. class Contact:
  50. """Stores contact data."""
  51. contact_f: float
  52. min_eig_T: float
  53. med_eig_T: float
  54. min_vec_scalar: float # result of scalar product test fot the eigenvector corresponding to minimal eigenvalue
  55. med_vec_scalar: float # result of scalar product test fot the eigenvector corresponding to median eigenvalue
  56. feval_min: int # number of minimal eigenvalue evaluations
  57. feval_med: int # number of median eigenvalue evaluations
  58. branch: int # contact function solution branch (0: min eig, 1: med eig, 2: Omega)
  59. def contact_function(a0: float, a1: float, b0: float, b1: float,
  60. coord0: np.ndarray, orient0: np.ndarray, coord1: np.ndarray, orient1: np.ndarray,
  61. minimizer: str = 'brent') -> Contact:
  62. """Calculates contact data for a pair of spherical ellipses."""
  63. other_results = np.zeros(4, dtype=np.float64)
  64. fevals = np.zeros(2, dtype=np.intc)
  65. branch = np.zeros(1, dtype=np.intc)
  66. contact_f = clib.contact_function(a0, a1, b0, b1, coord0, orient0, coord1, orient1,
  67. minimizer_map(minimizer), other_results, fevals, branch)
  68. return Contact(contact_f, other_results[0], other_results[1], other_results[2], other_results[3],
  69. fevals[0], fevals[1], branch[0])
  70. @dataclass
  71. class ContactMulti:
  72. """
  73. Stores the contact function for many pairs of ellipses, along with average numbers of eigenvalue evaluations
  74. and total calculation time.
  75. """
  76. contact_f: np.ndarray
  77. avg_evals_mineig: float
  78. avg_evals_medeig: float
  79. time: float
  80. def contact_function_multi(a0: float, a1: float, b0: float, b1: float,
  81. coord0: np.ndarray, orient0: np.ndarray, coord1: np.ndarray, orient1: np.ndarray,
  82. minimizer: str = 'brent') -> ContactMulti:
  83. """Calculates contact function for multiple configurations of spherical ellipses."""
  84. n = len(coord0)
  85. results = np.zeros(n, dtype=np.float64)
  86. all_evals = np.zeros(2, dtype=np.intc)
  87. t0 = time.perf_counter()
  88. clib.contact_function_multi(n, a0, a1, b0, b1, coord0, orient0, coord1, orient1,
  89. minimizer_map(minimizer), all_evals, results)
  90. t1 = time.perf_counter()
  91. return ContactMulti(results, all_evals[0] / n, all_evals[1] / n, t1 - t0)
  92. class EllipsePair:
  93. def __init__(self, a0: float, a1: float, epsilon: float):
  94. """
  95. Class to calculate the contact function between two spherical ellipses.
  96. :param a0: semi-major axis for the first elliptical cylinder
  97. :param a1: semi-major axis for the second elliptical cylinder
  98. :param epsilon: aspect ratio
  99. """
  100. self.a0 = max(a0, a1) # defined so that always a0 > a1
  101. self.a1 = min(a0, a1)
  102. self.b0 = self.a0 / epsilon
  103. self.b1 = self.a1 / epsilon
  104. self.size_ratio = self.a0 / self.a1
  105. self.epsilon = epsilon
  106. def contact(self, coord0, orient0, coord1=np.array([0., 0., 1.]), orient1=np.array([1., 0., 0.]),
  107. minimizer='brent') -> Contact:
  108. """Evaluates the contact function along with oth contact data."""
  109. return contact_function(self.a0, self.a1, self.b0, self.b1, coord0, orient0, coord1, orient1, minimizer)
  110. def contact_multi_dist(self, distances, num_ang, axis=np.array([1, 0, 0]), minimizer='brent')\
  111. -> (np.ndarray, np.ndarray, np.ndarray):
  112. """
  113. Evaluates the contact function at different distances between two ellipses as well as multiple
  114. mutual orientations. Returns mean calculation time, mean number of minimum eigenvalue evaluations and
  115. mean number of median eigenvalue evaluations at each distance.
  116. """
  117. timings = np.zeros(len(distances))
  118. avg_evals_mineig = np.zeros(len(distances))
  119. avg_evals_medeig = np.zeros(len(distances))
  120. for i, dist in enumerate(distances):
  121. coords0, coords1, orients0, orients1 = generate_confgs(dist, num_ang, axis=axis)
  122. contact_multi = contact_function_multi(self.a0, self.a1, self.b0, self.b1,
  123. coords0, orients0, coords1, orients1,
  124. minimizer=minimizer)
  125. timings[i] = contact_multi.time
  126. avg_evals_mineig[i] = contact_multi.avg_evals_mineig
  127. avg_evals_medeig[i] = contact_multi.avg_evals_medeig
  128. print(f'Total time: {np.sum(timings)}')
  129. return timings, avg_evals_mineig, avg_evals_medeig
  130. def generate_confgs(dist: np.ndarray, n: int, axis: np.ndarray = np.array([1, 0, 0])) \
  131. -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
  132. """
  133. At a given distance, generate all possible orientational configurations of two spherical ellipses,
  134. with n steps in a half-circle rotation. This results in n x n configurations.
  135. The first ellipse will always be centered at the north pole while the position of the second is determined by
  136. parameter axis (and obviously distance).
  137. """
  138. angles = np.linspace(1e-7, np.pi - 1e-7, n)
  139. coords0 = np.zeros((n, 3))
  140. coords0[:, 2] = 1
  141. orients0 = np.array([1., 0., 0.])[None, :] * np.cos(angles)[:, None] + \
  142. np.array([0., 1., 0.])[None, :] * np.sin(angles)[:, None]
  143. coords1 = coords0 * np.cos(dist) + \
  144. np.cross(axis, coords0) * np.sin(dist) + \
  145. axis[None, :] * np.einsum('m, nm -> n', axis, coords0)[:, None] * (1 - np.cos(dist))
  146. orients1 = orients0 * np.cos(dist) + \
  147. np.cross(axis, orients0) * np.sin(dist) + \
  148. axis[None, :] * np.einsum('m, nm -> n', axis, orients0)[:, None] * (1 - np.cos(dist))
  149. indices = np.indices((n, n))
  150. coords0 = coords0[indices[0].flatten()]
  151. orients0 = orients0[indices[0].flatten()]
  152. coords1 = coords1[indices[1].flatten()]
  153. orients1 = orients1[indices[1].flatten()]
  154. return coords0, coords1, orients0, orients1