Upload predict.py
Browse files- predict.py +7 -1
predict.py
CHANGED
|
@@ -333,6 +333,7 @@ def predict(model, text, tokenizer=None,
|
|
| 333 |
sft=True, convo_template = "",
|
| 334 |
device = "cuda",
|
| 335 |
model_name="AquilaChat2-7B",
|
|
|
|
| 336 |
**kwargs):
|
| 337 |
|
| 338 |
vocab = tokenizer.get_vocab()
|
|
@@ -352,7 +353,7 @@ def predict(model, text, tokenizer=None,
|
|
| 352 |
topk = 1
|
| 353 |
temperature = 1.0
|
| 354 |
if sft:
|
| 355 |
-
tokens = covert_prompt_to_input_ids_with_history(text, history=
|
| 356 |
tokens = torch.tensor(tokens)[None,].to(device)
|
| 357 |
else :
|
| 358 |
tokens = tokenizer.encode_plus(text)["input_ids"]
|
|
@@ -433,4 +434,9 @@ def predict(model, text, tokenizer=None,
|
|
| 433 |
|
| 434 |
convert_tokens = convert_tokens[1:]
|
| 435 |
probs = probs[1:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
return out
|
|
|
|
| 333 |
sft=True, convo_template = "",
|
| 334 |
device = "cuda",
|
| 335 |
model_name="AquilaChat2-7B",
|
| 336 |
+
history=[],
|
| 337 |
**kwargs):
|
| 338 |
|
| 339 |
vocab = tokenizer.get_vocab()
|
|
|
|
| 353 |
topk = 1
|
| 354 |
temperature = 1.0
|
| 355 |
if sft:
|
| 356 |
+
tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=2048, convo_template=convo_template)
|
| 357 |
tokens = torch.tensor(tokens)[None,].to(device)
|
| 358 |
else :
|
| 359 |
tokens = tokenizer.encode_plus(text)["input_ids"]
|
|
|
|
| 434 |
|
| 435 |
convert_tokens = convert_tokens[1:]
|
| 436 |
probs = probs[1:]
|
| 437 |
+
|
| 438 |
+
# Update history
|
| 439 |
+
history.insert(0, ('USER', text))
|
| 440 |
+
history.insert(0, ('ASSISTANT', out))
|
| 441 |
+
|
| 442 |
return out
|