cyrilvallez HF Staff commited on
Commit
a4fcf5b
·
verified ·
1 Parent(s): 245ec5a

Fix cache format

Browse files

Transformers now only uses the Cache classes

Files changed (1) hide show
  1. custom_generate/generate.py +2 -4
custom_generate/generate.py CHANGED
@@ -282,10 +282,8 @@ def _contrastive_search(
282
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
283
  "for contrastive search."
284
  )
285
- elif (
286
- not isinstance(past_key_values[0], (tuple, torch.Tensor))
287
- or past_key_values[0][0].shape[0] != batch_size
288
- ):
289
  raise ValueError(
290
  f"{model.__class__.__name__} does not have a standard cache format and therefore **can't** be "
291
  "used for contrastive search without further modifications."
 
282
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
283
  "for contrastive search."
284
  )
285
+ # We now only use Cache classes, but a few models have custom cache class, so we use this check instead of an instance check
286
+ elif not hasattr(past_key_values, "update"):
 
 
287
  raise ValueError(
288
  f"{model.__class__.__name__} does not have a standard cache format and therefore **can't** be "
289
  "used for contrastive search without further modifications."