ソースを参照

Upload modula za Slicer3D program.

Poišče  in shrani markerje znotraj telesa po odprtih CT in CBCT volumnih. Poišče optimalno rotacijo, katera je bila izbrana v dropdown meniju v slicer modulu, nato pa z rotacijo poišče translacijo. Rotacijo in translacijo združi v transformacijsko matriko, katero aplicira na CBCT volumen. CBCT volumen mora vsebovati besedo CBCT v imenu, da ga prepozna.

Mapo shranite v željeno lokacijo, pokažite Slicerju kje se nahaja in potem ponovno zaženite program. Algoritem se bo pokazal pod Modules/Image processing/Seek Transform Module
Luka 2 ヶ月 前
コミット
492a30525a

BIN
SeekTransformModule/Resources/Icons/SeekTransformModule.png


+ 503 - 0
SeekTransformModule/SeekTransformModule.py

@@ -0,0 +1,503 @@
+import os
+import numpy as np
+import scipy
+from scipy.spatial.distance import cdist
+from scipy.spatial.transform import Rotation as R
+import slicer
+from DICOMLib import DICOMUtils
+from collections import deque
+import vtk
+from slicer.ScriptedLoadableModule import *
+import qt
+
+#exec(open("C:/Users/lkomar/Documents/Prostata/FirstTryRegister.py").read())
+
+class SeekTransformModule(ScriptedLoadableModule):
+    """
+    Module description shown in the module panel.
+    """
+    def __init__(self, parent):
+        ScriptedLoadableModule.__init__(self, parent)
+        self.parent.title = "Seek Transform module"
+        self.parent.categories = ["Image Processing"]
+        self.parent.contributors = ["Luka Komar (Onkološki Inštitut Ljubljana, Fakulteta za Matematiko in Fiziko Ljubljana)"]
+        self.parent.helpText = "This module applies rigid transformations to CBCT volumes based on reference CT volumes."
+        self.parent.acknowledgementText = "Supported by doc. Primož Peterlin & prof. Andrej Studen"
+
+class SeekTransformModuleWidget(ScriptedLoadableModuleWidget):
+    """
+    GUI of the module.
+    """
+    def setup(self):
+        ScriptedLoadableModuleWidget.setup(self)
+
+        # Dropdown menu za izbiro metode
+        self.rotationMethodComboBox = qt.QComboBox()
+        self.rotationMethodComboBox.addItems(["SVD", "Horn", "Quaternion"])
+        self.layout.addWidget(self.rotationMethodComboBox)
+        
+        # Load button
+        self.applyButton = qt.QPushButton("Find markers and transform")
+        self.applyButton.toolTip = "Finds markers, computes optimal rigid transform and applies it to CBCT volumes."
+        self.applyButton.enabled = True
+        self.layout.addWidget(self.applyButton)
+
+        # Connect button to logic
+        self.applyButton.connect('clicked(bool)', self.onApplyButton)
+
+        self.layout.addStretch(1)
+
+    def onApplyButton(self):
+        logic = MyTransformModuleLogic()
+        selectedMethod = self.rotationMethodComboBox.currentText #izberi metodo izračuna rotacije
+        logic.run(selectedMethod)
+
+class MyTransformModuleLogic(ScriptedLoadableModuleLogic):
+    """
+    Core logic of the module.
+    """
+    def run(self, selectedMethod):
+        def group_points(points, threshold):
+            # Function to group points that are close to each other
+            grouped_points = []
+            while points:
+                point = points.pop()  # Take one point from the list
+                group = [point]  # Start a new group
+                
+                # Find all points close to this one
+                distances = cdist([point], points)  # Calculate distances from this point to others
+                close_points = [i for i, dist in enumerate(distances[0]) if dist < threshold]
+                
+                # Add the close points to the group
+                group.extend([points[i] for i in close_points])
+                
+                # Remove the grouped points from the list
+                points = [point for i, point in enumerate(points) if i not in close_points]
+                
+                # Add the group to the result
+                grouped_points.append(group)
+            
+            return grouped_points
+
+        def region_growing(image_data, seed, intensity_threshold, max_distance):
+            dimensions = image_data.GetDimensions()
+            visited = set()
+            region = []
+            queue = deque([seed])
+
+            while queue:
+                x, y, z = queue.popleft()
+                if (x, y, z) in visited:
+                    continue
+
+                visited.add((x, y, z))
+                voxel_value = image_data.GetScalarComponentAsDouble(x, y, z, 0)
+                
+                if voxel_value >= intensity_threshold:
+                    region.append((x, y, z))
+                    # Add neighbors within bounds
+                    for dx, dy, dz in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]:
+                        nx, ny, nz = x + dx, y + dy, z + dz
+                        if 0 <= nx < dimensions[0] and 0 <= ny < dimensions[1] and 0 <= nz < dimensions[2]:
+                            if (nx, ny, nz) not in visited:
+                                queue.append((nx, ny, nz))
+
+            return region
+
+        def detect_points_region_growing(volume_name, intensity_threshold=3000, x_min=90, x_max=380, y_min=190, y_max=380, z_min=80, z_max=120, max_distance=9, centroid_merge_threshold=5):
+            volume_node = slicer.util.getNode(volume_name)
+            if not volume_node:
+                raise RuntimeError(f"Volume {volume_name} not found.")
+            
+            image_data = volume_node.GetImageData()
+            matrix = vtk.vtkMatrix4x4()
+            volume_node.GetIJKToRASMatrix(matrix)
+
+            dimensions = image_data.GetDimensions()
+            detected_regions = []
+
+            # Check if it's CT or CBCT
+            is_cbct = "cbct" in volume_name.lower()
+
+            if is_cbct:
+                valid_x_min, valid_x_max = 0, dimensions[0] - 1
+                valid_y_min, valid_y_max = 0, dimensions[1] - 1
+                valid_z_min, valid_z_max = 0, dimensions[2] - 1
+            else:
+                valid_x_min, valid_x_max = max(x_min, 0), min(x_max, dimensions[0] - 1)
+                valid_y_min, valid_y_max = max(y_min, 0), min(y_max, dimensions[1] - 1)
+                valid_z_min, valid_z_max = max(z_min, 0), min(z_max, dimensions[2] - 1)
+
+            visited = set()
+
+            def grow_region(x, y, z):
+                if (x, y, z) in visited:
+                    return None
+
+                voxel_value = image_data.GetScalarComponentAsDouble(x, y, z, 0)
+                if voxel_value < intensity_threshold:
+                    return None
+
+                region = region_growing(image_data, (x, y, z), intensity_threshold, max_distance=max_distance)
+                if region:
+                    for point in region:
+                        visited.add(tuple(point))
+                    return region
+                return None
+
+            regions = []
+            for z in range(valid_z_min, valid_z_max + 1):
+                for y in range(valid_y_min, valid_y_max + 1):
+                    for x in range(valid_x_min, valid_x_max + 1):
+                        region = grow_region(x, y, z)
+                        if region:
+                            regions.append(region)
+
+            # Collect centroids using intensity-weighted average
+            centroids = []
+            for region in regions:
+                points = np.array([matrix.MultiplyPoint([*point, 1])[:3] for point in region])
+                intensities = np.array([image_data.GetScalarComponentAsDouble(*point, 0) for point in region])
+                
+                if intensities.sum() > 0:
+                    weighted_centroid = np.average(points, axis=0, weights=intensities)
+                    max_intensity = intensities.max()
+                    centroids.append((np.round(weighted_centroid, 2), max_intensity))
+
+            unique_centroids = []
+            for centroid, intensity in centroids:
+                if not any(np.linalg.norm(centroid - existing_centroid) < centroid_merge_threshold for existing_centroid, _ in unique_centroids):
+                    unique_centroids.append((centroid, intensity))
+                    
+            markups_node = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", f"Markers_{volume_name}")
+            for centroid, intensity in unique_centroids:
+                markups_node.AddControlPoint(*centroid)
+                #print(f"Detected Centroid (RAS): {centroid}, Max Intensity: {intensity}")
+
+            return unique_centroids
+
+        def compute_Kabsch_rotation(moving_points, fixed_points):
+            """
+            Computes the optimal rotation matrix to align moving_points to fixed_points.
+            
+            Parameters:
+            moving_points (list or ndarray): List of points to be rotated CBCT
+            fixed_points (list or ndarray): List of reference points CT
+
+            Returns:
+            ndarray: Optimal rotation matrix.
+            """
+            assert len(moving_points) == len(fixed_points), "Point lists must be the same length."
+
+            # Convert to numpy arrays
+            moving = np.array(moving_points)
+            fixed = np.array(fixed_points)
+
+            # Compute centroids
+            centroid_moving = np.mean(moving, axis=0)
+            centroid_fixed = np.mean(fixed, axis=0)
+
+            # Center the points
+            moving_centered = moving - centroid_moving
+            fixed_centered = fixed - centroid_fixed
+
+            # Compute covariance matrix
+            H = np.dot(moving_centered.T, fixed_centered)
+
+            # SVD decomposition
+            U, _, Vt = np.linalg.svd(H)
+            Rotate_optimal = np.dot(Vt.T, U.T)
+
+            # Correct improper rotation (reflection)
+            if np.linalg.det(Rotate_optimal) < 0:
+                Vt[-1, :] *= -1
+                Rotate_optimal = np.dot(Vt.T, U.T)
+
+            return Rotate_optimal
+
+        def compute_Horn_rotation(moving_points, fixed_points):
+            """
+            Computes the optimal rotation matrix using Horn's method.
+
+            Parameters:
+            moving_points (list or ndarray): List of points to be rotated, CBCT
+            fixed_points (list or ndarray): List of reference points, CT
+
+            Returns:
+            ndarray: Optimal rotation matrix.
+            """
+            assert len(moving_points) == len(fixed_points), "Point lists must be the same length."
+            
+            moving = np.array(moving_points)
+            fixed = np.array(fixed_points)
+            
+            # Compute centroids
+            centroid_moving = np.mean(moving, axis=0)
+            centroid_fixed = np.mean(fixed, axis=0)
+            
+            # Center the points
+            moving_centered = moving - centroid_moving
+            fixed_centered = fixed - centroid_fixed
+            
+            # Compute cross-dispersion matrix
+            H = np.dot(moving_centered.T, fixed_centered)
+            
+            # Compute SVD of H
+            U, _, Vt = np.linalg.svd(H)
+            
+            # Compute rotation matrix
+            R = np.dot(Vt.T, U.T)
+            
+            # Ensure a proper rotation (avoid reflection)
+            if np.linalg.det(R) < 0:
+                Vt[-1, :] *= -1
+                R = np.dot(Vt.T, U.T)
+            
+            return R
+
+        def compute_quaternion_rotation(moving_points, fixed_points):
+            """
+            Computes the optimal rotation matrix using quaternions.
+
+            Parameters:
+            moving_points (list or ndarray): List of points to be rotated.
+            fixed_points (list or ndarray): List of reference points.
+
+            Returns:
+            ndarray: Optimal rotation matrix.
+            """
+            assert len(moving_points) == len(fixed_points), "Point lists must be the same length."
+            
+            moving = np.array(moving_points)
+            fixed = np.array(fixed_points)
+            
+            # Compute centroids
+            centroid_moving = np.mean(moving, axis=0)
+            centroid_fixed = np.mean(fixed, axis=0)
+            
+            # Center the points
+            moving_centered = moving - centroid_moving
+            fixed_centered = fixed - centroid_fixed
+            
+            # Construct the cross-dispersion matrix
+            M = np.dot(moving_centered.T, fixed_centered)
+            
+            # Construct the N matrix for quaternion solution
+            A = M - M.T
+            delta = np.array([A[1, 2], A[2, 0], A[0, 1]])
+            trace = np.trace(M)
+            
+            N = np.zeros((4, 4))
+            N[0, 0] = trace
+            N[1:, 0] = delta
+            N[0, 1:] = delta
+            N[1:, 1:] = M + M.T - np.eye(3) * trace
+            
+            # Compute the eigenvector corresponding to the maximum eigenvalue
+            eigvals, eigvecs = np.linalg.eigh(N)
+            q_optimal = eigvecs[:, np.argmax(eigvals)]  # Optimal quaternion
+            
+            # Convert quaternion to rotation matrix
+            w, x, y, z = q_optimal
+            R = np.array([
+                [1 - 2*(y**2 + z**2), 2*(x*y - z*w), 2*(x*z + y*w)],
+                [2*(x*y + z*w), 1 - 2*(x**2 + z**2), 2*(y*z - x*w)],
+                [2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x**2 + y**2)]
+            ])
+            
+            return R
+
+        def compute_translation(moving_points, fixed_points, rotation_matrix):
+            """
+            Computes the translation vector to align moving_points to fixed_points given a rotation matrix.
+            
+            Parameters:
+            moving_points (list or ndarray): List of points to be translated.
+            fixed_points (list or ndarray): List of reference points.
+            rotation_matrix (ndarray): Rotation matrix.
+
+            Returns:
+            ndarray: Translation vector.
+            """
+            # Convert to numpy arrays
+            moving = np.array(moving_points)
+            fixed = np.array(fixed_points)
+
+            # Compute centroids
+            centroid_moving = np.mean(moving, axis=0)
+            centroid_fixed = np.mean(fixed, axis=0)
+
+            # Compute translation
+            translation = centroid_fixed - np.dot(centroid_moving, rotation_matrix)
+
+            return translation
+
+        def create_vtk_transform(rotation_matrix, translation_vector):
+            """
+            Creates a vtkTransform from a rotation matrix and a translation vector.
+            """
+            # Create a 4x4 transformation matrix
+            transform_matrix = np.eye(4)  # Start with an identity matrix
+            transform_matrix[:3, :3] = rotation_matrix  # Set rotation part
+            transform_matrix[:3, 3] = translation_vector  # Set translation part
+
+            # Convert to vtkMatrix4x4
+            vtk_matrix = vtk.vtkMatrix4x4()
+            for i in range(4):
+                for j in range(4):
+                    vtk_matrix.SetElement(i, j, transform_matrix[i, j])
+            print("Transform matrix: ")
+            print(vtk_matrix)
+            # Create vtkTransform and set the matrix
+            transform = vtk.vtkTransform()
+            transform.SetMatrix(vtk_matrix)
+            return transform
+
+
+        # Initialize lists and dictionary
+        cbct_list = []
+        ct_list = []
+        volume_points_dict = {}
+
+        # Process loaded volumes
+        for volumeNode in slicer.util.getNodesByClass("vtkMRMLScalarVolumeNode"):
+            volumeName = volumeNode.GetName()
+            shNode = slicer.vtkMRMLSubjectHierarchyNode.GetSubjectHierarchyNode(slicer.mrmlScene)
+            imageItem = shNode.GetItemByDataNode(volumeNode)
+            
+            modality = shNode.GetItemAttribute(imageItem, 'DICOM.Modality')
+            #print(modality)
+            
+            # Check if the volume is loaded into the scene
+            if not slicer.mrmlScene.IsNodePresent(volumeNode):
+                print(f"Volume {volumeName} not present in the scene.")
+                continue
+            
+            # Determine scan type
+            if "cbct" in volumeName.lower():
+                cbct_list.append(volumeName)
+                scan_type = "CBCT"
+            else:
+                ct_list.append(volumeName)
+                scan_type = "CT"
+            
+            # Detect points using region growing
+            grouped_points = detect_points_region_growing(volumeName, intensity_threshold=3000)
+            volume_points_dict[(scan_type, volumeName)] = grouped_points
+
+        # Print the results
+        # print(f"\nCBCT Volumes: {cbct_list}")
+        # print(f"CT Volumes: {ct_list}")
+        # print("\nDetected Points by Volume:")
+        # for (scan_type, vol_name), points in volume_points_dict.items():
+        #     print(f"{scan_type} Volume '{vol_name}': {len(points)} points detected.")
+
+
+        if cbct_list and ct_list:
+            # Izberi prvi CT volumen kot referenco
+            ct_volume_name = ct_list[0]
+            ct_points = [centroid for centroid, _ in volume_points_dict[("CT", ct_volume_name)]]
+
+            if len(ct_points) < 3:
+                print("CT volumen nima dovolj točk za registracijo.")
+            else:
+                print("CT points: ", np.array(ct_points))
+                
+                for cbct_volume_name in cbct_list:
+                    # Izvleci točke za trenutni CBCT volumen
+                    cbct_points = [centroid for centroid, _ in volume_points_dict[("CBCT", cbct_volume_name)]]
+
+                    print(f"\nProcessing CBCT Volume: {cbct_volume_name}")
+                    if len(cbct_points) < 3:
+                        print(f"CBCT Volume '{cbct_volume_name}' nima dovolj točk za registracijo.")
+                        continue
+
+                    #print("CBCT points: ", np.array(cbct_points))
+
+                    # Display the results for the current CBCT volume
+                    # print("\nSVD Method:")
+                    # print("Rotation Matrix:\n", svd_rotation_matrix)
+                    # print("Translation Vector:\n", svd_translation_vector)
+
+                    # print("\nHorn Method:")
+                    # print("Rotation Matrix:\n", horn_rotation_matrix)
+                    # print("Translation Vector:\n", horn_translation_vector)
+
+                    # print("\nQuaternion Method:")
+                    # print("Rotation Matrix:\n", quaternion_rotation_matrix)
+                    # print("Translation Vector:\n", quaternion_translation_vector)
+
+                    # Izberi metodo glede na uporabnikov izbor
+                    if selectedMethod == "SVD":
+                        chosen_rotation_matrix = compute_Kabsch_rotation(cbct_points, ct_points)
+                        chosen_translation_vector = compute_translation(cbct_points, ct_points, chosen_rotation_matrix)
+                        print("\nSVD Method:")
+                        print("Rotation Matrix:\n", chosen_rotation_matrix)
+                        print("Translation Vector:\n", chosen_translation_vector)
+                    elif selectedMethod == "Horn":
+                        chosen_rotation_matrix = compute_Horn_rotation(cbct_points, ct_points)
+                        chosen_translation_vector = compute_translation(cbct_points, ct_points, chosen_rotation_matrix)
+                        print("\nHorn Method:")
+                        print("Rotation Matrix:\n", chosen_rotation_matrix)
+                        print("Translation Vector:\n", chosen_translation_vector)
+                    elif selectedMethod == "Quaternion":
+                        chosen_rotation_matrix = compute_quaternion_rotation(cbct_points, ct_points)
+                        chosen_translation_vector = compute_translation(cbct_points, ct_points, chosen_rotation_matrix)
+                        print("\nQuaternion Method:")
+                        print("Rotation Matrix:\n", chosen_rotation_matrix)
+                        print("Translation Vector:\n", chosen_translation_vector)
+
+                    imeTransformNoda = cbct_volume_name + " Transform"
+                    # Ustvari vtkTransformNode in ga poveži z CBCT volumenom
+                    transform_node = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLTransformNode", imeTransformNoda)
+                    # Kliči funkcijo, ki uporabi matriki
+                    vtk_transform = create_vtk_transform(chosen_rotation_matrix, chosen_translation_vector)
+                    # Dodaj transform v node
+                    transform_node.SetAndObserveTransformToParent(vtk_transform)
+
+                    # Pridobi CBCT volumen in aplikacijo transformacije
+                    cbct_volume_node = slicer.util.getNode(cbct_volume_name)
+                    cbct_volume_node.SetAndObserveTransformNodeID(transform_node.GetID()) # Pripni transform node na volumen
+
+                    # Uporabi transformacijo na volumnu (fizična aplikacija)
+                    slicer.vtkSlicerTransformLogic().hardenTransform(cbct_volume_node) # Uporabi transform na volumen
+                    print("Transform uspešen na", cbct_volume_name)
+                    
+                    
+                    #transformed_cbct_image = create_vtk_transform(cbct_image_sitk, chosen_rotation_matrix, chosen_translation_vector)
+
+        else:
+            print("CBCT ali CT volumen ni bil najden.")
+
+
+    # def compute_rigid_transform(moving_points, fixed_points):
+    #     assert len(moving_points) == len(fixed_points), "Point lists must be the same length."
+
+    #     # Convert to numpy arrays
+    #     moving = np.array(moving_points)
+    #     fixed = np.array(fixed_points)
+
+    #     # Compute centroids
+    #     centroid_moving = np.mean(moving, axis=0)
+    #     centroid_fixed = np.mean(fixed, axis=0)
+
+    #     # Center the points
+    #     moving_centered = moving - centroid_moving
+    #     fixed_centered = fixed - centroid_fixed
+
+    #     # Compute covariance matrix
+    #     H = np.dot(moving_centered.T, fixed_centered)
+
+    #     # SVD decomposition
+    #     U, _, Vt = np.linalg.svd(H)
+    #     Rotate_optimal = np.dot(Vt.T, U.T)
+
+    #     # Correct improper rotation (reflection)
+    #     if np.linalg.det(Rotate_optimal) < 0:
+    #         Vt[-1, :] *= -1
+    #         Rotate_optimal = np.dot(Vt.T, U.T)
+
+    #     # Compute translation
+    #     translation = centroid_fixed - np.dot(centroid_moving, Rotate_optimal)
+
+    #     return Rotate_optimal, translation