OptimalTransportDomainAdaptation#

class uniharmony.ot.OptimalTransportDomainAdaptation(ot_method: str | BaseTransport = 'emd', metric: Literal['sqeuclidean', 'euclidean', 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'wminkowski', 'yule'] = 'euclidean', reg: float | None = 1.0, eta: float | None = 0.1, max_iter: int | None = 10, cost_norm: Literal['median', 'max', 'log', 'loglog'] | None = None, limit_max: int | None = 10, copy: bool = True)#

Optimal Transport for Domain Adaptation with reference site handling.

This class extends POT’s BaseTransport to provide a harmonization interface where data from multiple sites is aligned to a reference site(s) using optimal transport. The implementation supports both string-based OT method selection and direct injection of pre-configured OT instances.

Parameters:
ot_methodstr or BaseTransport instance, optional (default “emd”)

Optimal transport method to use. Can be either:

  • A string: “emd”, “sinkhorn”/”s”, “sinkhorn_gl”/”s_gl”, “emd_laplace”/”emd_l”

  • A pre-configured BaseTransport instance (e.g., ot.da.SinkhornTransport(reg_e=0.1))

metricstr, optional (default “euclidean”)

Distance metric for cost matrix computation. Supports all metrics from scipy.spatial.distance.cdist and POT’s backend implementations.

regfloat or None, optional (default 1.0)

Entropic regularization parameter. Used for Sinkhorn-based methods.

etafloat or None, optional (default 0.1)

Regularization parameter for Laplace or group Lasso regularization.

max_iterint or None, optional (default 10)

Maximum number of iterations for iterative solvers.

cost_normstr or None, optional (default None)

Cost matrix normalization method: “median”, “max”, “log”, “loglog”.

limit_maxint or None, optional (default 10)

Semi-supervised mode control. Sets infinite cost (10 * max(cost)) for transport between different classes.

Attributes:
ot_obj_BaseTransport

The fitted underlying OT object (set during fit).

ref_site_str or list of str

The reference site(s) used for alignment.

coupling_array, shape (n_source_samples, n_target_samples)

The optimal coupling matrix (forwarded from ot_obj_).

cost_array

The computed cost matrix (forwarded from ot_obj_).

Methods

fit(X, sites, ref_site[, y])

Fit optimal transport from non-reference sites to reference site(s).

fit_transform(X, sites, ref_site[, y, ...])

Fit and transform in one step.

get_metadata_routing()

Get metadata routing of this object.

get_params([deep])

Get parameters for this estimator.

inverse_transform(X[, sites, y, batch_size])

Transform data from reference back to original source domain.

inverse_transform_labels([yt])

Propagate target labels \(\mathbf{y_t}\) to obtain estimated source labels \(\mathbf{y_s}\)

set_fit_request(*[, ref_site, sites])

Configure whether metadata should be requested to be passed to the fit method.

set_inverse_transform_request(*[, ...])

Configure whether metadata should be requested to be passed to the inverse_transform method.

set_output(*[, transform])

Set output container.

set_params(**params)

Set the parameters of this estimator.

set_transform_request(*[, batch_size, sites])

Configure whether metadata should be requested to be passed to the transform method.

transform(X[, sites, y, batch_size])

Transform data using the fitted OT plan.

transform_labels([ys])

Propagate source labels \(\mathbf{y_s}\) to obtain estimated target labels as in [27].

Examples

>>> # Using string-based method selection
>>> otda = OptimalTransportDomainAdaptation(ot_method="sinkhorn", reg=0.1, metric="sqeuclidean")
>>> otda.fit(X_train, sites_train, ref_site="site_A", y=labels)
>>> X_aligned = otda.transform(X_test, sites_test)
>>> # Using pre-configured OT instance
>>> from ot.da import SinkhornTransport
>>> ot_solver = SinkhornTransport(reg_e=0.5, norm="median")
>>> otda = OptimalTransportDomainAdaptation(ot_method=ot_solver)
>>> otda.fit(X_train, sites_train, ref_site="site_A")
>>> X_aligned = otda.transform(X_test, sites_test))
fit(X: ArrayLike, sites: ArrayLike, ref_site: str | list[str] | int | list[int], y: ArrayLike | None = None) OptimalTransportDomainAdaptation#

Fit optimal transport from non-reference sites to reference site(s).

