Skip to content

expectation_maximization_registration

Point Cloud Registration using Expectation-MaximizationLink

This module implements the Expectation-Maximization algorithm for point cloud registration, providing a probabilistic approach to align 3D point sets with robust handling of noise and outliers. This is an abstract base class that should be implemented by specific registration methods.

Key FeaturesLink

  • Probabilistic point cloud alignment using EM algorithm
  • Iterative refinement of transformation parameters
  • Automatic variance estimation for noise handling
  • Support for rigid and non-rigid transformations
  • Convergence control through iteration limits and tolerance settings

NotesLink

This is a part of the implementation of the stochastic registration algorithm based on the following paper: Myronenko A. and Song X. (2010) Point set registration: Coherent Point drift. IEEE Transactions on Pattern Analysis and Machine Intelligence. 32 (2): 2262-2275. DOI: 10.1109/TPAMI.2010.46

The library is based on the python implementation of the paper in pycpd package.

ExpectationMaximizationRegistration Link

ExpectationMaximizationRegistration(X, Y, sigma2=None, max_iterations=MAX_ITER, tolerance=TOL, w=0, *args, **kwargs)

Bases: object

Abstract base class for point cloud registration using Expectation-Maximization algorithm.

This class implements the core functionality of the Coherent Point Drift (CPD) algorithm for point set registration based on Myronenko and Song's paper. It uses a probabilistic approach where the alignment of two point sets is treated as a Maximum Likelihood (ML) estimation problem with a Gaussian Mixture Model (GMM) as the likelihood function.

The class serves as a base for various CPD registration methods (rigid, affine, etc.), providing common EM framework while requiring specific transformation models to be implemented in child classes.

Attributes:

Name Type Description
X ndarray

Reference point cloud coordinates, shape (N, D).

Y ndarray

Initial point cloud coordinates to optimize, shape (M, D).

TY ndarray

Transformed/registered version of Y after optimization, shape (M, D).

sigma2 float

Variance of the Gaussian Mixture Model (GMM), updated during registration.

N int

Number of points in reference cloud X.

M int

Number of points in source cloud Y.

D int

Dimensionality of the point clouds (e.g., 3 for 3D point clouds).

tolerance float

Convergence criterion threshold.

w float

Weight of the uniform distribution component for outlier handling.

max_iterations int

Maximum number of iterations for the algorithm.

iteration int

Current iteration number during registration process.

err float

Current registration error/distance between point sets.

P ndarray

Posterior probability matrix of point correspondences, shape (M, N).

Pt1 ndarray

Column-wise sum of posterior probability matrix, shape (N,).

P1 ndarray

Row-wise sum of posterior probability matrix, shape (M,).

Np float

Sum of all elements in the posterior probability matrix.

q float

Negative log-likelihood of the current estimate.

Notes

This is an abstract base class. Child classes must implement:

  • update_transform(): Update transformation parameters
  • transform_point_cloud(): Apply transformation to point cloud
  • update_variance(): Update GMM variance
  • get_registration_parameters(): Return registration parameters
References

Myronenko A. and Song X. (2010) Point set registration: Coherent Point drift. IEEE Transactions on Pattern Analysis and Machine Intelligence. 32 (2): 2262-2275. DOI: 10.1109/TPAMI.2010.46

See Also

skeleton_refinement.utilities.initialize_sigma2 : Function to initialize the variance parameter

Initialize the Expectation-Maximization registration algorithm.

Parameters:

Name Type Description Default
X ndarray

Reference point cloud (target), shape (N, D).

required
Y ndarray

Point cloud to be aligned (source), shape (M, D).

required
sigma2 float or None

Initial variance of the Gaussian Mixture Model (GMM). If None, it will be estimated from data. Default is None.

None
max_iterations int

Maximum number of EM iterations. Default is 100.

MAX_ITER
tolerance float

Convergence threshold for stopping iterations. Algorithm stops when change in error falls below this value. Default is 0.0001.

TOL
w float

Weight of the uniform distribution component (0 <= w < 1). Used to account for outliers and noise. A value of 0 means no outlier handling. Default is 0.

0

Raises:

Type Description
ValueError

