
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/03-combat-based/05-plot_combatgam_imbalance_across_sites.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_03-combat-based_05-plot_combatgam_imbalance_across_sites.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_03-combat-based_05-plot_combatgam_imbalance_across_sites.py:


Analysing ComBatGAM behaviour with imbalance across sites
=========================================================

.. GENERATED FROM PYTHON SOURCE LINES 7-9

Imports
-------

.. GENERATED FROM PYTHON SOURCE LINES 9-23

.. code-block:: Python


    import matplotlib.pyplot as plt
    import pandas as pd
    import seaborn as sns

    from uniharmony import verbosity
    from uniharmony.combat import ComBatGAM
    from uniharmony.datasets import make_multisite_classification


    sns.set_theme(style="whitegrid")
    verbosity("warning")









.. GENERATED FROM PYTHON SOURCE LINES 24-26

Data generation
---------------

.. GENERATED FROM PYTHON SOURCE LINES 26-41

.. code-block:: Python


    X, y, sites = make_multisite_classification(
        n_features=2,
        signal_strength=2,
        site_effect_strength=0,  # NO site effect
        balance_per_site=[0.1, 0.9],
    )
    df = pd.DataFrame({"Target": y, "Site": sites})

    plt.figure(figsize=[10, 6])
    plt.title("Unbalanced classes by site")
    sns.countplot(df, x="Target", hue="Site")
    plt.grid(axis="y", color="black", alpha=0.5, linestyle="--")





.. image-sg:: /auto_examples/03-combat-based/images/sphx_glr_05-plot_combatgam_imbalance_across_sites_001.png
   :alt: Unbalanced classes by site
   :srcset: /auto_examples/03-combat-based/images/sphx_glr_05-plot_combatgam_imbalance_across_sites_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 42-46

.. caution::

   Note that we are harmonising the whole dataset, which must be avoided in ML scenarios.
   This is just to illustrate the effect of harmonisation.

.. GENERATED FROM PYTHON SOURCE LINES 48-50

Harmonisation
-------------

.. GENERATED FROM PYTHON SOURCE LINES 50-55

.. code-block:: Python


    combat = ComBatGAM()
    combat.fit(X.copy(), sites, smooth_covariates=y.reshape(-1, 1))
    X_harmonized = combat.transform(X, sites, smooth_covariates=y.reshape(-1, 1))








.. GENERATED FROM PYTHON SOURCE LINES 56-58

Plotting
--------

.. GENERATED FROM PYTHON SOURCE LINES 58-81

.. code-block:: Python


    df_orig = pd.DataFrame(X, columns=["Feature1", "Feature2"])
    df_orig["Site"] = sites
    df_orig["Target"] = y
    df_orig["Phase"] = "Original"

    df_harm = pd.DataFrame(X_harmonized, columns=["Feature1", "Feature2"])
    df_harm["Site"] = sites
    df_harm["Target"] = y
    df_harm["Phase"] = "Harmonized"


    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=True)
    sns.scatterplot(data=df_orig, x="Feature1", y="Feature2", hue="Target", alpha=0.6, ax=axes[0])
    axes[0].set_title("Original data by site")
    axes[0].grid(alpha=0.3, color="black", linestyle="--")

    sns.scatterplot(data=df_harm, x="Feature1", y="Feature2", hue="Target", alpha=0.6, ax=axes[1])
    axes[1].set_title("Harmonized data by site")
    axes[1].grid(alpha=0.3, color="black", linestyle="--")
    plt.tight_layout()





.. image-sg:: /auto_examples/03-combat-based/images/sphx_glr_05-plot_combatgam_imbalance_across_sites_002.png
   :alt: Original data by site, Harmonized data by site
   :srcset: /auto_examples/03-combat-based/images/sphx_glr_05-plot_combatgam_imbalance_across_sites_002.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 82-84

Preserving the target as covariate
----------------------------------

.. GENERATED FROM PYTHON SOURCE LINES 86-90

.. caution::

   This is also wrong in ML context, where you don't have access to the full
   dataset but may be a good option for statistical analysis.

.. GENERATED FROM PYTHON SOURCE LINES 92-112

.. code-block:: Python


    combat = ComBatGAM()
    # This is the key line: we need to include the target variable as a covariate
    # to preserve its relationship with the features during harmonization.

    combat.fit(X.copy(), sites, smooth_covariates=y.reshape(-1, 1))
    X_harmonized = combat.transform(X.copy(), sites, smooth_covariates=y.reshape(-1, 1))

    df_orig = pd.DataFrame(X, columns=["Feature1", "Feature2"])
    df_orig["Site"] = sites
    df_orig["Target"] = y

    df_orig["Phase"] = "Original"

    df_harm = pd.DataFrame(X_harmonized, columns=["Feature1", "Feature2"])
    df_harm["Site"] = sites
    df_harm["Target"] = y

    df_harm["Phase"] = "Harmonized"








.. GENERATED FROM PYTHON SOURCE LINES 113-115

Plotting
--------

.. GENERATED FROM PYTHON SOURCE LINES 115-124

.. code-block:: Python


    # Plot data distribution by site before and after harmonisation
    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=True)
    sns.scatterplot(data=df_orig, x="Feature1", y="Feature2", hue="Site", alpha=0.6, ax=axes[0])
    axes[0].set_title("Original data by site")
    sns.scatterplot(data=df_harm, x="Feature1", y="Feature2", hue="Site", alpha=0.6, ax=axes[1])
    axes[1].set_title("Harmonized data by site")
    plt.tight_layout()




.. image-sg:: /auto_examples/03-combat-based/images/sphx_glr_05-plot_combatgam_imbalance_across_sites_003.png
   :alt: Original data by site, Harmonized data by site
   :srcset: /auto_examples/03-combat-based/images/sphx_glr_05-plot_combatgam_imbalance_across_sites_003.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 125-130

.. admonition:: Take-home message

   ComBatGAM cannot preserve the target variance in class imbalance scenarios unless we preserve it as covariate.
   Note that preserving the target as covariate may be suited for statistical analysis, but not for ML scenarios.
   The implementation warns us about the preservation of a covariate.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 3.327 seconds)


.. _sphx_glr_download_auto_examples_03-combat-based_05-plot_combatgam_imbalance_across_sites.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: 05-plot_combatgam_imbalance_across_sites.ipynb <05-plot_combatgam_imbalance_across_sites.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: 05-plot_combatgam_imbalance_across_sites.py <05-plot_combatgam_imbalance_across_sites.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: 05-plot_combatgam_imbalance_across_sites.zip <05-plot_combatgam_imbalance_across_sites.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
