Note
Go to the end to download the full example code.
Compute metrics by site#
Imports#
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
from uniharmony import verbosity
from uniharmony.datasets import make_multisite_classification
from uniharmony.metrics import report_metrics_by_site
sns.set_theme(style="whitegrid")
verbosity("error")
clf = LogisticRegression()
Data generation#
X, y, sites = make_multisite_classification(n_sites=10, signal_strength=0.5)
X_train, X_test, y_train, y_test, sites_train, sites_test = train_test_split(X, y, sites)
Metrics by site report#
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
y_scores = clf.predict_proba(X_test)[:, 0]
metrics = report_metrics_by_site(y_test, y_pred, sites_test, balanced_accuracy_score)
# for key in metrics.keys():
# print(f"For site {key}: bACC {metrics[key]:.4}")
# Compute metrics but now request the overall
metrics = report_metrics_by_site(y_test, y_pred, sites_test, balanced_accuracy_score, overall_performance=True)
# Compute the metric outside the function to compare.
bacc = balanced_accuracy_score(y_true=y_test, y_pred=y_pred)
# Overall comparison.
print(f"Overall bACC: {bacc}")
# The overall performance is also stored in the metrics if requested.
print(f"Overall bACC: {metrics['balanced_accuracy_score']['overall']}")
Overall bACC: 0.6694415983606558
Overall bACC: 0.6694415983606558
If requested, the function also computes the overall performance and stores it as another entry in the dictionary.
# Single metric (simplest case)
metrics = report_metrics_by_site(y_test, y_scores, sites_test, accuracy_score)
print(metrics)
# Single metric with kwargs
metrics = report_metrics_by_site(y_test, y_scores, sites_test, f1_score, metric_kwargs={"threshold": 0.5, "average": "macro"})
print(metrics)
# Multiple metrics
metrics = report_metrics_by_site(
y_test,
y_scores,
sites_test,
metrics=[roc_auc_score, accuracy_score],
metric_kwargs=[{}, {"threshold": 0.5}],
)
print(metrics)
# With overall performance
metrics = report_metrics_by_site(
y_test,
y_scores,
sites_test,
metrics=[accuracy_score, roc_auc_score],
overall_performance=True,
)
print(metrics)
{'accuracy_score': {'overall': 0.332, 0: 0.4166666666666667, 1: 0.3870967741935484, 2: 0.3076923076923077, 3: 0.26666666666666666, 4: 0.2222222222222222, 5: 0.3, 6: 0.16666666666666666, 7: 0.4642857142857143, 8: 0.3448275862068966, 9: 0.34615384615384615}}
{'f1_score': {'overall': 0.3288968188905499, 0: 0.37777777777777777, 1: 0.3433667781493868, 2: 0.26875, 3: 0.26666666666666666, 4: 0.2125, 5: 0.29292929292929293, 6: 0.14285714285714285, 7: 0.44664031620553357, 8: 0.315527950310559, 9: 0.3210445468509985}}
{'roc_auc_score': {'overall': 0.31070696721311475, 0: 0.20714285714285713, 1: 0.40454545454545454, 2: 0.34911242603550297, 3: 0.27777777777777773, 4: 0.2625, 5: 0.3232323232323232, 6: 0.20987654320987653, 7: 0.47692307692307695, 8: 0.21212121212121215, 9: 0.23809523809523808}, 'accuracy_score': {'overall': 0.332, 0: 0.4166666666666667, 1: 0.3870967741935484, 2: 0.3076923076923077, 3: 0.26666666666666666, 4: 0.2222222222222222, 5: 0.3, 6: 0.16666666666666666, 7: 0.4642857142857143, 8: 0.3448275862068966, 9: 0.34615384615384615}}
{'accuracy_score': {'overall': 0.332, 0: 0.4166666666666667, 1: 0.3870967741935484, 2: 0.3076923076923077, 3: 0.26666666666666666, 4: 0.2222222222222222, 5: 0.3, 6: 0.16666666666666666, 7: 0.4642857142857143, 8: 0.3448275862068966, 9: 0.34615384615384615}, 'roc_auc_score': {'overall': 0.31070696721311475, 0: 0.20714285714285713, 1: 0.40454545454545454, 2: 0.34911242603550297, 3: 0.27777777777777773, 4: 0.2625, 5: 0.3232323232323232, 6: 0.20987654320987653, 7: 0.47692307692307695, 8: 0.21212121212121215, 9: 0.23809523809523808}}
Total running time of the script: (0 minutes 2.199 seconds)