Pārlūkot izejas kodu

Alpha version, offset 2 mm still.

Luka 6 dienas atpakaļ
vecāks
revīzija
f43be9a756
1 mainītis faili ar 195 papildinājumiem un 24 dzēšanām
  1. 195 24
      SeekTransformModule/SeekTransformModule.py

+ 195 - 24
SeekTransformModule/SeekTransformModule.py

@@ -332,20 +332,75 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
 
             return moving, R, t
 
-        def match_points(cbct_points, ct_points):
-            """
-            Vrne cbct_points permutirane tako, da so najbolje poravnani s ct_points
-            glede na vsoto evklidskih razdalj.
-            """
+        def match_points(cbct_points, ct_points, auto_weights=True, fallback_if_worse=True, normalize_lengths=True):
+            def side_lengths(points):
+                lengths = [
+                    np.linalg.norm(points[0] - points[1]),
+                    np.linalg.norm(points[1] - points[2]),
+                    np.linalg.norm(points[2] - points[0])
+                ]
+                if normalize_lengths:
+                    total = sum(lengths) + 1e-6  # da ne delimo z 0
+                    lengths = [l / total for l in lengths]
+                return lengths
+
+            def triangle_angles(points):
+                a = np.linalg.norm(points[1] - points[2])
+                b = np.linalg.norm(points[0] - points[2])
+                c = np.linalg.norm(points[0] - points[1])
+                angle_A = np.arccos(np.clip((b**2 + c**2 - a**2) / (2 * b * c), -1.0, 1.0))
+                angle_B = np.arccos(np.clip((a**2 + c**2 - b**2) / (2 * a * c), -1.0, 1.0))
+                angle_C = np.pi - angle_A - angle_B
+                return [angle_A, angle_B, angle_C]
+
+            def permutation_score(perm, ct_lengths, ct_angles, w_len, w_ang):
+                perm_lengths = side_lengths(perm)
+                perm_angles = triangle_angles(perm)
+
+                # normaliziraj
+                def normalize(vec):
+                    norm = np.linalg.norm(vec)
+                    return vec / norm if norm > 0 else vec
+
+                perm_lengths = normalize(perm_lengths)
+                perm_angles = normalize(perm_angles)
+                ct_lengths_n = normalize(ct_lengths)
+                ct_angles_n = normalize(ct_angles)
+
+                score_len = sum(abs(a - b) for a, b in zip(perm_lengths, ct_lengths_n))
+                score_ang = sum(abs(a - b) for a, b in zip(perm_angles, ct_angles_n))
+                return w_len * score_len + w_ang * score_ang
+
+            cbct_points = list(cbct_points)
+            ct_lengths = side_lengths(np.array(ct_points))
+            ct_angles = triangle_angles(np.array(ct_points))
+
+            if auto_weights:
+                var_len = np.var(ct_lengths)
+                var_ang = np.var(ct_angles)
+                total_var = var_len + var_ang + 1e-6
+                weight_length = (1 - var_len / total_var)
+                weight_angle = (1 - var_ang / total_var)
+            else:
+                weight_length = 0.5
+                weight_angle = 0.5
+
             best_perm = None
-            min_total_dist = float('inf')
+            best_score = float('inf')
 
             for perm in itertools.permutations(cbct_points):
-                total_dist = sum(np.linalg.norm(np.array(p1) - np.array(p2)) for p1, p2 in zip(perm, ct_points))
-                if total_dist < min_total_dist:
-                    min_total_dist = total_dist
+                perm = np.array(perm)
+                score = permutation_score(perm, ct_lengths, ct_angles, weight_length, weight_angle)
+                if score < best_score:
+                    best_score = score
                     best_perm = perm
 
+            if fallback_if_worse:
+                original_score = permutation_score(np.array(cbct_points), ct_lengths, ct_angles, weight_length, weight_angle)
+                if original_score <= best_score or np.allclose(cbct_points, best_perm):
+                    print("Fallback to original points due to worse score of the permutation.")
+                    return list(cbct_points)
+
             return list(best_perm)
 
         def compute_translation(moving_points, fixed_points, rotation_matrix):
@@ -496,7 +551,7 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
             ijkToRasMatrix = vtk.vtkMatrix4x4()
             ct_volume_node.GetIJKToRASMatrix(ijkToRasMatrix)
 
-            #mid_ras = np.array(ijkToRasMatrix.MultiplyPoint([*mid_ijk, 1]))[:3]
+            
 
             # Sredinski Z slice
             mid_z_voxel = mid_ijk[2]
@@ -508,6 +563,7 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
 
             # Doda marker v RAS koordinatah
             #mid_node = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", f"Sredina_{ct_volume_name}")
+            #mid_ras = np.array(ijkToRasMatrix.MultiplyPoint([*mid_ijk, 1]))[:3]
             #mid_node.AddControlPoint(mid_ras)
 
             # Določi threshold glede na CBCT ali CT
@@ -545,9 +601,9 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
             table_ras = np.array(ijkToRasMatrix.MultiplyPoint([*table_ijk, 1]))[:3]
 
             # Doda marker za višino mize
