contact_function.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from dataclasses import dataclass
  2. import numpy as np
  3. import time
  4. import ctypes
  5. clib = ctypes.cdll.LoadLibrary('./overlap_algorithm.so')
  6. clib.contact_function.argtypes = [
  7. ctypes.c_double,
  8. ctypes.c_double,
  9. ctypes.c_double,
  10. ctypes.c_double,
  11. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  12. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  13. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  14. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  15. ctypes.c_int,
  16. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  17. np.ctypeslib.ndpointer(dtype=np.intc, ndim=1, flags='C_CONTIGUOUS'),
  18. np.ctypeslib.ndpointer(dtype=np.intc, ndim=1, flags='C_CONTIGUOUS'),
  19. ]
  20. clib.contact_function.restype = ctypes.c_double
  21. clib.contact_function_multi.argtypes = [
  22. ctypes.c_int,
  23. ctypes.c_double,
  24. ctypes.c_double,
  25. ctypes.c_double,
  26. ctypes.c_double,
  27. np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, flags='C_CONTIGUOUS'),
  28. np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, flags='C_CONTIGUOUS'),
  29. np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, flags='C_CONTIGUOUS'),
  30. np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, flags='C_CONTIGUOUS'),
  31. ctypes.c_int,
  32. np.ctypeslib.ndpointer(dtype=np.intc, ndim=1, flags='C_CONTIGUOUS'),
  33. np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS'),
  34. ]
  35. clib.contact_function_multi.restype = ctypes.c_void_p
  36. def minimizer_map(minimizer: str) -> int:
  37. if minimizer == 'brent':
  38. return 0
  39. elif minimizer == 'brent_early':
  40. return 1
  41. elif minimizer == 'gss':
  42. return 2
  43. elif minimizer == 'gss_early':
  44. return 3
  45. else:
  46. raise ValueError('Unknown minimizer: "' + minimizer + '"')
  47. @dataclass
  48. class Contact:
  49. contact_f: float
  50. min_eig_T: float
  51. med_eig_T: float
  52. min_vec_scalar: float
  53. med_vec_scalar: float
  54. feval_min: int
  55. feval_med: int
  56. branch: int
  57. def contact_function(a0, a1, b0, b1, coord0, orient0, coord1, orient1, minimizer='brent') -> Contact:
  58. other_results = np.zeros(4, dtype=np.float64)
  59. fevals = np.zeros(2, dtype=np.intc)
  60. branch = np.zeros(1, dtype=np.intc)
  61. contact_f = clib.contact_function(a0, a1, b0, b1, coord0, orient0, coord1, orient1,
  62. minimizer_map(minimizer), other_results, fevals, branch)
  63. return Contact(contact_f, other_results[0], other_results[1], other_results[2], other_results[3],
  64. fevals[0], fevals[1], branch[0])
  65. @dataclass
  66. class ContactMulti:
  67. contact_f: np.ndarray
  68. avg_evals_mineig: float
  69. avg_evals_medeig: float
  70. time: float
  71. def contact_function_multi(a0, a1, b0, b1, coord0, orient0, coord1, orient1, minimizer='brent') -> ContactMulti:
  72. n = len(coord0)
  73. results = np.zeros(n, dtype=np.float64)
  74. all_evals = np.zeros(2, dtype=np.intc)
  75. t0 = time.perf_counter()
  76. clib.contact_function_multi(n, a0, a1, b0, b1, coord0, orient0, coord1, orient1,
  77. minimizer_map(minimizer), all_evals, results)
  78. t1 = time.perf_counter()
  79. return ContactMulti(results, all_evals[0] / n, all_evals[1] / n, t1 - t0)
  80. class EllipsePair:
  81. def __init__(self, a0, a1, epsilon):
  82. """
  83. Class in which optimization of ellipse configurations on a sphere is performed.
  84. :param a0: semi-major axis for the first elliptical cylinder
  85. :param a1: semi-major axis for the second elliptical cylinder
  86. :param epsilon: aspect ratio
  87. """
  88. self.a0 = max(a0, a1) # defined so that always a0 > a1
  89. self.a1 = min(a0, a1)
  90. self.b0 = self.a0 / epsilon
  91. self.b1 = self.a1 / epsilon
  92. self.size_ratio = self.a0 / self.a1
  93. self.epsilon = epsilon
  94. def contact(self, coord0, orient0, coord1=np.array([0., 0., 1.]), orient1=np.array([1., 0., 0.]),
  95. minimizer='brent') -> Contact:
  96. return contact_function(self.a0, self.a1, self.b0, self.b1, coord0, orient0, coord1, orient1, minimizer)
  97. def contact_multi_dist(self, distances, num_ang, axis=np.array([1, 0, 0]), minimizer='brent')\
  98. -> (np.ndarray, np.ndarray, np.ndarray):
  99. timings = np.zeros(len(distances))
  100. avg_evals_mineig = np.zeros(len(distances))
  101. avg_evals_medeig = np.zeros(len(distances))
  102. for i, dist in enumerate(distances):
  103. coords0, coords1, orients0, orients1 = generate_confgs(dist, num_ang, axis=axis)
  104. contact_multi = contact_function_multi(self.a0, self.a1, self.b0, self.b1,
  105. coords0, orients0, coords1, orients1,
  106. minimizer=minimizer)
  107. timings[i] = contact_multi.time
  108. avg_evals_mineig[i] = contact_multi.avg_evals_mineig
  109. avg_evals_medeig[i] = contact_multi.avg_evals_medeig
  110. print(f'Total time: {np.sum(timings)}')
  111. return timings, avg_evals_mineig, avg_evals_medeig
  112. def generate_confgs(dist, n, axis=np.array([1, 0, 0])):
  113. """
  114. At a given distance, generate all possible orientational configurations of two spherical ellipses,
  115. with n steps in a half-circle rotation. This results in n x n configurations.
  116. The first ellipse will always be centered at the north pole while the position of the second is determined by
  117. parameter axis (and obviously distance).
  118. """
  119. angles = np.linspace(1e-7, np.pi - 1e-7, n)
  120. coords0 = np.zeros((n, 3))
  121. coords0[:, 2] = 1
  122. orients0 = np.array([1., 0., 0.])[None, :] * np.cos(angles)[:, None] + \
  123. np.array([0., 1., 0.])[None, :] * np.sin(angles)[:, None]
  124. coords1 = coords0 * np.cos(dist) + \
  125. np.cross(axis, coords0) * np.sin(dist) + \
  126. axis[None, :] * np.einsum('m, nm -> n', axis, coords0)[:, None] * (1 - np.cos(dist))
  127. orients1 = orients0 * np.cos(dist) + \
  128. np.cross(axis, orients0) * np.sin(dist) + \
  129. axis[None, :] * np.einsum('m, nm -> n', axis, orients0)[:, None] * (1 - np.cos(dist))
  130. indices = np.indices((n, n))
  131. coords0 = coords0[indices[0].flatten()]
  132. orients0 = orients0[indices[0].flatten()]
  133. coords1 = coords1[indices[1].flatten()]
  134. orients1 = orients1[indices[1].flatten()]
  135. return coords0, coords1, orients0, orients1