Spaces:
Running
Running
Update working_yolo_pipeline.py
Browse files- working_yolo_pipeline.py +16 -1
working_yolo_pipeline.py
CHANGED
|
@@ -190,12 +190,20 @@ class MCQTagger(nn.Module):
|
|
| 190 |
self.spatial_proj = nn.Sequential(nn.Linear(11, SPATIAL_FEATURE_DIM), nn.ReLU(), nn.Dropout(0.1))
|
| 191 |
self.context_proj = nn.Sequential(nn.Linear(8, 32), nn.ReLU(), nn.Dropout(0.1))
|
| 192 |
self.positional_encoding = nn.Embedding(512, POSITIONAL_DIM)
|
|
|
|
| 193 |
in_dim = (EMBED_DIM + self.char_enc.out_dim + BBOX_DIM + SPATIAL_FEATURE_DIM + 32 + POSITIONAL_DIM)
|
|
|
|
| 194 |
self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=3, batch_first=True, bidirectional=True, dropout=0.3)
|
| 195 |
self.spatial_attention = SpatialAttention(HIDDEN_SIZE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
self.ff = nn.Sequential(nn.Linear(HIDDEN_SIZE * 2, HIDDEN_SIZE), nn.ReLU(), nn.Dropout(0.3), nn.Linear(HIDDEN_SIZE, n_labels))
|
| 197 |
self.crf = CRF(n_labels)
|
| 198 |
self.dropout = nn.Dropout(p=0.5)
|
|
|
|
| 199 |
def forward(self, words, chars, bboxes, spatial_feats, context_feats, mask):
|
| 200 |
B, L = words.size()
|
| 201 |
wemb = self.word_emb(words)
|
|
@@ -203,17 +211,24 @@ class MCQTagger(nn.Module):
|
|
| 203 |
benc = self.bbox_proj(bboxes)
|
| 204 |
senc = self.spatial_proj(spatial_feats)
|
| 205 |
cxt_enc = self.context_proj(context_feats)
|
|
|
|
| 206 |
pos = torch.arange(L, device=words.device).unsqueeze(0).expand(B, -1)
|
| 207 |
pos_enc = self.positional_encoding(pos.clamp(max=511))
|
|
|
|
| 208 |
enc_in = self.dropout(torch.cat([wemb, cenc, benc, senc, cxt_enc, pos_enc], dim=-1))
|
|
|
|
| 209 |
lengths = mask.sum(dim=1).cpu()
|
| 210 |
packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
|
| 211 |
packed_out, _ = self.bilstm(packed_in)
|
| 212 |
lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
|
|
|
|
| 213 |
attn_out = self.spatial_attention(lstm_out, mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
|
| 215 |
return self.crf.viterbi_decode(emissions, mask=mask)
|
| 216 |
-
|
| 217 |
# --- INJECT DEPENDENCIES FOR PICKLE LOADING ---
|
| 218 |
import sys
|
| 219 |
from types import ModuleType
|
|
|
|
| 190 |
self.spatial_proj = nn.Sequential(nn.Linear(11, SPATIAL_FEATURE_DIM), nn.ReLU(), nn.Dropout(0.1))
|
| 191 |
self.context_proj = nn.Sequential(nn.Linear(8, 32), nn.ReLU(), nn.Dropout(0.1))
|
| 192 |
self.positional_encoding = nn.Embedding(512, POSITIONAL_DIM)
|
| 193 |
+
|
| 194 |
in_dim = (EMBED_DIM + self.char_enc.out_dim + BBOX_DIM + SPATIAL_FEATURE_DIM + 32 + POSITIONAL_DIM)
|
| 195 |
+
|
| 196 |
self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=3, batch_first=True, bidirectional=True, dropout=0.3)
|
| 197 |
self.spatial_attention = SpatialAttention(HIDDEN_SIZE)
|
| 198 |
+
|
| 199 |
+
# --- FIX: ADD THIS LINE TO MATCH SAVED MODEL ---
|
| 200 |
+
self.layer_norm = nn.LayerNorm(HIDDEN_SIZE)
|
| 201 |
+
# -----------------------------------------------
|
| 202 |
+
|
| 203 |
self.ff = nn.Sequential(nn.Linear(HIDDEN_SIZE * 2, HIDDEN_SIZE), nn.ReLU(), nn.Dropout(0.3), nn.Linear(HIDDEN_SIZE, n_labels))
|
| 204 |
self.crf = CRF(n_labels)
|
| 205 |
self.dropout = nn.Dropout(p=0.5)
|
| 206 |
+
|
| 207 |
def forward(self, words, chars, bboxes, spatial_feats, context_feats, mask):
|
| 208 |
B, L = words.size()
|
| 209 |
wemb = self.word_emb(words)
|
|
|
|
| 211 |
benc = self.bbox_proj(bboxes)
|
| 212 |
senc = self.spatial_proj(spatial_feats)
|
| 213 |
cxt_enc = self.context_proj(context_feats)
|
| 214 |
+
|
| 215 |
pos = torch.arange(L, device=words.device).unsqueeze(0).expand(B, -1)
|
| 216 |
pos_enc = self.positional_encoding(pos.clamp(max=511))
|
| 217 |
+
|
| 218 |
enc_in = self.dropout(torch.cat([wemb, cenc, benc, senc, cxt_enc, pos_enc], dim=-1))
|
| 219 |
+
|
| 220 |
lengths = mask.sum(dim=1).cpu()
|
| 221 |
packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
|
| 222 |
packed_out, _ = self.bilstm(packed_in)
|
| 223 |
lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
|
| 224 |
+
|
| 225 |
attn_out = self.spatial_attention(lstm_out, mask)
|
| 226 |
+
|
| 227 |
+
# Note: Even if layer_norm isn't explicitly used in the forward pass logic here,
|
| 228 |
+
# it must be defined in __init__ to satisfy the strict state_dict loading.
|
| 229 |
+
|
| 230 |
emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
|
| 231 |
return self.crf.viterbi_decode(emissions, mask=mask)
|
|
|
|
| 232 |
# --- INJECT DEPENDENCIES FOR PICKLE LOADING ---
|
| 233 |
import sys
|
| 234 |
from types import ModuleType
|