
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/03-combat-based/04-plot_combatgam_binary_classification.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_03-combat-based_04-plot_combatgam_binary_classification.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_03-combat-based_04-plot_combatgam_binary_classification.py:


Binary classification with ComBatGAM
====================================

.. GENERATED FROM PYTHON SOURCE LINES 7-9

Imports
-------

.. GENERATED FROM PYTHON SOURCE LINES 9-23

.. code-block:: Python


    import matplotlib.pyplot as plt
    import pandas as pd
    import seaborn as sns

    from uniharmony import verbosity
    from uniharmony.combat import ComBatGAM
    from uniharmony.datasets import make_multisite_classification


    sns.set_theme(style="whitegrid")
    verbosity("warning")









.. GENERATED FROM PYTHON SOURCE LINES 24-26

Data generation
---------------

.. GENERATED FROM PYTHON SOURCE LINES 26-40

.. code-block:: Python


    X, y, sites = make_multisite_classification(
        n_features=2,
        site_effect_strength=10,
        signal_strength=0,
    )
    df = pd.DataFrame({"Target": y, "Site": sites})

    plt.figure(figsize=[10, 6])
    plt.title("Generated data by site")
    sns.countplot(df, x="Target", hue="Site")
    plt.grid(axis="y", color="black", alpha=0.5, linestyle="--")





.. image-sg:: /auto_examples/03-combat-based/images/sphx_glr_04-plot_combatgam_binary_classification_001.png
   :alt: Generated data by site
   :srcset: /auto_examples/03-combat-based/images/sphx_glr_04-plot_combatgam_binary_classification_001.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    2026-05-18 13:04:40 [warning  ] signal_strength is 0. Adding a delta (1e-6) to signal_strength to avoid degenerate data.




.. GENERATED FROM PYTHON SOURCE LINES 41-43

Harmonisation
-------------

.. GENERATED FROM PYTHON SOURCE LINES 43-63

.. code-block:: Python


    combat = ComBatGAM()
    X_harmonized = combat.fit_transform(X, sites, smooth_covariates=y.reshape(-1, 1))

    df_orig = pd.DataFrame(X, columns=["Feature1", "Feature2"])
    df_orig["Site"] = sites
    df_orig["Phase"] = "Original"

    df_harm = pd.DataFrame(X_harmonized, columns=["Feature1", "Feature2"])
    df_harm["Site"] = sites
    df_harm["Phase"] = "Harmonized"

    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=True)
    sns.scatterplot(data=df_orig, x="Feature1", y="Feature2", hue="Site", alpha=0.6, ax=axes[0])
    axes[0].set_title("Original data by site")
    sns.scatterplot(data=df_harm, x="Feature1", y="Feature2", hue="Site", alpha=0.6, ax=axes[1])
    axes[1].set_title("Harmonized data by site")
    plt.tight_layout()





.. image-sg:: /auto_examples/03-combat-based/images/sphx_glr_04-plot_combatgam_binary_classification_002.png
   :alt: Original data by site, Harmonized data by site
   :srcset: /auto_examples/03-combat-based/images/sphx_glr_04-plot_combatgam_binary_classification_002.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 64-66

Plotting
--------

.. GENERATED FROM PYTHON SOURCE LINES 66-77

.. code-block:: Python


    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=True)
    sns.boxplot(data=df_orig, y="Feature1", hue="Site", ax=axes[0])
    axes[0].set_title("Original data by site")
    axes[0].grid(axis="y", color="black", alpha=0.5, linestyle="--")
    sns.boxplot(data=df_harm, y="Feature1", hue="Site", ax=axes[1])
    axes[1].set_title("Harmonized data by site")
    axes[1].grid(axis="y", color="black", alpha=0.5, linestyle="--")
    plt.tight_layout()





.. image-sg:: /auto_examples/03-combat-based/images/sphx_glr_04-plot_combatgam_binary_classification_003.png
   :alt: Original data by site, Harmonized data by site
   :srcset: /auto_examples/03-combat-based/images/sphx_glr_04-plot_combatgam_binary_classification_003.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 78-85

.. code-block:: Python


    print("Feature means by site before harmonization:")
    print(df_orig["Feature1"].groupby(df_orig["Site"]).mean())
    print("Feature means by site after harmonization:")
    print(df_harm["Feature1"].groupby(df_harm["Site"]).mean())






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Feature means by site before harmonization:
    Site
    0    3.454432
    1    2.313463
    Name: Feature1, dtype: float64
    Feature means by site after harmonization:
    Site
    0    2.881748
    1    2.886076
    Name: Feature1, dtype: float64




.. GENERATED FROM PYTHON SOURCE LINES 86-97

.. code-block:: Python


    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=True)
    sns.boxplot(data=df_orig, y="Feature2", hue="Site", ax=axes[0])
    axes[0].set_title("Original data by site")
    axes[0].grid(axis="y", color="black", alpha=0.5, linestyle="--")
    sns.boxplot(data=df_harm, y="Feature2", hue="Site", ax=axes[1])
    axes[1].set_title("Harmonized data by site")
    axes[1].grid(axis="y", color="black", alpha=0.5, linestyle="--")
    plt.tight_layout()





.. image-sg:: /auto_examples/03-combat-based/images/sphx_glr_04-plot_combatgam_binary_classification_004.png
   :alt: Original data by site, Harmonized data by site
   :srcset: /auto_examples/03-combat-based/images/sphx_glr_04-plot_combatgam_binary_classification_004.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 98-104

.. code-block:: Python


    print("Feature means by site before harmonization:")
    print(df_orig["Feature2"].groupby(df_orig["Site"]).mean())
    print("Feature means by site after harmonization:")
    print(df_harm["Feature2"].groupby(df_harm["Site"]).mean())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Feature means by site before harmonization:
    Site
    0    3.599556
    1    2.287443
    Name: Feature2, dtype: float64
    Feature means by site after harmonization:
    Site
    0    2.944512
    1    2.942473
    Name: Feature2, dtype: float64




.. GENERATED FROM PYTHON SOURCE LINES 105-108

.. admonition:: Take-home message

   As expected, ComBatGAM pushes the mean of the site distributions closer.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 2.475 seconds)


.. _sphx_glr_download_auto_examples_03-combat-based_04-plot_combatgam_binary_classification.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: 04-plot_combatgam_binary_classification.ipynb <04-plot_combatgam_binary_classification.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: 04-plot_combatgam_binary_classification.py <04-plot_combatgam_binary_classification.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: 04-plot_combatgam_binary_classification.zip <04-plot_combatgam_binary_classification.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
