
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/05-ot-based/01-plot_binary_classification.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_05-ot-based_01-plot_binary_classification.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_05-ot-based_01-plot_binary_classification.py:


Binary classification using OTDA
================================

.. GENERATED FROM PYTHON SOURCE LINES 7-9

Imports
-------

.. GENERATED FROM PYTHON SOURCE LINES 9-24

.. code-block:: Python


    import matplotlib.pyplot as plt
    import pandas as pd
    import seaborn as sns
    from sklearn.model_selection import train_test_split

    from uniharmony import verbosity
    from uniharmony.datasets import make_multisite_classification
    from uniharmony.ot import OptimalTransportDomainAdaptation


    sns.set_theme(style="whitegrid")
    verbosity("warning")









.. GENERATED FROM PYTHON SOURCE LINES 25-27

Data generation
---------------

.. GENERATED FROM PYTHON SOURCE LINES 27-33

.. code-block:: Python


    X, y, sites = make_multisite_classification(n_sites=3, n_features=2, site_effect_strength=10)

    X_train, X_val, y_train, y_val, sites_train, sites_val = train_test_split(X, y, sites, test_size=0.5, random_state=42)









.. GENERATED FROM PYTHON SOURCE LINES 34-36

Harmonisation
-------------

.. GENERATED FROM PYTHON SOURCE LINES 36-45

.. code-block:: Python


    otda = OptimalTransportDomainAdaptation()

    # Fit the transformer to adapts sites 1 and 2, to be similar to site 0
    otda.fit(X=X_train, sites=sites_train, y=y_train, ref_site=0)
    # Transform the validation data
    X_harmonized = otda.transform(X=X_val, sites=sites_val)









.. GENERATED FROM PYTHON SOURCE LINES 46-48

Plotting
--------

.. GENERATED FROM PYTHON SOURCE LINES 48-62

.. code-block:: Python


    df_orig = pd.DataFrame(X_train, columns=["Feature1", "Feature2"])
    df_orig["Site"] = sites_train

    df_harm = pd.DataFrame(X_harmonized, columns=["Feature1", "Feature2"])
    df_harm["Site"] = sites_val

    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=True)
    sns.scatterplot(data=df_orig, x="Feature1", y="Feature2", hue="Site", alpha=0.6, ax=axes[0])
    axes[0].set_title("Original data by site")
    sns.scatterplot(data=df_harm, x="Feature1", y="Feature2", hue="Site", alpha=0.6, ax=axes[1])
    axes[1].set_title("Harmonized data with OTDA by site")
    plt.tight_layout()




.. image-sg:: /auto_examples/05-ot-based/images/sphx_glr_01-plot_binary_classification_001.png
   :alt: Original data by site, Harmonized data with OTDA by site
   :srcset: /auto_examples/05-ot-based/images/sphx_glr_01-plot_binary_classification_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 63-66

.. admonition:: Take-home message

   The transported validation samples are now similar to site ``0``.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 1.254 seconds)


.. _sphx_glr_download_auto_examples_05-ot-based_01-plot_binary_classification.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: 01-plot_binary_classification.ipynb <01-plot_binary_classification.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: 01-plot_binary_classification.py <01-plot_binary_classification.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: 01-plot_binary_classification.zip <01-plot_binary_classification.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