If X or Y are not 2D numpy arrays, or if their dimensions don't match.

Source code in skeleton_refinement/expectation_maximization_registration.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def __init__(self, X, Y, sigma2=None, max_iterations=MAX_ITER, tolerance=TOL, w=0, *args, **kwargs):
    """Initialize the Expectation-Maximization registration algorithm.

    Parameters
    ----------
    X : numpy.ndarray
        Reference point cloud (target), shape ``(N, D)``.
    Y : numpy.ndarray
        Point cloud to be aligned (source), shape ``(M, D)``.
    sigma2 : float or None, optional
        Initial variance of the Gaussian Mixture Model (GMM).
        If ``None``, it will be estimated from data.
        Default is ``None``.
    max_iterations : int, optional
        Maximum number of EM iterations. Default is ``100``.
    tolerance : float, optional
        Convergence threshold for stopping iterations.
        Algorithm stops when change in error falls below this value.
        Default is ``0.0001``.
    w : float, optional
        Weight of the uniform distribution component (0 <= w < 1).
        Used to account for outliers and noise.
        A value of ``0`` means no outlier handling.
        Default is ``0``.

    Raises
    ------
    ValueError
        If `X` or `Y` are not 2D numpy arrays, or if their dimensions don't match.
    """
    if not isinstance(X, np.ndarray) or X.ndim != 2:
        raise ValueError("The target point cloud (X) must be at a 2D numpy array.")
    if not isinstance(Y, np.ndarray) or Y.ndim != 2:
        raise ValueError("The source point cloud (Y) must be a 2D numpy array.")
    if X.shape[1] != Y.shape[1]:
        raise ValueError("Both point clouds need to have the same number of dimensions.")

    self.X = X
    self.Y = Y
    self.sigma2 = sigma2
    (self.N, self.D) = self.X.shape
    (self.M, _) = self.Y.shape
    self.tolerance = tolerance
    self.w = w
    self.max_iterations = max_iterations
    self.iteration = 0
    self.err = self.tolerance + 1
    self.P = np.zeros((self.M, self.N))
    self.Pt1 = np.zeros((self.N,))
    self.P1 = np.zeros((self.M,))
    self.Np = 0

    self.TY = None

expectation Link

expectation()

Perform the Expectation step of the EM algorithm.

The expectation step estimates the posterior probability (P) that each point in the source set corresponds to each point in the reference set, based on the current transformation and GMM variance.

This step also handles outlier detection based on the uniform distribution weight parameter w.

Notes

Updates the following attributes:

  • P: Posterior probability matrix of point correspondences
  • Pt1: Column-wise sum of P
  • P1: Row-wise sum of P
  • Np: Sum of all elements in P
Source code in skeleton_refinement/expectation_maximization_registration.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def expectation(self):
    """Perform the Expectation step of the EM algorithm.

    The expectation step estimates the posterior probability (P) that each
    point in the source set corresponds to each point in the reference set,
    based on the current transformation and GMM variance.

    This step also handles outlier detection based on the uniform distribution
    weight parameter w.

    Notes
    -----
    Updates the following attributes:

    - P: Posterior probability matrix of point correspondences
    - Pt1: Column-wise sum of P
    - P1: Row-wise sum of P
    - Np: Sum of all elements in P
    """
    # Initialize posterior probability matrix (M source points × N reference points)
    P = np.zeros((self.M, self.N))

    # Calculate squared Mahalanobis distances between transformed source points and reference points
    for i in range(0, self.M):
        # Calculate differences between current transformed point and all reference points
        diff = self.X - np.tile(self.TY[i, :], (self.N, 1))
        # Square the differences
        diff = np.multiply(diff, diff)
        # Sum squared differences across dimensions for each point pair
        P[i, :] = P[i, :] + np.sum(diff, axis=1)

    # Calculate uniform distribution component for outlier handling
    c = (2 * np.pi * self.sigma2) ** (self.D / 2)  # Normalization factor for Gaussian
    c = c * self.w / (1 - self.w)  # Scale by outlier ratio
    c = c * self.M / self.N  # Normalize by point cloud sizes

    # Convert distances to probabilities using Gaussian kernel
    P = np.exp(-P / (2 * self.sigma2))

    # Calculate denominator for posterior probability normalization
    den = np.sum(P, axis=0)
    den = np.tile(den, (self.M, 1))
    # Avoid division by zero
    den[den == 0] = np.finfo(float).eps
    # Add uniform component for outlier handling
    den += c

    # Compute normalized posterior probabilities
    self.P = np.divide(P, den)

    # Calculate marginal probabilities and total correspondence strength
    self.Pt1 = np.sum(self.P, axis=0)  # Column-wise sum - probability mass for each reference point
    self.P1 = np.sum(self.P, axis=1)  # Row-wise sum - probability mass for each source point
    self.Np = np.sum(self.P1)  # Total correspondence probability mass

