| | """ |
| | Client test. |
| | |
| | Run server: |
| | |
| | python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b |
| | |
| | NOTE: For private models, add --use-auth_token=True |
| | |
| | NOTE: --use_gpu_id=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches. |
| | Currently, this will force model to be on a single GPU. |
| | |
| | Then run this client as: |
| | |
| | python src/client_test.py |
| | |
| | |
| | |
| | For HF spaces: |
| | |
| | HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py |
| | |
| | Result: |
| | |
| | Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔ |
| | {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''} |
| | |
| | |
| | For demo: |
| | |
| | HOST="https://gpt.h2o.ai" python src/client_test.py |
| | |
| | Result: |
| | |
| | Loaded as API: https://gpt.h2o.ai ✔ |
| | {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''} |
| | |
| | NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict: |
| | |
| | {'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''} |
| | |
| | |
| | """ |
| | import ast |
| | import time |
| | import os |
| | import markdown |
| | import pytest |
| | from bs4 import BeautifulSoup |
| |
|
| | from enums import DocumentSubset, LangChainAction |
| |
|
| | debug = False |
| |
|
| | os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' |
| |
|
| |
|
| | def get_client(serialize=True): |
| | from gradio_client import Client |
| |
|
| | client = Client(os.getenv('HOST', "http://localhost:7860"), serialize=serialize) |
| | if debug: |
| | print(client.view_api(all_endpoints=True)) |
| | return client |
| |
|
| |
|
| | def get_args(prompt, prompt_type, chat=False, stream_output=False, |
| | max_new_tokens=50, |
| | top_k_docs=3, |
| | langchain_mode='Disabled', |
| | add_chat_history_to_context=True, |
| | langchain_action=LangChainAction.QUERY.value, |
| | langchain_agents=[], |
| | prompt_dict=None): |
| | from collections import OrderedDict |
| | kwargs = OrderedDict(instruction=prompt if chat else '', |
| | iinput='', |
| | context='', |
| | |
| | |
| | stream_output=stream_output, |
| | prompt_type=prompt_type, |
| | prompt_dict=prompt_dict, |
| | temperature=0.1, |
| | top_p=0.75, |
| | top_k=40, |
| | num_beams=1, |
| | max_new_tokens=max_new_tokens, |
| | min_new_tokens=0, |
| | early_stopping=False, |
| | max_time=20, |
| | repetition_penalty=1.0, |
| | num_return_sequences=1, |
| | do_sample=True, |
| | chat=chat, |
| | instruction_nochat=prompt if not chat else '', |
| | iinput_nochat='', |
| | langchain_mode=langchain_mode, |
| | add_chat_history_to_context=add_chat_history_to_context, |
| | langchain_action=langchain_action, |
| | langchain_agents=langchain_agents, |
| | top_k_docs=top_k_docs, |
| | chunk=True, |
| | chunk_size=512, |
| | document_subset=DocumentSubset.Relevant.name, |
| | document_choice=[], |
| | ) |
| | from evaluate_params import eval_func_param_names |
| | assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0 |
| | if chat: |
| | |
| | kwargs.update(dict(chatbot=[])) |
| |
|
| | return kwargs, list(kwargs.values()) |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_basic(prompt_type='human_bot'): |
| | return run_client_nochat(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) |
| |
|
| |
|
| | def run_client_nochat(prompt, prompt_type, max_new_tokens): |
| | kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens) |
| |
|
| | api_name = '/submit_nochat' |
| | client = get_client(serialize=True) |
| | res = client.predict( |
| | *tuple(args), |
| | api_name=api_name, |
| | ) |
| | print("Raw client result: %s" % res, flush=True) |
| | res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'], |
| | response=md_to_text(res)) |
| | print(res_dict) |
| | return res_dict, client |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_basic_api(prompt_type='human_bot'): |
| | return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) |
| |
|
| |
|
| | def run_client_nochat_api(prompt, prompt_type, max_new_tokens): |
| | kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens) |
| |
|
| | api_name = '/submit_nochat_api' |
| | client = get_client(serialize=True) |
| | res = client.predict( |
| | str(dict(kwargs)), |
| | api_name=api_name, |
| | ) |
| | print("Raw client result: %s" % res, flush=True) |
| | res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'], |
| | response=md_to_text(ast.literal_eval(res)['response']), |
| | sources=ast.literal_eval(res)['sources']) |
| | print(res_dict) |
| | return res_dict, client |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_basic_api_lean(prompt_type='human_bot'): |
| | return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) |
| |
|
| |
|
| | def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens): |
| | kwargs = dict(instruction_nochat=prompt) |
| |
|
| | api_name = '/submit_nochat_api' |
| | client = get_client(serialize=True) |
| | res = client.predict( |
| | str(dict(kwargs)), |
| | api_name=api_name, |
| | ) |
| | print("Raw client result: %s" % res, flush=True) |
| | res_dict = dict(prompt=kwargs['instruction_nochat'], |
| | response=md_to_text(ast.literal_eval(res)['response']), |
| | sources=ast.literal_eval(res)['sources']) |
| | print(res_dict) |
| | return res_dict, client |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_basic_api_lean_morestuff(prompt_type='human_bot'): |
| | return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) |
| |
|
| |
|
| | def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512): |
| | kwargs = dict( |
| | instruction='', |
| | iinput='', |
| | context='', |
| | stream_output=False, |
| | prompt_type=prompt_type, |
| | temperature=0.1, |
| | top_p=0.75, |
| | top_k=40, |
| | num_beams=1, |
| | max_new_tokens=256, |
| | min_new_tokens=0, |
| | early_stopping=False, |
| | max_time=20, |
| | repetition_penalty=1.0, |
| | num_return_sequences=1, |
| | do_sample=True, |
| | chat=False, |
| | instruction_nochat=prompt, |
| | iinput_nochat='', |
| | langchain_mode='Disabled', |
| | add_chat_history_to_context=True, |
| | langchain_action=LangChainAction.QUERY.value, |
| | langchain_agents=[], |
| | top_k_docs=4, |
| | document_subset=DocumentSubset.Relevant.name, |
| | document_choice=[], |
| | ) |
| |
|
| | api_name = '/submit_nochat_api' |
| | client = get_client(serialize=True) |
| | res = client.predict( |
| | str(dict(kwargs)), |
| | api_name=api_name, |
| | ) |
| | print("Raw client result: %s" % res, flush=True) |
| | res_dict = dict(prompt=kwargs['instruction_nochat'], |
| | response=md_to_text(ast.literal_eval(res)['response']), |
| | sources=ast.literal_eval(res)['sources']) |
| | print(res_dict) |
| | return res_dict, client |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_chat(prompt_type='human_bot'): |
| | return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50, |
| | langchain_mode='Disabled', |
| | langchain_action=LangChainAction.QUERY.value, |
| | langchain_agents=[]) |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_chat_stream(prompt_type='human_bot'): |
| | return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type, |
| | stream_output=True, max_new_tokens=512, |
| | langchain_mode='Disabled', |
| | langchain_action=LangChainAction.QUERY.value, |
| | langchain_agents=[]) |
| |
|
| |
|
| | def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, |
| | langchain_mode, langchain_action, langchain_agents, |
| | prompt_dict=None): |
| | client = get_client(serialize=False) |
| |
|
| | kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output, |
| | max_new_tokens=max_new_tokens, |
| | langchain_mode=langchain_mode, |
| | langchain_action=langchain_action, |
| | langchain_agents=langchain_agents, |
| | prompt_dict=prompt_dict) |
| | return run_client(client, prompt, args, kwargs) |
| |
|
| |
|
| | def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False): |
| | assert kwargs['chat'], "Chat mode only" |
| | res = client.predict(*tuple(args), api_name='/instruction') |
| | args[-1] += [res[-1]] |
| |
|
| | res_dict = kwargs |
| | res_dict['prompt'] = prompt |
| | if not kwargs['stream_output']: |
| | res = client.predict(*tuple(args), api_name='/instruction_bot') |
| | res_dict['response'] = res[0][-1][1] |
| | print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text)) |
| | return res_dict, client |
| | else: |
| | job = client.submit(*tuple(args), api_name='/instruction_bot') |
| | res1 = '' |
| | while not job.done(): |
| | outputs_list = job.communicator.job.outputs |
| | if outputs_list: |
| | res = job.communicator.job.outputs[-1] |
| | res1 = res[0][-1][-1] |
| | res1 = md_to_text(res1, do_md_to_text=do_md_to_text) |
| | print(res1) |
| | time.sleep(0.1) |
| | full_outputs = job.outputs() |
| | if verbose: |
| | print('job.outputs: %s' % str(full_outputs)) |
| | |
| | |
| | |
| | |
| | |
| | res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text) |
| | return res_dict, client |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_nochat_stream(prompt_type='human_bot'): |
| | return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type, |
| | stream_output=True, max_new_tokens=512, |
| | langchain_mode='Disabled', |
| | langchain_action=LangChainAction.QUERY.value, |
| | langchain_agents=[]) |
| |
|
| |
|
| | def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, |
| | langchain_mode, langchain_action, langchain_agents): |
| | client = get_client(serialize=False) |
| |
|
| | kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output, |
| | max_new_tokens=max_new_tokens, langchain_mode=langchain_mode, |
| | langchain_action=langchain_action, langchain_agents=langchain_agents) |
| | return run_client_gen(client, prompt, args, kwargs) |
| |
|
| |
|
| | def run_client_gen(client, prompt, args, kwargs, do_md_to_text=True, verbose=False): |
| | res_dict = kwargs |
| | res_dict['prompt'] = prompt |
| | if not kwargs['stream_output']: |
| | res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api') |
| | res_dict['response'] = res[0] |
| | print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text)) |
| | return res_dict, client |
| | else: |
| | job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api') |
| | while not job.done(): |
| | outputs_list = job.communicator.job.outputs |
| | if outputs_list: |
| | res = job.communicator.job.outputs[-1] |
| | res_dict = ast.literal_eval(res) |
| | print('Stream: %s' % res_dict['response']) |
| | time.sleep(0.1) |
| | res_list = job.outputs() |
| | assert len(res_list) > 0, "No response, check server" |
| | res = res_list[-1] |
| | res_dict = ast.literal_eval(res) |
| | print('Final: %s' % res_dict['response']) |
| | return res_dict, client |
| |
|
| |
|
| | def md_to_text(md, do_md_to_text=True): |
| | if not do_md_to_text: |
| | return md |
| | assert md is not None, "Markdown is None" |
| | html = markdown.markdown(md) |
| | soup = BeautifulSoup(html, features='html.parser') |
| | return soup.get_text() |
| |
|
| |
|
| | def run_client_many(prompt_type='human_bot'): |
| | ret1, _ = test_client_chat(prompt_type=prompt_type) |
| | ret2, _ = test_client_chat_stream(prompt_type=prompt_type) |
| | ret3, _ = test_client_nochat_stream(prompt_type=prompt_type) |
| | ret4, _ = test_client_basic(prompt_type=prompt_type) |
| | ret5, _ = test_client_basic_api(prompt_type=prompt_type) |
| | ret6, _ = test_client_basic_api_lean(prompt_type=prompt_type) |
| | ret7, _ = test_client_basic_api_lean_morestuff(prompt_type=prompt_type) |
| | return ret1, ret2, ret3, ret4, ret5, ret6, ret7 |
| |
|
| |
|
| | if __name__ == '__main__': |
| | run_client_many() |
| |
|