OptimalTransportDomainAdaptation#
- class uniharmony.ot.OptimalTransportDomainAdaptation(ot_method: str | BaseTransport = 'emd', metric: Literal['sqeuclidean', 'euclidean', 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'wminkowski', 'yule'] = 'euclidean', reg: float | None = 1.0, eta: float | None = 0.1, max_iter: int | None = 10, cost_norm: Literal['median', 'max', 'log', 'loglog'] | None = None, limit_max: int | None = 10, copy: bool = True)#
Optimal Transport for Domain Adaptation with reference site handling.
This class extends POTâs BaseTransport to provide a harmonization interface where data from multiple sites is aligned to a reference site(s) using optimal transport. The implementation supports both string-based OT method selection and direct injection of pre-configured OT instances.
- Parameters:
- ot_methodstr or BaseTransport instance, optional (default âemdâ)
Optimal transport method to use. Can be either:
A string: âemdâ, âsinkhornâ/âsâ, âsinkhorn_glâ/âs_glâ, âemd_laplaceâ/âemd_lâ
A pre-configured BaseTransport instance (e.g., ot.da.SinkhornTransport(reg_e=0.1))
- metricstr, optional (default âeuclideanâ)
Distance metric for cost matrix computation. Supports all metrics from scipy.spatial.distance.cdist and POTâs backend implementations.
- regfloat or None, optional (default 1.0)
Entropic regularization parameter. Used for Sinkhorn-based methods.
- etafloat or None, optional (default 0.1)
Regularization parameter for Laplace or group Lasso regularization.
- max_iterint or None, optional (default 10)
Maximum number of iterations for iterative solvers.
- cost_normstr or None, optional (default None)
Cost matrix normalization method: âmedianâ, âmaxâ, âlogâ, âloglogâ.
- limit_maxint or None, optional (default 10)
Semi-supervised mode control. Sets infinite cost (10 * max(cost)) for transport between different classes.
- Attributes:
- ot_obj_BaseTransport
The fitted underlying OT object (set during fit).
- ref_site_str or list of str
The reference site(s) used for alignment.
- coupling_array, shape (n_source_samples, n_target_samples)
The optimal coupling matrix (forwarded from ot_obj_).
- cost_array
The computed cost matrix (forwarded from ot_obj_).
Methods
fit(X, sites, ref_site[, y])Fit optimal transport from non-reference sites to reference site(s).
fit_transform(X, sites, ref_site[, y, ...])Fit and transform in one step.
get_metadata_routing()Get metadata routing of this object.
get_params([deep])Get parameters for this estimator.
inverse_transform(X[, sites, y, batch_size])Transform data from reference back to original source domain.
inverse_transform_labels([yt])Propagate target labels \(\mathbf{y_t}\) to obtain estimated source labels \(\mathbf{y_s}\)
set_fit_request(*[, ref_site, sites])Configure whether metadata should be requested to be passed to the
fitmethod.set_inverse_transform_request(*[, ...])Configure whether metadata should be requested to be passed to the
inverse_transformmethod.set_output(*[, transform])Set output container.
set_params(**params)Set the parameters of this estimator.
set_transform_request(*[, batch_size, sites])Configure whether metadata should be requested to be passed to the
transformmethod.transform(X[, sites, y, batch_size])Transform data using the fitted OT plan.
transform_labels([ys])Propagate source labels \(\mathbf{y_s}\) to obtain estimated target labels as in [27].
Examples
>>> # Using string-based method selection >>> otda = OptimalTransportDomainAdaptation(ot_method="sinkhorn", reg=0.1, metric="sqeuclidean") >>> otda.fit(X_train, sites_train, ref_site="site_A", y=labels) >>> X_aligned = otda.transform(X_test, sites_test)
>>> # Using pre-configured OT instance >>> from ot.da import SinkhornTransport >>> ot_solver = SinkhornTransport(reg_e=0.5, norm="median") >>> otda = OptimalTransportDomainAdaptation(ot_method=ot_solver) >>> otda.fit(X_train, sites_train, ref_site="site_A") >>> X_aligned = otda.transform(X_test, sites_test))
- fit(X: ArrayLike, sites: ArrayLike, ref_site: str | list[str] | int | list[int], y: ArrayLike | None = None) OptimalTransportDomainAdaptation#
Fit optimal transport from non-reference sites to reference site(s).
This method separates the data into reference (target) and non-reference (source) domains, then fits the optimal transport plan to map source distributions onto the reference distribution.
- Parameters:
- Xarray-like, shape (n_samples, n_features)
Input data from all sites.
- sitesarray-like, shape (n_samples,)
Site labels for each sample. Must align with X.
- ref_sitestr, int or list of str or list of int
Site identifier(s) to use as reference (target domain). If list, combines all specified sites as the reference distribution.
- yarray-like, shape (n_samples,) or (n_samples, n_classes), optional (default None)
Labels for supervised/semi-supervised transport. Used for cost matrix computation. Must align with X if provided.
- Returns:
- selfOptimalTransportDomainAdaptation
Fitted instance.
- fit_transform(X: ArrayLike, sites: ArrayLike, ref_site: str | list[str] | int | list[int], y: ArrayLike | None = None, batch_size: int = 128) ndarray[tuple[Any, ...], dtype[_ScalarT]]#
Fit and transform in one step.
- Parameters:
- Xarray-like, shape (n_samples, n_features)
Input data.
- sitesarray-like, shape (n_samples,)
Site labels.
- ref_sitestr or list of str
Reference site(s).
- yarray-like, optional (default None)
Labels for supervised transport.
- batch_sizeint, optional (default 128)
Batch size for transformation.
- Returns:
- X_transformedndarray
Data aligned to reference distribution.
- inverse_transform(X: ArrayLike, sites: ArrayLike | None = None, y: ArrayLike | None = None, batch_size: int = 128) ndarray[tuple[Any, ...], dtype[_ScalarT]]#
Transform data from reference back to original source domain.
- Parameters:
- Xarray-like, shape (n_samples, n_features)
Reference domain data to map back.
- sitesarray-like, shape (n_samples,), optional (default None)
Site labels for X. If provided, only reference site samples are inverse transformed. Non-reference samples are returned as-is.
- yarray-like, shape (n_samples,) or (n_samples, n_classes), optional (default None)
- Labels for supervised transport. Must align with X if provided.
Used to ensure consistent handling of supervised transformations.
- batch_sizeint, optional (default 128)
Batch size for transformation.
- Returns:
- X_inv_transformedndarray
Data mapped back to source distribution.
- set_fit_request(*, ref_site: bool | None | str = '$UNCHANGED$', sites: bool | None | str = '$UNCHANGED$') OptimalTransportDomainAdaptation#
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
- ref_sitestr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
ref_siteparameter infit.- sitesstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
sitesparameter infit.
- Returns:
- selfobject
The updated object.
- set_inverse_transform_request(*, batch_size: bool | None | str = '$UNCHANGED$', sites: bool | None | str = '$UNCHANGED$') OptimalTransportDomainAdaptation#
Configure whether metadata should be requested to be passed to the
inverse_transformmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed toinverse_transformif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it toinverse_transform.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
- batch_sizestr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
batch_sizeparameter ininverse_transform.- sitesstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
sitesparameter ininverse_transform.
- Returns:
- selfobject
The updated object.
- set_transform_request(*, batch_size: bool | None | str = '$UNCHANGED$', sites: bool | None | str = '$UNCHANGED$') OptimalTransportDomainAdaptation#
Configure whether metadata should be requested to be passed to the
transformmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed totransformif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it totransform.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
- batch_sizestr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
batch_sizeparameter intransform.- sitesstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
sitesparameter intransform.
- Returns:
- selfobject
The updated object.
- transform(X: ArrayLike, sites: ArrayLike | None = None, y: ArrayLike | None = None, batch_size: int = 128) ndarray[tuple[Any, ...], dtype[_ScalarT]]#
Transform data using the fitted OT plan.
Transports samples from the source domain (non-reference sites) to the target domain (reference site). If sites are provided, only transforms non-reference samples; reference samples are returned unchanged.
- Parameters:
- Xarray-like, shape (n_samples, n_features)
Input data to transform.
- sitesarray-like, shape (n_samples,), optional (default None)
Site labels for X. If provided, only non-reference sites are transformed. Reference site samples are returned as-is.
- yarray-like, shape (n_samples,) or (n_samples, n_classes), optional (default None)
- Labels for supervised transport. Must align with X if provided.
Used to ensure consistent handling of supervised transformations.
- batch_sizeint, optional (default 128)
Batch size for out-of-sample transformation.
- Returns:
- X_transformedndarray, shape (n_samples, n_features)
Transformed data aligned to reference distribution.