Discover biases in metrics by site

Discover biases in metrics by site#

uniharmony allows you to stratify the performance metrics by site, unraveling hidden patterns. In this example, we will not simulate site effects.

Imports#

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split

from uniharmony import verbosity
from uniharmony.datasets import make_multisite_classification
from uniharmony.metrics import report_metrics_by_site


sns.set_theme(style="whitegrid")
verbosity("error")

clf = LogisticRegression()

Data generation#

Let’s create the first scenario: a dataset with 3 good sites and 1 bad site (signal strength = 0) to show the effect of having a bad site in the dataset

n_bad_sites = 1
X_bad, y_bad, sites_bad = make_multisite_classification(
    n_sites=n_bad_sites,
    signal_strength=0,
    site_effect_strength=0,
)

# Used to simulate "good" sites
signal_strength = 1
X_good, y_good, sites_good = make_multisite_classification(
    n_sites=3,
    signal_strength=signal_strength,
    site_effect_strength=0,
)
# Increase site labels for good sites to avoid overlap with bad sites
sites_good = sites_good + n_bad_sites

X = np.concatenate([X_bad, X_good], axis=0)
y = np.concatenate([y_bad, y_good], axis=0)
sites = np.concatenate([sites_bad, sites_good], axis=0)

X_train, X_test, y_train, y_test, sites_train, sites_test = train_test_split(X, y, sites)

clf.fit(X_train, y_train)
y_pred_s1 = clf.predict(X_test)
metric_s1 = report_metrics_by_site(
    y_test,
    y_pred_s1,
    sites_test,
    balanced_accuracy_score,
    overall_performance=True,
)
print(f"Overall bACC for Scenario 1: {metric_s1['balanced_accuracy_score']['overall']:.3}")
Overall bACC for Scenario 1: 0.64

Now let’s create a second scenario: a dataset with 3 bad sites and 1 good site (signal strength = 1)

n_bad_sites = 3
X_bad, y_bad, sites_bad = make_multisite_classification(
    n_sites=n_bad_sites,
    signal_strength=0,
    site_effect_strength=0,
)

# Used to simulate "good" sites
signal_strength = 1
X_good, y_good, sites_good = make_multisite_classification(
    n_sites=1,
    signal_strength=signal_strength,
    site_effect_strength=0,
)
# Increase site labels for good sites to avoid overlap with bad sites
sites_good = sites_good + n_bad_sites

X = np.concatenate([X_bad, X_good], axis=0)
y = np.concatenate([y_bad, y_good], axis=0)
sites = np.concatenate([sites_bad, sites_good], axis=0)

X_train, X_test, y_train, y_test, sites_train, sites_test = train_test_split(X, y, sites)

clf.fit(X_train, y_train)
y_pred_s2 = clf.predict(X_test)
metric_s2 = report_metrics_by_site(y_test, y_pred_s2, sites_test, balanced_accuracy_score, overall_performance=True)
print(f"Overall bACC for Scenario 2: {metric_s2['balanced_accuracy_score']['overall']:.3}")

# # %%
# # Let's plot the results obtained in the each site

# sites_unique = np.unique(sites)
# # Extract global performance for both scenarios
# metric_global_s1 = metric_s1['balanced_accuracy_score'].pop("overall")
# metric_global_s2 = metric_s2['balanced_accuracy_score'].pop("overall")

# # Visualize both scenarios
# fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# # Scenario 1
# site_scores_s1 = [metric_s1[s] for s in sites_unique]

# sns.barplot(
#     x=sites_unique,
#     y=site_scores_s1,
#     color="steelblue",
#     label="Site Scores",
#     ax=axes[0],
# )
# axes[0].axhline(
#     metric_global_s1,
#     color="black",
#     linestyle="--",
#     label=f"Global: {metric_global_s1:.3f}",
# )
# axes[0].axhline(
#     0.5,
#     color="red",
#     linestyle="--",
#     alpha=0.7,
#     label="Chance level: 0.5",
# )
# axes[0].set_xlabel("Site")
# axes[0].set_ylabel("Balanced Accuracy")
# axes[0].set_title("Scenario 1: Good Overall, One Site Fails")
# axes[0].legend()
# axes[0].grid(True, alpha=1, axis="y")
# axes[0].set_ylim([0, 1])

# # Scenario 2
# site_scores_s2 = [metric_s2[s] for s in sites_unique]
# sns.barplot(
#     x=sites_unique,
#     y=site_scores_s2,
#     color="coral",
#     label="Site Scores",
#     ax=axes[1],
# )
# axes[1].axhline(
#     metric_global_s2,
#     color="black",
#     linestyle="--",
#     label=f"Global: {metric_global_s2:.3f}",
# )
# axes[1].axhline(
#     0.5,
#     color="red",
#     linestyle="--",
#     alpha=0.7,
#     label="Chance level: 0.5",
# )
# axes[1].set_xlabel("Site")
# axes[1].set_ylabel("Balanced Accuracy")
# axes[1].set_title("Scenario 2: Bad Overall, One Site Excels")
# axes[1].legend()
# axes[1].set_ylim([0, 1])

# plt.tight_layout()
# plt.show()
Overall bACC for Scenario 2: 0.655

But, how is it possible that they have an similar overall performance? Where is the catch? The sites have different number of samples!

In the first scenario, even when the first site is bigger, the other 3 compensates the bad performance. In the second scenario, the last site (good one) is bigger an pushes the overall performance up.

If we had only reported the overall performance, we would not be able to unravel the site’s behavior.

Total running time of the script: (0 minutes 2.564 seconds)

Gallery generated by Sphinx-Gallery