get_registration_parameters Link

get_registration_parameters()

Get the current registration transformation parameters.

This is an abstract method that must be implemented by child classes to return the specific transformation parameters used in the registration.

Returns:

Type Description
dict

Dictionary containing the transformation parameters specific to the registration method.

Raises:

Type Description
NotImplementedError

If called from the base class without being overridden.

Source code in skeleton_refinement/expectation_maximization_registration.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def get_registration_parameters(self):
    """Get the current registration transformation parameters.

    This is an abstract method that must be implemented by child classes
    to return the specific transformation parameters used in the registration.

    Returns
    -------
    dict
        Dictionary containing the transformation parameters specific to
        the registration method.

    Raises
    ------
    NotImplementedError
        If called from the base class without being overridden.
    """
    raise NotImplementedError("Registration parameters should be defined in child classes.")

iterate Link

iterate()

Perform one Expectation-Maximization iteration.

This method runs a single EM iteration consisting of:

  1. Expectation step: compute point correspondences
  2. Maximization step: update transformation parameters

The iteration counter is incremented after each call.

Source code in skeleton_refinement/expectation_maximization_registration.py
284
285
286
287
288
289
290
291
292
293
294
295
296
def iterate(self):
    """Perform one Expectation-Maximization iteration.

    This method runs a single EM iteration consisting of:

    1. Expectation step: compute point correspondences
    2. Maximization step: update transformation parameters

    The iteration counter is incremented after each call.
    """
    self.expectation()
    self.maximization()
    self.iteration += 1

maximization Link

maximization()

Perform the Maximization step of the EM algorithm.

The maximization step updates the transformation parameters and variance to maximize the probability that the transformed source points were drawn from the GMM centered at the reference points.

This method calls the abstract methods that should be implemented by child classes:

  1. update_transform()
  2. transform_point_cloud()
  3. update_variance()
Source code in skeleton_refinement/expectation_maximization_registration.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def maximization(self):
    """Perform the Maximization step of the EM algorithm.

    The maximization step updates the transformation parameters and variance
    to maximize the probability that the transformed source points were drawn
    from the GMM centered at the reference points.

    This method calls the abstract methods that should be implemented by child classes:

    1. update_transform()
    2. transform_point_cloud()
    3. update_variance()
    """
    self.update_transform()
    self.transform_point_cloud()
    self.update_variance()

register Link

register(callback=lambda **kwargs: None)

Perform the point set registration.

This method runs the EM algorithm to align the source point cloud (Y) to the reference point cloud (X). The algorithm iteratively estimates point correspondences and updates the transformation parameters until convergence or maximum iterations are reached.

Parameters:

Name Type Description Default
callback callable

Function to call after each iteration with registration state information. The function should accept keyword arguments: iteration, error, X, Y. Default is a no-op function.

lambda **kwargs: None

Returns:

Type Description
ndarray

The transformed point cloud (TY).

dict

Registration parameters specific to the registration method.

Notes

The registration is considered converged when the change in error between iterations falls below the tolerance threshold or the maximum number of iterations is reached.