This method separates the data into reference (target) and non-reference (source) domains, then fits the optimal transport plan to map source distributions onto the reference distribution.

Parameters:
Xarray-like, shape (n_samples, n_features)

Input data from all sites.

sitesarray-like, shape (n_samples,)

Site labels for each sample. Must align with X.

ref_sitestr, int or list of str or list of int

Site identifier(s) to use as reference (target domain). If list, combines all specified sites as the reference distribution.

yarray-like, shape (n_samples,) or (n_samples, n_classes), optional (default None)

Labels for supervised/semi-supervised transport. Used for cost matrix computation. Must align with X if provided.

Returns:
selfOptimalTransportDomainAdaptation

Fitted instance.

fit_transform(X: ArrayLike, sites: ArrayLike, ref_site: str | list[str] | int | list[int], y: ArrayLike | None = None, batch_size: int = 128) ndarray[tuple[Any, ...], dtype[_ScalarT]]#

Fit and transform in one step.

Parameters:
Xarray-like, shape (n_samples, n_features)

Input data.

sitesarray-like, shape (n_samples,)

Site labels.

ref_sitestr or list of str

Reference site(s).

yarray-like, optional (default None)

Labels for supervised transport.

batch_sizeint, optional (default 128)

Batch size for transformation.

Returns:
X_transformedndarray

Data aligned to reference distribution.

inverse_transform(X: ArrayLike, sites: ArrayLike | None = None, y: ArrayLike | None = None, batch_size: int = 128) ndarray[tuple[Any, ...], dtype[_ScalarT]]#

Transform data from reference back to original source domain.

Parameters:
Xarray-like, shape (n_samples, n_features)

Reference domain data to map back.

sitesarray-like, shape (n_samples,), optional (default None)

Site labels for X. If provided, only reference site samples are inverse transformed. Non-reference samples are returned as-is.

yarray-like, shape (n_samples,) or (n_samples, n_classes), optional (default None)
Labels for supervised transport. Must align with X if provided.

Used to ensure consistent handling of supervised transformations.

batch_sizeint, optional (default 128)

Batch size for transformation.

Returns:
X_inv_transformedndarray

Data mapped back to source distribution.

set_fit_request(*, ref_site: bool | None | str = '$UNCHANGED$', sites: bool | None | str = '$UNCHANGED$') OptimalTransportDomainAdaptation#

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
ref_sitestr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for ref_site parameter in fit.

sitesstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for sites parameter in fit.

Returns:
selfobject

The updated object.

set_inverse_transform_request(*, batch_size: bool | None | str = '$UNCHANGED$', sites: bool | None | str = '$UNCHANGED$') OptimalTransportDomainAdaptation#

Configure whether metadata should be requested to be passed to the inverse_transform method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to inverse_transform if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to inverse_transform.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
batch_sizestr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for batch_size parameter in inverse_transform.

sitesstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for sites parameter in inverse_transform.

Returns:
selfobject

The updated object.

set_transform_request(*, batch_size: bool | None | str = '$UNCHANGED$', sites: bool | None | str = '$UNCHANGED$') OptimalTransportDomainAdaptation#

Configure whether metadata should be requested to be passed to the transform method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to transform if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to transform.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
batch_sizestr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for batch_size parameter in transform.

sitesstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for sites parameter in transform.

Returns:
selfobject

The updated object.

transform(X: ArrayLike, sites: ArrayLike | None = None, y: ArrayLike | None = None, batch_size: int = 128) ndarray[tuple[Any, ...], dtype[_ScalarT]]#

Transform data using the fitted OT plan.

Transports samples from the source domain (non-reference sites) to the target domain (reference site). If sites are provided, only transforms non-reference samples; reference samples are returned unchanged.

Parameters:
Xarray-like, shape (n_samples, n_features)

Input data to transform.

sitesarray-like, shape (n_samples,), optional (default None)

Site labels for X. If provided, only non-reference sites are transformed. Reference site samples are returned as-is.

yarray-like, shape (n_samples,) or (n_samples, n_classes), optional (default None)
Labels for supervised transport. Must align with X if provided.

Used to ensure consistent handling of supervised transformations.

batch_sizeint, optional (default 128)

Batch size for out-of-sample transformation.

Returns:
X_transformedndarray, shape (n_samples, n_features)

Transformed data aligned to reference distribution.

Examples#

Binary classification using OTDA

Binary classification using OTDA