You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

Model Card for smb-vision-v0-risk

This model card is for a variant of standardmodelbio/smb-vision-v0-mim adapted for survival regression tasks. It pairs the powerful, pre-trained SMB-Vision-v0 vision backbone with a randomly initialized survival regression head.

The model is designed to be fine-tuned on downstream time-to-event datasets. The core methodology is to freeze the vision backbone and train only the lightweight survival head, enabling efficient adaptation to specific tasks like predicting patient survival or disease progression from medical scans.

We provide a fully end-to-end pipeline to finetune our SMB-Vision series model on your own downstream tasks on Google Colab.

Model Description

  • Backbone: The model uses the SMB-Vision-v0 backbone, a Transformer-based architecture pre-trained on a large, diverse dataset of radiology images (CT, MRI, X-ray) using self-supervised learning. Its parameters are frozen and are not intended to be updated during fine-tuning.
  • Head: A linear regression head is attached to the backbone's pooled output. It is randomly initialized and must be trained on a user-provided dataset. It outputs a single scalar value per input study, which represents the log-risk for survival analysis.
  • Objective: To provide a strong starting point for survival analysis in medical imaging. By leveraging the rich, general-purpose features of the frozen backbone, users can train effective models with smaller datasets and less computational overhead compared to end-to-end fine-tuning.

Intended Use

This model is intended for researchers and developers to build and evaluate survival prediction models from medical imaging data. The primary use case is to fine-tune the survival regression head on a specific task with time-to-event labels.

Example Applications:

  • Predicting overall patient survival from 3D CT scans of tumors.
  • Estimating time-to-disease progression from a series of brain MRIs.
  • Modeling the risk of an adverse event over time from chest X-rays.

Users are responsible for providing their own labeled dataset (images, event times, and event status) and implementing an appropriate survival loss function (e.g., negative Cox partial log-likelihood) to train the regression head.


How to Use

The following example demonstrates how to load the model, freeze the backbone, and perform a forward pass to get log-risk predictions.

Installation

First, ensure you have the necessary libraries installed.

pip install -q -U transformers "torch==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1"
pip install nibabel monai lifelines
pip install "git+https://github.com/standardmodelbio/smb-biopan-utils.git"

Fine-tuning Example

import torch
from transformers import AutoModelForSequenceClassification

# 1. Load the survival regression model from the Hugging Face Hub
#    `trust_remote_code=True` is required to load the SMB-Vision architecture.
model = AutoModelForSequenceClassification.from_pretrained(
    "standardmodelbio/smb-vision-v0-risk", # Replace with the actual HF repo name for the survival model
    trust_remote_code=True,
    dtype=torch.bfloat16,
    # attn_implementation="flash_attention_2"
)

# 2. Freeze the parameters of the vision backbone
#    This is the key step for efficient fine-tuning.
for param in model.encoder.parameters():
    param.requires_grad = False

# Ensure the classification head (used here for regression) is trainable
for param in model.classifier.parameters():
    param.requires_grad = True

# 3. Prepare dummy input data
#    - `images` are the flattened 3D patch tokens.
#    - `grid_thw` describes the 3D grid dimensions of the patches for each study.
#    This example simulates a batch of two studies.
#    In a real application, you would generate these from your NIfTI/DICOM files.
num_patches_study1 = 4 * 10 * 10  # T*H*W
num_patches_study2 = 5 * 12 * 12  # T*H*W
patch_embedding_dim = 4096

images = torch.randn(num_patches_study1 + num_patches_study2, patch_embedding_dim)
grid_thw = torch.tensor([
    [4, 10, 10], # Grid for study 1
    [5, 12, 12]  # Grid for study 2
])

# 4. Move model and data to GPU, if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
images = images.to(device)
grid_thw = grid_thw.to(device)

# 5. Perform a forward pass to get the log-risk predictions
#    The model outputs logits, which in this case represent the predicted log-risk.
with torch.no_grad():
  risk = model(hidden_states=images, grid_thw=grid_thw)

print(f"Log-risk predictions shape: {risk[0].shape}")
# Expected output: Log-risk predictions shape: torch.Size([2, 1])
print(f"Predicted log-risks:\n{log_risk}")

# In a real training loop, you would now use `log_risk` along with your
# ground-truth event times and event statuses to compute a survival loss
# (e.g., Cox partial log-likelihood) and run the optimizer.

Limitations and Bias

  • Not a Clinical Tool: This model is provided for research purposes only and is not a certified medical device. It should not be used for clinical decision-making.
  • Dataset Dependency: The performance of the fine-tuned model is highly dependent on the quality, size, and characteristics of the user-provided training data.
  • Inherited Bias: The model may reflect biases present in the backbone's pre-training data. Performance may vary across different patient demographics, imaging protocols, or scanner manufacturers.
  • Frozen Backbone: While efficient, the frozen backbone approach may be suboptimal for downstream tasks that require visual features significantly different from those learned during the initial self-supervised pre-training.

Citation

If you use SMB-Vision in your research, please cite the original repository.

Downloads last month
191
Safetensors
Model size
0.6B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Collection including standardmodelbio/smb-vision-v0-risk