Source code in skeleton_refinement/expectation_maximization_registration.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def register(self, callback=lambda **kwargs: None):
    """Perform the point set registration.

    This method runs the EM algorithm to align the source point cloud (Y)
    to the reference point cloud (X). The algorithm iteratively estimates
    point correspondences and updates the transformation parameters until
    convergence or maximum iterations are reached.

    Parameters
    ----------
    callback : callable, optional
        Function to call after each iteration with registration state information.
        The function should accept keyword arguments: iteration, error, X, Y.
        Default is a no-op function.

    Returns
    -------
    numpy.ndarray
        The transformed point cloud (TY).
    dict
        Registration parameters specific to the registration method.

    Notes
    -----
    The registration is considered converged when the change in error between
    iterations falls below the tolerance threshold or the maximum number of
    iterations is reached.
    """
    # Initialize by transforming points according to current parameters
    self.transform_point_cloud()

    # If variance is not provided, calculate initial variance based on point clouds
    if self.sigma2 is None:
        self.sigma2 = initialize_sigma2(self.X, self.TY)

    # Initialize negative log-likelihood (q) based on current error and variance
    self.q = -self.err - self.N * self.D / 2 * np.log(self.sigma2)

    # Create progress bar
    pbar = tqdm(total=self.max_iterations, desc="Registration")

    # Main EM loop - continue until convergence or max iterations
    while self.iteration < self.max_iterations and self.err > self.tolerance:
        # Run one iteration of Expectation-Maximization algorithm
        self.iterate()
        # If callback is provided, execute it with current registration state
        if callable(callback):
            kwargs = {'iteration': self.iteration, 'error': self.err, 'X': self.X, 'Y': self.TY}
            callback(**kwargs)
        # Update progress bar
        pbar.update(1)
        pbar.set_postfix({"error": f"{self.err:.6f}", 'tol.': f'{self.tolerance}'})
        # If we've reached convergence, update to max to close the bar
        if self.err <= self.tolerance:
            pbar.n = self.max_iterations
            pbar.set_postfix({"error": f"{self.err:.6f}/{self.tolerance}", "total n_iter": f"{self.iteration}"})
            pbar.refresh()

    # Close the progress bar
    pbar.close()
    return

transform_point_cloud Link

transform_point_cloud()

Apply the current transformation to the source point cloud.

This is an abstract method that must be implemented by child classes to apply the specific transformation to the point cloud Y and update TY.

Raises:

Type Description
NotImplementedError

If called from the base class without being overridden.

Source code in skeleton_refinement/expectation_maximization_registration.py
176
177
178
179
180
181
182
183
184
185
186
187
def transform_point_cloud(self):
    """Apply the current transformation to the source point cloud.

    This is an abstract method that must be implemented by child classes
    to apply the specific transformation to the point cloud Y and update TY.

    Raises
    ------
    NotImplementedError
        If called from the base class without being overridden.
    """
    raise NotImplementedError("This method should be defined in child classes.")

update_transform Link

update_transform()

Update transformation parameters based on the current point correspondence.

This is an abstract method that must be implemented by child classes to update the specific transformation parameters (e.g., rotation matrix, scaling factor, etc.) based on the current state of the registration.

Raises:

Type Description
NotImplementedError

If called from the base class without being overridden.

Source code in skeleton_refinement/expectation_maximization_registration.py
162
163
164
165
166
167
168
169
170
171
172
173
174
def update_transform(self):
    """Update transformation parameters based on the current point correspondence.

    This is an abstract method that must be implemented by child classes
    to update the specific transformation parameters (e.g., rotation matrix,
    scaling factor, etc.) based on the current state of the registration.

    Raises
    ------
    NotImplementedError
        If called from the base class without being overridden.
    """
    raise NotImplementedError("This method should be defined in child classes.")

update_variance Link

update_variance()

Update the variance of the GMM model (sigma2).

This is an abstract method that must be implemented by child classes to update the variance parameter based on the current state of the registration process.

Raises:

Type Description
NotImplementedError

If called from the base class without being overridden.

Source code in skeleton_refinement/expectation_maximization_registration.py
189
190
191
192
193
194
195
196
197
198
199
200
201
def update_variance(self):
    """Update the variance of the GMM model (sigma2).

    This is an abstract method that must be implemented by child classes
    to update the variance parameter based on the current state of the
    registration process.

    Raises
    ------
    NotImplementedError
        If called from the base class without being overridden.
    """
    raise NotImplementedError("This method should be defined in child classes.")