Stroke Classification ResNet Model

This is a Keras model for classifying MRI images into:

  • Hemorrhagic Stroke
  • Ischemic Stroke
  • No Stroke

The model is a fine-tuned ResNet50.

How to use this model (example in Python)

import tensorflow as tf
from PIL import Image
import numpy as np
import os # Added for os.path.join

# Load the model
# Ensure the model file 'stroke_classification_model.h5' is in the same directory
# or provide the full path.
model = tf.keras.models.load_model('stroke_classification_model.h5')

# Define your class names (must match how your model was trained)
CLASS_NAMES = ['hemorrhagic_stroke', 'ischemic_stroke', 'no_stroke'] # Automatically populated from your Colab session

def preprocess_image_for_prediction(image_path, target_size=(224, 224), pixel_threshold=40):
    img = Image.open(image_path).convert("L")
    original_width, original_height = img.size
    data = np.array(img)
    rows_with_content = np.any(data > pixel_threshold, axis=1)
    cols_with_content = np.any(data > pixel_threshold, axis=0)
    try:
        min_row = np.where(rows_with_content)[0][0]
        max_row = np.where(rows_with_content)[0][-1]
        min_col = np.where(cols_with_content)[0][0]
        max_col = np.where(cols_with_content)[0][-1]
    except IndexError:
        cropped_img = img
    else:
        buffer = 5
        min_row = max(0, min_row - buffer)
        max_row = min(original_height - 1, max_row + buffer)
        min_col = max(0, min_col - buffer)
        max_col = min(original_width - 1, max_col + buffer)
        cropped_img = img.crop((min_col, min_row, max_col + 1, max_row + 1))
    processed_img = cropped_img.resize(target_size, Image.LANCZOS)
    if processed_img.mode == 'L':
        processed_img = processed_img.convert('RGB')
    img_array = tf.keras.utils.img_to_array(processed_img)
    img_array = tf.expand_dims(img_array, 0)
    return img_array

# Example usage:
# image_path = "path/to/your/new_mri_image.jpg"
# preprocessed_img = preprocess_image_for_prediction(image_path)
# if preprocessed_img is not None:
#     predictions = model.predict(preprocessed_img)
#     predicted_class_index = np.argmax(predictions[0])
#     predicted_class_name = CLASS_NAMES[predicted_class_index]
#     confidence = np.max(predictions[0]) * 100
#     print(f"Predicted: {predicted_class_name} with {confidence:.2f}% confidence")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support