remove dependencies on deepspeed and wandb
Browse files- modeling_magma.py +0 -48
modeling_magma.py
CHANGED
|
@@ -24,7 +24,6 @@ import numpy as np
|
|
| 24 |
import torch
|
| 25 |
import torch.utils.checkpoint
|
| 26 |
from torch import nn
|
| 27 |
-
import wandb
|
| 28 |
import torch.distributed as dist
|
| 29 |
from transformers.modeling_utils import PreTrainedModel
|
| 30 |
from transformers.activations import ACT2FN
|
|
@@ -282,12 +281,6 @@ class MagmaForCausalLM(MagmaPreTrainedModel):
|
|
| 282 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
| 283 |
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
| 284 |
|
| 285 |
-
try:
|
| 286 |
-
if dist.get_rank() == 0:
|
| 287 |
-
wandb.init(project=os.environ['WANDB_PROJECT'])
|
| 288 |
-
except:
|
| 289 |
-
pass
|
| 290 |
-
|
| 291 |
self.post_init()
|
| 292 |
|
| 293 |
# def from_pretrained(self, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
@@ -325,40 +318,6 @@ class MagmaForCausalLM(MagmaPreTrainedModel):
|
|
| 325 |
|
| 326 |
def tie_weights(self):
|
| 327 |
return self.language_model.tie_weights()
|
| 328 |
-
|
| 329 |
-
def load_special_module_from_ckpt(self, ckpt_path, torch_dtype=None):
|
| 330 |
-
from deepspeed.runtime.zero import Init
|
| 331 |
-
from deepspeed import zero
|
| 332 |
-
# Defer initialization for ZeRO-3 compatibility
|
| 333 |
-
# with Init(data_parallel_group=None):
|
| 334 |
-
# # Initialize the special module
|
| 335 |
-
# self.vision_tower = MagmaImageTower(self.config.vision_config, require_pretrained=False)
|
| 336 |
-
|
| 337 |
-
# Load checkpoint weights into the special module
|
| 338 |
-
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
| 339 |
-
state_dict = {k.replace('visual.', ''): v for k, v in checkpoint.items() if 'visual.' in k}
|
| 340 |
-
|
| 341 |
-
# Convert checkpoint weights to match model's parameter dtype
|
| 342 |
-
if torch_dtype is None:
|
| 343 |
-
model_dtype = next(self.vision_tower.clip_vision_model.parameters()).dtype
|
| 344 |
-
for k, v in state_dict.items():
|
| 345 |
-
state_dict[k] = v.to(model_dtype)
|
| 346 |
-
else:
|
| 347 |
-
for k, v in state_dict.items():
|
| 348 |
-
state_dict[k] = v.to(torch_dtype)
|
| 349 |
-
|
| 350 |
-
# Temporarily gather parameters for loading (if ZeRO-3 is active)
|
| 351 |
-
with zero.GatheredParameters(list(self.vision_tower.parameters()), modifier_rank=0):
|
| 352 |
-
# Load the state dictionary
|
| 353 |
-
self.vision_tower.clip_vision_model.load_state_dict(state_dict, strict=False)
|
| 354 |
-
# After loading, ensure the module is on the correct device
|
| 355 |
-
for param in self.vision_tower.parameters():
|
| 356 |
-
param.data = param.data.to(self.device).to(torch_dtype)
|
| 357 |
-
|
| 358 |
-
# import pdb; pdb.set_trace()
|
| 359 |
-
# If using a DeepSpeed engine, attach the updated module
|
| 360 |
-
if hasattr(self, "deepspeed_engine"):
|
| 361 |
-
self.deepspeed_engine.module = self
|
| 362 |
|
| 363 |
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
| 364 |
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
|
@@ -832,13 +791,6 @@ class MagmaForCausalLM(MagmaPreTrainedModel):
|
|
| 832 |
# concatenate the action accuracy across all devices
|
| 833 |
action_accuracy = torch.cat(action_accuracy_gather)
|
| 834 |
|
| 835 |
-
if dist.get_rank() == 0:
|
| 836 |
-
# remove zero values
|
| 837 |
-
if action_accuracy.mean() == 0:
|
| 838 |
-
wandb.log({"action_accuracy": action_accuracy.mean().item()})
|
| 839 |
-
else:
|
| 840 |
-
action_accuracy = action_accuracy[action_accuracy != 0]
|
| 841 |
-
wandb.log({"action_accuracy": action_accuracy.mean().item()})
|
| 842 |
else:
|
| 843 |
logits = self.language_model.lm_head(hidden_states)
|
| 844 |
logits = logits.float()
|
|
|
|
| 24 |
import torch
|
| 25 |
import torch.utils.checkpoint
|
| 26 |
from torch import nn
|
|
|
|
| 27 |
import torch.distributed as dist
|
| 28 |
from transformers.modeling_utils import PreTrainedModel
|
| 29 |
from transformers.activations import ACT2FN
|
|
|
|
| 281 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
| 282 |
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
self.post_init()
|
| 285 |
|
| 286 |
# def from_pretrained(self, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
|
|
| 318 |
|
| 319 |
def tie_weights(self):
|
| 320 |
return self.language_model.tie_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
| 323 |
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
|
|
|
| 791 |
# concatenate the action accuracy across all devices
|
| 792 |
action_accuracy = torch.cat(action_accuracy_gather)
|
| 793 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 794 |
else:
|
| 795 |
logits = self.language_model.lm_head(hidden_states)
|
| 796 |
logits = logits.float()
|