|
|
"""Linear classification heads for polyreactivity prediction.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any |
|
|
|
|
|
import numpy as np |
|
|
from sklearn.linear_model import LogisticRegression |
|
|
from sklearn.svm import LinearSVC |
|
|
|
|
|
|
|
|
@dataclass(slots=True) |
|
|
class LinearModelConfig: |
|
|
"""Configuration options for linear heads.""" |
|
|
|
|
|
head: str = "logreg" |
|
|
C: float = 1.0 |
|
|
class_weight: Any = "balanced" |
|
|
max_iter: int = 1000 |
|
|
|
|
|
|
|
|
@dataclass(slots=True) |
|
|
class TrainedModel: |
|
|
"""Container for trained estimators and optional calibration.""" |
|
|
|
|
|
estimator: Any |
|
|
calibrator: Any | None = None |
|
|
vectorizer_name: str = "" |
|
|
feature_meta: dict[str, Any] = field(default_factory=dict) |
|
|
metrics_cv: dict[str, float] = field(default_factory=dict) |
|
|
|
|
|
def predict(self, X: np.ndarray) -> np.ndarray: |
|
|
if self.calibrator is not None and hasattr(self.calibrator, "predict"): |
|
|
return self.calibrator.predict(X) |
|
|
return self.estimator.predict(X) |
|
|
|
|
|
def predict_proba(self, X: np.ndarray) -> np.ndarray: |
|
|
if self.calibrator is not None and hasattr(self.calibrator, "predict_proba"): |
|
|
probs = self.calibrator.predict_proba(X) |
|
|
return probs[:, 1] |
|
|
if hasattr(self.estimator, "predict_proba"): |
|
|
probs = self.estimator.predict_proba(X) |
|
|
return probs[:, 1] |
|
|
if hasattr(self.estimator, "decision_function"): |
|
|
scores = self.estimator.decision_function(X) |
|
|
return 1.0 / (1.0 + np.exp(-scores)) |
|
|
msg = "Estimator does not support probability prediction" |
|
|
raise AttributeError(msg) |
|
|
|
|
|
|
|
|
def build_estimator( |
|
|
*, config: LinearModelConfig, random_state: int | None = 42 |
|
|
) -> Any: |
|
|
"""Construct an unfitted linear estimator based on configuration.""" |
|
|
|
|
|
if config.head == "logreg": |
|
|
estimator = LogisticRegression( |
|
|
C=config.C, |
|
|
max_iter=config.max_iter, |
|
|
class_weight=config.class_weight, |
|
|
solver="liblinear", |
|
|
random_state=random_state, |
|
|
) |
|
|
elif config.head == "linear_svm": |
|
|
estimator = LinearSVC( |
|
|
C=config.C, |
|
|
class_weight=config.class_weight, |
|
|
max_iter=config.max_iter, |
|
|
random_state=random_state, |
|
|
) |
|
|
else: |
|
|
msg = f"Unsupported head type: {config.head}" |
|
|
raise ValueError(msg) |
|
|
return estimator |
|
|
|
|
|
|
|
|
def train_linear_model( |
|
|
X: np.ndarray, |
|
|
y: np.ndarray, |
|
|
*, |
|
|
config: LinearModelConfig, |
|
|
random_state: int | None = 42, |
|
|
) -> TrainedModel: |
|
|
"""Fit a linear classifier on the provided feature matrix.""" |
|
|
|
|
|
estimator = build_estimator(config=config, random_state=random_state) |
|
|
if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000: |
|
|
estimator.set_params(solver="lbfgs") |
|
|
estimator.fit(X, y) |
|
|
return TrainedModel(estimator=estimator) |
|
|
|