-            #table_node = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", f"VišinaMize_{ct_volume_name}")
-            #table_node.AddControlPoint(table_ras)
-            #table_node.SetDisplayVisibility(False)
+            table_node = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", f"VišinaMize_{ct_volume_name}")
+            table_node.AddControlPoint(table_ras)
+            table_node.SetDisplayVisibility(False)
 
             # Izračun višine v mm
             #image_center_y = dims[1] // 2
@@ -621,7 +677,110 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
             for i in range(3):
                 print([matrix.GetElement(i, j) for j in range(3)])
 
+        def prealign_by_centroid(cbct_points, ct_points):
+            """
+            Predporavna CBCT markerje na CT markerje glede na centrične točke.
+
+            Args:
+                cbct_points: List ali ndarray točk iz CBCT.
+                ct_points: List ali ndarray točk iz CT.
+
+            Returns:
+                List: CBCT točke premaknjene tako, da so centrične točke usklajene.
+            """
+            cbct_points = np.array(cbct_points)
+            ct_points = np.array(ct_points)
+            cbct_centroid = np.mean(cbct_points, axis=0)
+            ct_centroid = np.mean(ct_points, axis=0)
+            translation_vector = ct_centroid - cbct_centroid
+            aligned_cbct = cbct_points + translation_vector
+            return aligned_cbct
+        
+        def choose_best_translation(cbct_points, ct_points, rotation_matrix):
+            """
+            Izbere boljšo translacijo: centroidno ali povprečno po rotaciji (retranslation).
+            
+            Args:
+                cbct_points (array-like): Točke iz CBCT (še ne rotirane).
+                ct_points (array-like): Ciljne CT točke.
+                rotation_matrix (ndarray): Rotacijska matrika.
+
+            Returns:
+                tuple: (best_translation_vector, transformed_cbct_points, used_method)
+            """
+            cbct_points = np.array(cbct_points)
+            ct_points = np.array(ct_points)
+            
+            # 1. Rotiraj CBCT točke
+            rotated_cbct = np.dot(cbct_points, rotation_matrix.T)
+            
+            # 2. Centroid translacija
+            centroid_moving = np.mean(cbct_points, axis=0)
+            centroid_fixed = np.mean(ct_points, axis=0)
+            translation_centroid = centroid_fixed - np.dot(centroid_moving, rotation_matrix)
+            transformed_centroid = rotated_cbct + translation_centroid
+            error_centroid = np.mean(np.linalg.norm(transformed_centroid - ct_points, axis=1))
+
+            # 3. Retranslacija (srednja razlika)
+            translation_recomputed = np.mean(ct_points - rotated_cbct, axis=0)
+            transformed_recomputed = rotated_cbct + translation_recomputed
+            error_recomputed = np.mean(np.linalg.norm(transformed_recomputed - ct_points, axis=1))
+
+            # 4. Izberi boljšo
+            if error_recomputed < error_centroid:
+                print(f"✅ Using retranslation (error: {error_recomputed:.2f} mm)")
+                return translation_recomputed, transformed_recomputed, "retranslation"
+            else:
+                print(f"✅ Using centroid-based translation (error: {error_centroid:.2f} mm)")
+                return translation_centroid, transformed_centroid, "centroid"
+
+        def visualize_point_matches_in_slicer(cbct_points, ct_points, study_name="MatchVisualization"):
+            assert len(cbct_points) == len(ct_points), "Mora biti enako število točk!"
+
+            # Ustvari markups za CBCT
+            cbct_node = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", f"{study_name}_CBCT")
+            cbct_node.GetDisplayNode().SetSelectedColor(0, 0, 1)  # modra
 
+            # Ustvari markups za CT
+            ct_node = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", f"{study_name}_CT")
+            ct_node.GetDisplayNode().SetSelectedColor(1, 0, 0)  # rdeča
+
+            # Dodaj točke
+            for i, (cbct, ct) in enumerate(zip(cbct_points, ct_points)):
+                cbct_node.AddControlPoint(*cbct, f"CBCT_{i}")
+                ct_node.AddControlPoint(*ct, f"CT_{i}")
+
+            # Ustvari model z linijami med pari
+            points = vtk.vtkPoints()
+            lines = vtk.vtkCellArray()
+
+            for i, (p1, p2) in enumerate(zip(cbct_points, ct_points)):
+                id1 = points.InsertNextPoint(p1)
+                id2 = points.InsertNextPoint(p2)
+
+                line = vtk.vtkLine()
+                line.GetPointIds().SetId(0, id1)
+                line.GetPointIds().SetId(1, id2)
+                lines.InsertNextCell(line)
+
+            polyData = vtk.vtkPolyData()
+            polyData.SetPoints(points)
+            polyData.SetLines(lines)
+
+            # Model node
+            modelNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLModelNode", f"{study_name}_Connections")
+            modelNode.SetAndObservePolyData(polyData)
+
+            modelDisplay = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLModelDisplayNode")
+            modelDisplay.SetColor(0, 0, 0)  # črna
+            modelDisplay.SetLineWidth(2)
+            modelDisplay.SetVisibility(True)
+
+            modelNode.SetAndObserveDisplayNodeID(modelDisplay.GetID())
+            modelNode.SetAndObservePolyData(polyData)
+
+            print(f"✅ Vizualizacija dodana za {study_name} (točke + povezave)")
+        
         # Globalni seznami za končno statistiko
         prostate_size_est = []
         ctcbct_distance = []
