Update inference.py
Browse files- inference.py +3 -5
inference.py
CHANGED
|
@@ -2,10 +2,8 @@ import streamlit as st
|
|
| 2 |
import torch
|
| 3 |
from transformers import BertForTokenClassification, BertTokenizerFast # Import BertTokenizerFast
|
| 4 |
|
| 5 |
-
def load_model(model_name='
|
| 6 |
-
|
| 7 |
-
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=2)
|
| 8 |
-
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| 9 |
model.eval() # Set the model to inference mode
|
| 10 |
return model
|
| 11 |
|
|
@@ -45,7 +43,7 @@ def predict_and_annotate(model, tokenizer, text):
|
|
| 45 |
st.title("BERT Token Classification for Anchor Text Prediction")
|
| 46 |
|
| 47 |
# Load the model and tokenizer
|
| 48 |
-
model = load_model('
|
| 49 |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') # Use BertTokenizerFast
|
| 50 |
|
| 51 |
# User input text area
|
|
|
|
| 2 |
import torch
|
| 3 |
from transformers import BertForTokenClassification, BertTokenizerFast # Import BertTokenizerFast
|
| 4 |
|
| 5 |
+
def load_model(model_name='dejanseo/LinkBERT'):
|
| 6 |
+
model = BertForTokenClassification.from_pretrained(model_name, num_labels=2)
|
|
|
|
|
|
|
| 7 |
model.eval() # Set the model to inference mode
|
| 8 |
return model
|
| 9 |
|
|
|
|
| 43 |
st.title("BERT Token Classification for Anchor Text Prediction")
|
| 44 |
|
| 45 |
# Load the model and tokenizer
|
| 46 |
+
model = load_model('dejanseo/LinkBERT')
|
| 47 |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') # Use BertTokenizerFast
|
| 48 |
|
| 49 |
# User input text area
|