heerjtdev commited on
Commit
255ee09
·
verified ·
1 Parent(s): 676a735

Update working_yolo_pipeline.py

Browse files
Files changed (1) hide show
  1. working_yolo_pipeline.py +16 -1
working_yolo_pipeline.py CHANGED
@@ -190,12 +190,20 @@ class MCQTagger(nn.Module):
190
  self.spatial_proj = nn.Sequential(nn.Linear(11, SPATIAL_FEATURE_DIM), nn.ReLU(), nn.Dropout(0.1))
191
  self.context_proj = nn.Sequential(nn.Linear(8, 32), nn.ReLU(), nn.Dropout(0.1))
192
  self.positional_encoding = nn.Embedding(512, POSITIONAL_DIM)
 
193
  in_dim = (EMBED_DIM + self.char_enc.out_dim + BBOX_DIM + SPATIAL_FEATURE_DIM + 32 + POSITIONAL_DIM)
 
194
  self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=3, batch_first=True, bidirectional=True, dropout=0.3)
195
  self.spatial_attention = SpatialAttention(HIDDEN_SIZE)
 
 
 
 
 
196
  self.ff = nn.Sequential(nn.Linear(HIDDEN_SIZE * 2, HIDDEN_SIZE), nn.ReLU(), nn.Dropout(0.3), nn.Linear(HIDDEN_SIZE, n_labels))
197
  self.crf = CRF(n_labels)
198
  self.dropout = nn.Dropout(p=0.5)
 
199
  def forward(self, words, chars, bboxes, spatial_feats, context_feats, mask):
200
  B, L = words.size()
201
  wemb = self.word_emb(words)
@@ -203,17 +211,24 @@ class MCQTagger(nn.Module):
203
  benc = self.bbox_proj(bboxes)
204
  senc = self.spatial_proj(spatial_feats)
205
  cxt_enc = self.context_proj(context_feats)
 
206
  pos = torch.arange(L, device=words.device).unsqueeze(0).expand(B, -1)
207
  pos_enc = self.positional_encoding(pos.clamp(max=511))
 
208
  enc_in = self.dropout(torch.cat([wemb, cenc, benc, senc, cxt_enc, pos_enc], dim=-1))
 
209
  lengths = mask.sum(dim=1).cpu()
210
  packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
211
  packed_out, _ = self.bilstm(packed_in)
212
  lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
 
213
  attn_out = self.spatial_attention(lstm_out, mask)
 
 
 
 
214
  emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
215
  return self.crf.viterbi_decode(emissions, mask=mask)
216
-
217
  # --- INJECT DEPENDENCIES FOR PICKLE LOADING ---
218
  import sys
219
  from types import ModuleType
 
190
  self.spatial_proj = nn.Sequential(nn.Linear(11, SPATIAL_FEATURE_DIM), nn.ReLU(), nn.Dropout(0.1))
191
  self.context_proj = nn.Sequential(nn.Linear(8, 32), nn.ReLU(), nn.Dropout(0.1))
192
  self.positional_encoding = nn.Embedding(512, POSITIONAL_DIM)
193
+
194
  in_dim = (EMBED_DIM + self.char_enc.out_dim + BBOX_DIM + SPATIAL_FEATURE_DIM + 32 + POSITIONAL_DIM)
195
+
196
  self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=3, batch_first=True, bidirectional=True, dropout=0.3)
197
  self.spatial_attention = SpatialAttention(HIDDEN_SIZE)
198
+
199
+ # --- FIX: ADD THIS LINE TO MATCH SAVED MODEL ---
200
+ self.layer_norm = nn.LayerNorm(HIDDEN_SIZE)
201
+ # -----------------------------------------------
202
+
203
  self.ff = nn.Sequential(nn.Linear(HIDDEN_SIZE * 2, HIDDEN_SIZE), nn.ReLU(), nn.Dropout(0.3), nn.Linear(HIDDEN_SIZE, n_labels))
204
  self.crf = CRF(n_labels)
205
  self.dropout = nn.Dropout(p=0.5)
206
+
207
  def forward(self, words, chars, bboxes, spatial_feats, context_feats, mask):
208
  B, L = words.size()
209
  wemb = self.word_emb(words)
 
211
  benc = self.bbox_proj(bboxes)
212
  senc = self.spatial_proj(spatial_feats)
213
  cxt_enc = self.context_proj(context_feats)
214
+
215
  pos = torch.arange(L, device=words.device).unsqueeze(0).expand(B, -1)
216
  pos_enc = self.positional_encoding(pos.clamp(max=511))
217
+
218
  enc_in = self.dropout(torch.cat([wemb, cenc, benc, senc, cxt_enc, pos_enc], dim=-1))
219
+
220
  lengths = mask.sum(dim=1).cpu()
221
  packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
222
  packed_out, _ = self.bilstm(packed_in)
223
  lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
224
+
225
  attn_out = self.spatial_attention(lstm_out, mask)
226
+
227
+ # Note: Even if layer_norm isn't explicitly used in the forward pass logic here,
228
+ # it must be defined in __init__ to satisfy the strict state_dict loading.
229
+
230
  emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
231
  return self.crf.viterbi_decode(emissions, mask=mask)
 
232
  # --- INJECT DEPENDENCIES FOR PICKLE LOADING ---
233
  import sys
234
  from types import ModuleType