Bio_ClinicalBERT_MIMIC_IV_death_in_7_prediction_lora_ti

This model is designed to predict 7-day mortality upon hospital discharge. It is trained on discharge notes from the MIMIC-IV dataset, which comprises of open-sourced Electronic Health Records (EHRs). Model was trained on a novel tabular-infused LoRA, whereby the pre-operative tabular features (e.g., patient demographics and insurance information) were used to initialize the newly introduced LORa parameters, instead of initializing them randomly.

Model Details

How to use model

from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("cja5553/Bio_ClinicalBERT_MIMIC_IV_death_in_7_prediction_lora_ti")
model = AutoModelForSequenceClassification.from_pretrained("cja5553/Bio_ClinicalBERT_MIMIC_IV_death_in_7_prediction_lora_ti")

Then you can use this function below to get one test point

import torch

def get_outcome(tokenizer, model, text, device="cuda:0", max_length=512):

    device = torch.device(device)
    model = model.to(device)
    model.eval()

    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=max_length,
        truncation=True,
        padding="max_length"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1)[0]  # (2,)

    probs = probs.detach().cpu().numpy()
    result = {
        "False": float(probs[0]),
        "True": float(probs[1])
    }

    return result

Questions?

Contact me at alba@wustl.edu

Downloads last month
18
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for cja5553/Bio_ClinicalBERT_MIMIC_IV_death_in_7_prediction_lora_ti

Adapter
(18)
this model

Collection including cja5553/Bio_ClinicalBERT_MIMIC_IV_death_in_7_prediction_lora_ti