@@ -660,11 +819,9 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
                     dicomUIDs = volumeNode.GetAttribute("DICOM.instanceUIDs")
                     if not dicomUIDs:
                         print("❌ This is an NRRD volume!")
-                        continue  # Preskoči, če ni DICOM volume
+                        continue  # Preskoči, če ni DICOM volume             
                     
-                    
-                        
-                        
+                                            
                     volumeName = volumeNode.GetName()
                     imageItem = shNode.GetItemByDataNode(volumeNode)
                     modality = shNode.GetItemAttribute(imageItem, "DICOM.Modality")                                 #deluje!
@@ -677,7 +834,6 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
                         print("Not a CT")
                         continue  # Preskoči, če ni CT
 
-                    
                     # Preveri, če volume obstaja v sceni
                     if not slicer.mrmlScene.IsNodePresent(volumeNode):
                         print(f"Volume {volumeName} not present in the scene.")
@@ -708,7 +864,7 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
                         
                         
                     # Detekcija točk v volumnu
-                    ustvari_marker = False  # Ustvari markerje
+                    ustvari_marker = not yesCbct  # Ustvari markerje pred poravnavo na mizo
                     grouped_points = detect_points_region_growing(volumeName, yesCbct, ustvari_marker, intensity_threshold=3000)
                     #print(f"Populating volume_points_dict with key ('{scan_type}', '{volumeName}')")
                     volume_points_dict[(scan_type, volumeName)] = grouped_points
@@ -737,7 +893,7 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
                         scan_type = "CBCT" #redundant but here we are
                         cbct_volume_node = slicer.util.getNode(cbct_volume_name)
                         
-                        mm_offset, pixel_offset = find_table_top_z(cbct_volume_name, writefilecheck, yesCbct)                        
+                        mm_offset, pixel_offset = find_table_top_z(cbct_volume_name, writefilecheck, yesCbct)                   
                                             
                         cbct_points = [centroid for centroid, _ in volume_points_dict[("CBCT", cbct_volume_name)]] #zastareli podatki
                         cbct_points_array = np.array(cbct_points)  # Pretvorba v numpy array
@@ -770,8 +926,8 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
                         #Sortiramo točke po X/Y/Z da se izognemo težavam pri poravnavi
                         cbct_points = match_points(cbct_points, ct_points)
                         
-                        #for i, (cb, ct) in enumerate(zip(cbct_points, ct_points)):
-                        #    print(f"Pair {i}: CBCT {cb}, CT {ct}, diff: {np.linalg.norm(cb - ct):.2f}")
+                        #visualize_point_matches_in_slicer(cbct_points, ct_points, studyName) #poveže pare markerjev
+        
                         
                         # Shranjevanje razdalj
                         distances_ct_cbct = []
@@ -813,7 +969,12 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
                             scaling_factors = compute_optimal_scaling_per_axis(cbct_points, ct_points)
                             #print("Scaling factors: ", scaling_factors)
                             cbct_points = compute_scaling(cbct_points, scaling_factors)
-
+                        
+                        initial_error = np.mean(np.linalg.norm(np.array(cbct_points) - np.array(ct_points), axis=1))
+                        if initial_error > 30:
+                            print("⚠️ Initial distance too large, applying centroid prealignment.")
+                            cbct_points = prealign_by_centroid(cbct_points, ct_points)
+                        
                         if applyRotation:
                             if selectedMethod == "Kabsch":
                                 chosen_rotation_matrix = compute_Kabsch_rotation(cbct_points, ct_points)
@@ -824,8 +985,16 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
                             #print("Rotation Matrix:\n", chosen_rotation_matrix)
 
                         if applyTranslation:
-                            chosen_translation_vector = compute_translation(cbct_points, ct_points, chosen_rotation_matrix)
+                            chosen_translation_vector, cbct_points_transformed, method_used = choose_best_translation(cbct_points, ct_points, chosen_rotation_matrix) #Izbere optimalno translacijo
+                            
+                            per_axis_error = np.abs(np.array(cbct_points_transformed) - np.array(ct_points))
+                            dz_mean = np.mean(per_axis_error[:, 2])
+                            print(f"per-axis error: {per_axis_error}, dz_mean err: {dz_mean:.2f} mm")
+                            
+                            #chosen_translation_vector = compute_translation(cbct_points, ct_points, chosen_rotation_matrix)
                             #print("Translation Vector:\n", chosen_translation_vector)
+
+                        
                         
                         
                         # Ustvari vtkTransformNode in ga poveži z CBCT volumenom
@@ -854,6 +1023,8 @@ class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
 
                         print("Individualne napake:", errors)
                         print("📏 Povprečna napaka poravnave: {:.2f} mm".format(mean_error))
+                        
+                        
 
             else:
                 print(f"Study {studyItem} doesn't have any appropriate CT or CBCT volumes.")