-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathgradio_runner.py
7084 lines (6542 loc) · 433 KB
/
gradio_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import ast
import base64
import copy
import functools
import inspect
import itertools
import json
import os
import platform
import pprint
import random
import shutil
import sys
import time
import traceback
import uuid
import filelock
import numpy as np
import pandas as pd
import requests
import ujson
from iterators import TimeoutIterator
from gradio_utils.css import get_css
from gradio_utils.prompt_form import make_chatbots, get_chatbot_name
from gradio_funcs import visible_models_to_model_choice, clear_embeddings, fix_text_for_gradio, get_response, \
my_db_state_done, update_langchain_mode_paths, process_audio, is_valid_key, is_from_ui, get_llm_history, prep_bot, \
allow_empty_instruction, update_prompt, gen1_fake, get_one_key, get_fun_with_dict_str_plain, bot, choose_exc
from db_utils import set_userid, get_username_direct, get_userid_direct, fetch_user, upsert_user, get_all_usernames, \
append_to_user_data, append_to_users_data
from model_utils import switch_a_roo_llama, get_on_disk_models, get_inf_models, model_lock_to_state
from src.prompter_utils import get_chat_template, base64_decode_jinja_template
from tts_utils import combine_audios
from src.enums import IMAGE_EXTENSIONS
# This is a hack to prevent Gradio from phoning home when it gets imported
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
def my_get(url, **kwargs):
print('Gradio HTTP request redirected to localhost :)', flush=True)
kwargs.setdefault('allow_redirects', True)
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
original_get = requests.get
requests.get = my_get
import gradio as gr
requests.get = original_get
def fix_pydantic_duplicate_validators_error():
try:
from pydantic import class_validators
class_validators.in_ipython = lambda: True # type: ignore[attr-defined]
except ImportError:
pass
fix_pydantic_duplicate_validators_error()
from enums import DocumentSubset, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode, \
DocumentChoice, langchain_modes_intrinsic, LangChainTypes, langchain_modes_non_db, gr_to_lg, invalid_key_msg, \
LangChainAgent, docs_ordering_types, docs_token_handlings, docs_joiner_default, split_google, response_formats, \
summary_prefix, extract_prefix, unknown_prompt_type, my_db_state0, requests_state0, noneset, \
is_vision_model, is_video_model
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, \
get_dark_js, get_heap_js, wrap_js_to_lambda, \
spacing_xsm, radius_xsm, text_xsm
from prompter import prompt_type_to_model_name, prompt_types_strings, non_hf_types, \
get_prompt, model_names_curated, get_system_prompts, get_llava_prompts, get_llm_history
from utils import flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
ping, makedirs, get_kwargs, system_info, ping_gpu, get_url, \
save_generate_output, url_alive, remove, dict_to_html, text_to_html, lg_to_gr, str_to_dict, have_serpapi, \
have_librosa, have_gradio_pdf, have_pyrubberband, is_gradio_version4, have_fiftyone, n_gpus_global, \
get_accordion_named, get_is_gradio_h2oai, is_uuid4, get_show_username, deepcopy_by_pickle_object, get_gradio_depth, \
get_supports_schema
from gen import get_model, languages_covered, evaluate, score_qa, inputs_kwargs_list, \
get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list, \
get_model_max_length_from_tokenizer, \
get_model_retry, remove_refs, model_name_to_prompt_type
from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults, \
input_args_list, image_quality_choices, image_size_default
from apscheduler.schedulers.background import BackgroundScheduler
def get_prompt_type1(is_public, **kwargs):
prompt_types_strings_used = prompt_types_strings.copy()
if kwargs['model_lock']:
prompt_types_strings_used += [no_model_str]
default_prompt_type = kwargs['prompt_type'] or no_model_str
else:
default_prompt_type = kwargs['prompt_type'] or unknown_prompt_type
prompt_type = gr.Dropdown(prompt_types_strings_used,
value=default_prompt_type,
label="Choose/Select Prompt Type",
info="Auto-Detected if known (template or unknown means will try to use chat template).",
visible=not kwargs['model_lock'],
interactive=not is_public,
)
return prompt_type
def get_prompt_type2(is_public, **kwargs):
prompt_types_strings_used = prompt_types_strings.copy()
if kwargs['model_lock']:
prompt_types_strings_used += [no_model_str]
default_prompt_type = kwargs['prompt_type'] or no_model_str
else:
default_prompt_type = kwargs['prompt_type'] or unknown_prompt_type
prompt_type2 = gr.Dropdown(prompt_types_strings_used,
value=default_prompt_type,
label="Choose/Select Prompt Type Model 2",
info="Auto-Detected if known (template or unknown means will try to use chat template).",
visible=False and not kwargs['model_lock'],
interactive=not is_public)
return prompt_type2
def ask_block(kwargs, instruction_label, visible_upload, file_types, mic_sources_kwargs, mic_kwargs, noqueue_kwargs2,
submit_kwargs, stop_kwargs):
with gr.Row():
with gr.Column(scale=50):
with gr.Row(elem_id="prompt-form-row"):
label_instruction = 'Ask or Ingest'
instruction = gr.Textbox(
lines=kwargs['input_lines'],
label=label_instruction,
info=instruction_label,
# info=None,
elem_id='prompt-form',
container=True,
)
mw0 = 20
mic_button = gr.Button(
elem_id="microphone-button" if kwargs['enable_stt'] else None,
value="🔴",
size="sm",
min_width=mw0,
visible=kwargs['enable_stt'])
attach_button = gr.UploadButton(
elem_id="attach-button" if visible_upload else None,
value=None,
label="Upload",
size="sm",
min_width=mw0,
file_types=['.' + x for x in file_types],
file_count="multiple",
visible=visible_upload)
add_button = gr.Button(
elem_id="add-button" if visible_upload and not kwargs[
'actions_in_sidebar'] else None,
value="Ingest",
size="sm",
min_width=mw0,
visible=visible_upload and not kwargs['actions_in_sidebar'])
# AUDIO
if kwargs['enable_stt']:
def action(btn, instruction1, audio_state1, stt_continue_mode=1):
# print("B0: %s %s" % (audio_state1[0], instruction1), flush=True)
"""Changes button text on click"""
if btn == '🔴':
audio_state1[3] = 'on'
# print("A: %s %s" % (audio_state1[0], instruction1), flush=True)
if stt_continue_mode == 1:
audio_state1[0] = instruction1
audio_state1[1] = instruction1
audio_state1[2] = None
return '⭕', instruction1, audio_state1
else:
audio_state1[3] = 'off'
if stt_continue_mode == 1:
audio_state1[0] = None # indicates done for race case
instruction1 = audio_state1[1]
audio_state1[2] = []
# print("B1: %s %s" % (audio_state1[0], instruction1), flush=True)
return '🔴', instruction1, audio_state1
# while audio state used, entries are pre_text, instruction source, and audio chunks, condition
audio_state0 = [None, None, None, 'off']
audio_state = gr.State(value=audio_state0)
audio_output = gr.HTML(visible=False)
audio = gr.Audio(**mic_sources_kwargs, streaming=True, visible=False,
# max_length=30 if is_public else None,
elem_id='audio',
# waveform_options=dict(show_controls=True),
)
mic_button_kwargs = dict(fn=functools.partial(action,
stt_continue_mode=kwargs[
'stt_continue_mode']),
inputs=[mic_button, instruction,
audio_state],
outputs=[mic_button, instruction,
audio_state],
api_name=False,
show_progress='hidden')
# JS first, then python, but all in one click instead of using .then() that will delay
mic_button.click(fn=lambda: None, **mic_kwargs, **noqueue_kwargs2) \
.then(**mic_button_kwargs)
audio.stream(fn=kwargs['transcriber_func'],
inputs=[audio_state, audio],
outputs=[audio_state, instruction],
show_progress='hidden')
submit_buttons = gr.Row(equal_height=False, visible=kwargs['visible_submit_buttons'])
with submit_buttons:
mw1 = 50
mw2 = 50
with gr.Column(min_width=mw1):
submit = gr.Button(value='Submit', variant='primary', size='sm',
min_width=mw1, elem_id="submit")
stop_btn = gr.Button(value="Stop", variant='secondary', size='sm',
min_width=mw1, elem_id='stop')
save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
with gr.Column(min_width=mw2):
retry_btn = gr.Button("Redo", size='sm', min_width=mw2)
undo = gr.Button("Undo", size='sm', min_width=mw2)
clear_chat_btn = gr.Button(value="Clear", size='sm', min_width=mw2)
if kwargs['enable_stt'] and (
kwargs['tts_action_phrases'] or kwargs['tts_stop_phrases']):
def detect_words(action_text1, stop_text1, text):
got_action_word = False
action_words = kwargs['tts_action_phrases']
if action_words:
for action_word in action_words:
if action_word.lower() in text.lower():
text = text[:text.lower().index(action_word.lower())]
print("Got action: %s %s" % (action_text1, text), flush=True)
got_action_word = True
if got_action_word:
action_text1 = action_text1 + '.'
got_stop_word = False
stop_words = kwargs['tts_stop_phrases']
if stop_words:
for stop_word in stop_words:
if stop_word.lower() in text.lower():
text = text[:text.lower().index(stop_word.lower())]
print("Got stop: %s %s" % (stop_text1, text), flush=True)
got_stop_word = True
if got_stop_word:
stop_text1 = stop_text1 + '.'
return action_text1, stop_text1, text
action_text = gr.Textbox(value='', visible=False)
stop_text = gr.Textbox(value='', visible=False)
# avoid if no action word, may take extra time
instruction.change(fn=detect_words,
inputs=[action_text, stop_text, instruction],
outputs=[action_text, stop_text, instruction])
def clear_audio_state():
return audio_state0
action_text.change(fn=clear_audio_state, outputs=audio_state) \
.then(fn=lambda: None, **submit_kwargs)
stop_text.change(fn=clear_audio_state, outputs=audio_state) \
.then(fn=lambda: None, **stop_kwargs)
return attach_button, add_button, submit_buttons, instruction, submit, retry_btn, undo, clear_chat_btn, save_chat_btn, stop_btn
def go_gradio(**kwargs):
page_title = kwargs['page_title']
model_label_prefix = kwargs['model_label_prefix']
allow_api = kwargs['allow_api']
is_public = kwargs['is_public']
is_hf = kwargs['is_hf']
memory_restriction_level = kwargs['memory_restriction_level']
n_gpus = kwargs['n_gpus']
admin_pass = kwargs['admin_pass']
model_states = kwargs['model_states']
dbs = kwargs['dbs']
db_type = kwargs['db_type']
visible_langchain_actions = kwargs['visible_langchain_actions']
visible_langchain_agents = kwargs['visible_langchain_agents']
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
enable_sources_list = kwargs['enable_sources_list']
enable_url_upload = kwargs['enable_url_upload']
enable_text_upload = kwargs['enable_text_upload']
use_openai_embedding = kwargs['use_openai_embedding']
hf_embedding_model = kwargs['hf_embedding_model']
load_db_if_exists = kwargs['load_db_if_exists']
migrate_embedding_model = kwargs['migrate_embedding_model']
captions_model = kwargs['captions_model']
caption_loader = kwargs['caption_loader']
doctr_loader = kwargs['doctr_loader']
llava_model = kwargs['llava_model']
asr_model = kwargs['asr_model']
asr_loader = kwargs['asr_loader']
n_jobs = kwargs['n_jobs']
verbose = kwargs['verbose']
# for dynamic state per user session in gradio
model_state0 = kwargs['model_state0']
score_model_state0 = kwargs['score_model_state0']
selection_docs_state0 = kwargs['selection_docs_state0']
visible_models_state0 = kwargs['visible_models_state0']
visible_vision_models_state0 = kwargs['visible_vision_models_state0']
visible_image_models_state0 = kwargs['visible_image_models_state0']
roles_state0 = kwargs['roles_state0']
# For Heap analytics
is_heap_analytics_enabled = kwargs['enable_heap_analytics']
heap_app_id = kwargs['heap_app_id']
# easy update of kwargs needed for evaluate() etc.
queue = True
allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
allow_upload_api = allow_api and allow_upload
h2ogpt_key1 = get_one_key(kwargs['h2ogpt_api_keys'], kwargs['enforce_h2ogpt_api_key'])
kwargs.update(locals().copy())
# import control
if kwargs['langchain_mode'] != 'Disabled':
from gpt_langchain import file_types, have_arxiv
else:
have_arxiv = False
file_types = []
if 'mbart-' in kwargs['model_lower']:
instruction_label_nochat = "Text to translate"
else:
instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \
" use Enter for multiple input lines)"
if kwargs['visible_h2ogpt_links']:
description = """<a href="https://github.com/pseudotensor/open-strawberry">🍓strawberry🍓 project: </a> <br /><a href="https://gpt-docs.h2o.ai">🎉✨ GO: OpenWebUI ✨🎉</a> <br /> <br /><a href="https://github.com/h2oai/h2ogpt">h2oGPT Code</a> <br /><a href="https://huggingface.co/h2oai">🤗 Models</a> <br /><a href="https://h2o.ai/platform/enterprise-h2ogpte/">h2oGPTe</a>"""
else:
description = None
description_bottom = "If this host is busy, try<br>[Multi-Model](https://gpt.h2o.ai)<br>[CodeLlama](https://codellama.h2o.ai)<br>[Llama2 70B](https://llama.h2o.ai)<br>[Falcon 40B](https://falcon.h2o.ai)<br>[HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot)<br>[HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
if is_hf:
description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
task_info_md = ''
css_code = get_css(kwargs, select_string='\"Select_%s\"' % kwargs['max_visible_models'] if kwargs[
'max_visible_models'] else '\"Select_Any\"')
if kwargs['gradio_offline_level'] >= 0:
# avoid GoogleFont that pulls from internet
if kwargs['gradio_offline_level'] == 1:
# front end would still have to download fonts or have cached it at some point
base_font = 'Source Sans Pro'
else:
base_font = 'Helvetica'
theme_kwargs = dict(font=(base_font, 'ui-sans-serif', 'system-ui', 'sans-serif'),
font_mono=('IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'))
else:
theme_kwargs = dict()
if kwargs['gradio_size'] == 'xsmall':
theme_kwargs.update(dict(spacing_size=spacing_xsm, text_size=text_xsm, radius_size=radius_xsm))
elif kwargs['gradio_size'] in [None, 'small']:
theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm,
radius_size=gr.themes.sizes.spacing_sm))
elif kwargs['gradio_size'] == 'large':
theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_lg, text_size=gr.themes.sizes.text_lg),
radius_size=gr.themes.sizes.spacing_lg)
elif kwargs['gradio_size'] == 'medium':
theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_md, text_size=gr.themes.sizes.text_md,
radius_size=gr.themes.sizes.spacing_md))
theme = H2oTheme(**theme_kwargs) if kwargs['h2ocolors'] else SoftTheme(**theme_kwargs)
demo = gr.Blocks(theme=theme, css=css_code, title=page_title, analytics_enabled=False)
callback = gr.CSVLogger()
# modify, if model lock then don't show models, then need prompts in expert
kwargs['visible_models_tab'] = kwargs['visible_models_tab'] and not bool(kwargs['model_lock'])
# Initial model options
if kwargs['visible_all_prompter_models']:
model_options0 = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
else:
model_options0 = []
if kwargs['visible_curated_models']:
model_options0.extend(model_names_curated)
model_options0.extend(kwargs['extra_model_options'])
if kwargs['base_model'].strip() and kwargs['base_model'].strip() not in model_options0:
model_options0 = [kwargs['base_model'].strip()] + model_options0
if kwargs['add_disk_models_to_ui'] and kwargs['visible_models_tab'] and not kwargs['model_lock']:
model_options0.extend(get_on_disk_models(llamacpp_path=kwargs['llamacpp_path'],
use_auth_token=kwargs['use_auth_token'],
trust_remote_code=kwargs['trust_remote_code']))
model_options0 = sorted(set(model_options0))
# Initial LORA options
lora_options = kwargs['extra_lora_options']
if kwargs['lora_weights'].strip() and kwargs['lora_weights'].strip() not in lora_options:
lora_options = [kwargs['lora_weights'].strip()] + lora_options
# Initial server options
server_options = kwargs['extra_server_options']
if kwargs['inference_server'].strip() and kwargs['inference_server'].strip() not in server_options:
server_options = [kwargs['inference_server'].strip()] + server_options
if os.getenv('OPENAI_API_KEY'):
if 'openai_chat' not in server_options:
server_options += ['openai_chat']
if 'openai' not in server_options:
server_options += ['openai']
# always add in no lora case
# add fake space so doesn't go away in gradio dropdown
model_options0 = [no_model_str] + sorted(model_options0)
lora_options = [no_lora_str] + sorted(lora_options)
server_options = [no_server_str] + sorted(server_options)
# always add in no model case so can free memory
# add fake space so doesn't go away in gradio dropdown
# transcribe, will be detranscribed before use by evaluate()
if not kwargs['base_model'].strip():
kwargs['base_model'] = no_model_str
if not kwargs['lora_weights'].strip():
kwargs['lora_weights'] = no_lora_str
if not kwargs['inference_server'].strip():
kwargs['inference_server'] = no_server_str
# transcribe for gradio
kwargs['gpu_id'] = str(kwargs['gpu_id'])
no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
chat_name0 = get_chatbot_name(kwargs.get("base_model"),
kwargs.get("display_name"),
kwargs.get("llamacpp_dict", {}).get("model_path_llama"),
kwargs.get("inference_server"),
kwargs.get("prompt_type"),
kwargs.get("model_label_prefix"),
)
output_label0 = chat_name0 if kwargs.get('base_model') else no_model_msg
output_label0_model2 = no_model_msg
default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
# ensure prompt_type consistent with prep_bot(), so nochat API works same way
default_kwargs['prompt_type'], default_kwargs['prompt_dict'] = \
update_prompt(default_kwargs['prompt_type'], default_kwargs['prompt_dict'],
model_state1=model_state0,
which_model=visible_models_to_model_choice(kwargs['visible_models'], model_states),
global_scope=True, # don't assume state0 is the prompt for all models
**kwargs,
)
for k in no_default_param_names:
default_kwargs[k] = ''
def dummy_fun(x):
# need dummy function to block new input from being sent until output is done,
# else gets input_list at time of submit that is old, and shows up as truncated in chatbot
return x
def update_auth_selection(auth_user, selection_docs_state1, save=False):
# in-place update of both
if 'selection_docs_state' not in auth_user:
auth_user['selection_docs_state'] = selection_docs_state0
for k, v in auth_user['selection_docs_state'].items():
if isinstance(selection_docs_state1[k], dict):
if save:
auth_user['selection_docs_state'][k].clear()
auth_user['selection_docs_state'][k].update(selection_docs_state1[k])
else:
if not kwargs['update_selection_state_from_cli']:
selection_docs_state1[k].clear()
selection_docs_state1[k].update(auth_user['selection_docs_state'][k])
elif isinstance(selection_docs_state1[k], list):
if save:
auth_user['selection_docs_state'][k].clear()
auth_user['selection_docs_state'][k].extend(selection_docs_state1[k])
else:
if not kwargs['update_selection_state_from_cli']:
selection_docs_state1[k].clear()
selection_docs_state1[k].extend(auth_user['selection_docs_state'][k])
newlist = sorted(set(selection_docs_state1[k]))
selection_docs_state1[k].clear()
selection_docs_state1[k].extend(newlist)
else:
raise RuntimeError("Bad type: %s" % selection_docs_state1[k])
# BEGIN AUTH THINGS
def get_auth_password(username1, auth_filename):
with filelock.FileLock(auth_filename + '.lock'):
auth_dict = {}
if os.path.isfile(auth_filename):
if auth_filename.endswith('.db'):
auth_dict = fetch_user(auth_filename, username1, verbose=verbose)
else:
try:
with open(auth_filename, 'rt') as f:
auth_dict = json.load(f)
except json.decoder.JSONDecodeError as e:
print("Auth exception: %s" % str(e), flush=True)
shutil.move(auth_filename, auth_filename + '.bak' + str(uuid.uuid4()))
auth_dict = {}
return auth_dict.get(username1, {}).get('password')
def auth_func(username1, password1, auth_pairs=None, auth_filename=None,
auth_access=None,
auth_freeze=None,
guest_name=None,
selection_docs_state1=None,
selection_docs_state00=None,
id0=None,
**kwargs):
assert auth_freeze is not None
if selection_docs_state1 is None:
selection_docs_state1 = selection_docs_state00
assert selection_docs_state1 is not None
assert auth_filename and isinstance(auth_filename, str), "Auth file must be a non-empty string, got: %s" % str(
auth_filename)
if auth_access == 'open' and guest_name and username1.startswith(guest_name):
return True
if username1 == '':
# some issue with login
return False
if guest_name and username1.startswith(guest_name):
# for random access with persistent password in auth case
# username1 here only for auth check, rest of time full guest name used
username1 = guest_name
with filelock.FileLock(auth_filename + '.lock'):
auth_dict = {}
if os.path.isfile(auth_filename):
print("Auth access: %s" % username1)
if auth_filename.endswith('.db'):
auth_dict = fetch_user(auth_filename, username1, verbose=verbose)
else:
try:
with open(auth_filename, 'rt') as f:
auth_dict = json.load(f)
except json.decoder.JSONDecodeError as e:
print("Auth exception: %s" % str(e), flush=True)
shutil.move(auth_filename, auth_filename + '.bak' + str(uuid.uuid4()))
auth_dict = {}
if username1 in auth_dict and username1 in auth_pairs:
if password1 == auth_dict[username1]['password'] and password1 == auth_pairs[username1]:
auth_user = auth_dict[username1]
update_auth_selection(auth_user, selection_docs_state1)
save_auth_dict(auth_dict, auth_filename, username1)
return True
else:
return False
elif username1 in auth_dict and 'password' in auth_dict[username1]:
if password1 == auth_dict[username1]['password']:
auth_user = auth_dict[username1]
update_auth_selection(auth_user, selection_docs_state1)
save_auth_dict(auth_dict, auth_filename, username1)
return True
else:
return False
elif username1 in auth_pairs:
# copy over CLI auth to file so only one state to manage
auth_dict[username1] = dict(password=auth_pairs[username1], userid=id0 or str(uuid.uuid4()))
auth_user = auth_dict[username1]
update_auth_selection(auth_user, selection_docs_state1)
save_auth_dict(auth_dict, auth_filename, username1)
return True
else:
if auth_access == 'closed':
return False
# open access
auth_dict[username1] = dict(password=password1, userid=id0 or str(uuid.uuid4()))
auth_user = auth_dict[username1]
update_auth_selection(auth_user, selection_docs_state1)
save_auth_dict(auth_dict, auth_filename, username1)
if auth_access == 'open':
return True
else:
raise RuntimeError("Invalid auth_access: %s" % auth_access)
def auth_func_open(*args, **kwargs):
return True
def get_username(requests_state1):
username1 = None
if 'username' in requests_state1:
username1 = requests_state1['username']
return username1
def get_userid_auth_func(requests_state1, auth_filename=None, auth_access=None, guest_name=None, id0=None,
**kwargs):
username1 = get_username(requests_state1)
if auth_filename and isinstance(auth_filename, str):
if username1:
if username1.startswith(guest_name):
return str(uuid.uuid4())
with filelock.FileLock(auth_filename + '.lock'):
if os.path.isfile(auth_filename):
if auth_filename.endswith('.db'):
auth_dict = fetch_user(auth_filename, username1, verbose=verbose)
else:
with open(auth_filename, 'rt') as f:
auth_dict = json.load(f)
if username1 in auth_dict:
return auth_dict[username1]['userid']
# if here, then not persistently associated with username1,
# but should only be one-time asked if going to persist within a single session!
return id0 or username1 or str(uuid.uuid4())
get_userid_auth = functools.partial(get_userid_auth_func,
auth_filename=kwargs['auth_filename'],
auth_access=kwargs['auth_access'],
guest_name=kwargs['guest_name'],
)
if kwargs['auth_access'] == 'closed':
auth_message1 = "Closed access"
else:
auth_message1 = "WELCOME to %s! Open access" \
" (%s/%s or any unique user/pass)" % (page_title, kwargs['guest_name'], kwargs['guest_name'])
if kwargs['auth_message'] is not None:
auth_message = kwargs['auth_message']
else:
auth_message = auth_message1
# always use same callable
auth_pairs0 = {}
if isinstance(kwargs['auth'], list):
for k, v in kwargs['auth']:
auth_pairs0[k] = v
authf = functools.partial(auth_func,
auth_pairs=auth_pairs0,
auth_filename=kwargs['auth_filename'],
auth_access=kwargs['auth_access'],
auth_freeze=kwargs['auth_freeze'],
guest_name=kwargs['guest_name'],
selection_docs_state00=copy.deepcopy(selection_docs_state0))
def get_request_state(requests_state1, request, db1s):
# if need to get state, do it now
if not requests_state1:
requests_state1 = requests_state0.copy()
if requests:
if not requests_state1.get('headers', '') and hasattr(request, 'headers'):
requests_state1.update(request.headers)
if not requests_state1.get('host', '') and hasattr(request, 'host'):
requests_state1.update(dict(host=request.host))
if not requests_state1.get('host2', '') and hasattr(request, 'client') and hasattr(request.client, 'host'):
requests_state1.update(dict(host2=request.client.host))
if not requests_state1.get('username', '') and hasattr(request, 'username'):
# use already-defined username instead of keep changing to new uuid
# should be same as in requests_state1
db_username = get_username_direct(db1s)
if request.username and split_google in request.username:
assert len(request.username.split(split_google)) >= 2 # 3 if already got pic out
username = split_google.join(request.username.split(split_google)[0:2]) # no picture
else:
username = request.username
requests_state1.update(dict(username=username or db_username or str(uuid.uuid4())))
if not requests_state1.get('picture', ''):
if request.username and split_google in request.username and len(
request.username.split(split_google)) == 3:
picture = split_google.join(request.username.split(split_google)[2:3]) # picture
else:
picture = None
requests_state1.update(dict(picture=picture))
requests_state1 = {str(k): str(v) for k, v in requests_state1.items()}
return requests_state1
def user_state_setup(db1s, requests_state1, guest_name1, request: gr.Request, *args):
requests_state1 = get_request_state(requests_state1, request, db1s)
set_userid(db1s, requests_state1, get_userid_auth, guest_name=guest_name1)
args_list = [db1s, requests_state1] + list(args)
return tuple(args_list)
# END AUTH THINGS
image_audio_loaders_options0, image_audio_loaders_options, \
pdf_loaders_options0, pdf_loaders_options, \
url_loaders_options0, url_loaders_options = lg_to_gr(**kwargs)
jq_schema0 = '.[]'
def click_js():
return """function audioRecord() {
var xPathRes = document.evaluate ('//*[contains(@class, "record")]', document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null);
xPathRes.singleNodeValue.click();}"""
def click_submit():
return """function check() {
document.getElementById("submit").click();
}"""
def click_stop():
return """function check() {
document.getElementById("stop").click();
}"""
if is_gradio_version4:
noqueue_kwargs = dict(concurrency_limit=None)
noqueue_kwargs2 = dict(concurrency_limit=None)
noqueue_kwargs_curl = dict(queue=False)
mic_kwargs = dict(js=click_js())
submit_kwargs = dict(js=click_submit())
stop_kwargs = dict(js=click_stop())
dark_kwargs = dict(js=wrap_js_to_lambda(0, get_dark_js()))
queue_kwargs = dict(default_concurrency_limit=kwargs['concurrency_count'])
mic_sources_kwargs = dict(sources=['microphone'],
waveform_options=dict(show_controls=False, show_recording_waveform=False))
else:
noqueue_kwargs = dict(queue=False)
noqueue_kwargs2 = dict()
noqueue_kwargs_curl = dict(queue=False)
mic_kwargs = dict(_js=click_js())
submit_kwargs = dict(_js=click_submit())
stop_kwargs = dict(_js=click_stop())
dark_kwargs = dict(_js=wrap_js_to_lambda(0, get_dark_js()))
queue_kwargs = dict(concurrency_count=kwargs['concurrency_count'])
mic_sources_kwargs = dict(source='microphone')
if kwargs['model_lock']:
have_vision_models = any(
[is_vision_model(x.get('base_model', '')) or
x.get('display_name', x.get('base_model')) in kwargs['is_vision_models'] for x in kwargs['model_lock']])
else:
have_vision_models = is_vision_model(kwargs['base_model']) or kwargs.get('display_name',
kwargs['base_model']) in kwargs[
'is_vision_models']
is_gradio_h2oai = get_is_gradio_h2oai()
# image control prep
image_gen_visible = kwargs['enable_imagegen']
image_change_visible = kwargs['enable_imagechange']
image_control_panels_visible = False # WIP
image_tab_visible = image_control_panels_visible and (image_gen_visible or image_change_visible)
visible_image_models_visible = len(visible_image_models_state0) > 1
visible_image_models_kwargs = dict(choices=visible_image_models_state0,
label="Visible ImageGen Models",
value=visible_image_models_state0[
0] if visible_image_models_state0 else None,
interactive=True,
multiselect=False,
visible=visible_image_models_visible,
filterable=False,
max_choices=None,
)
image_quality_kwargs = dict(choices=image_quality_choices, label="Image Quality", value=image_quality_choices[0],
visible=not is_public)
image_size_kwargs = dict(value=image_size_default, label="Image Size", visible=not is_public)
image_guidance_kwargs = dict(label="Image generation guidance", value=3.0, visible=not is_public)
image_num_inference_steps_kwargs = dict(label="Image generation inference steps", value=50, visible=not is_public)
with demo:
support_state_callbacks = hasattr(gr.State(), 'callback')
# avoid actual model/tokenizer here or anything that would be bad to deepcopy
# https://github.com/gradio-app/gradio/issues/3558
def model_state_done(state):
if isinstance(state, dict) and 'model' in state and hasattr(state['model'], 'cpu'):
state['model'].cpu()
state['model'] = None
clear_torch_cache()
model_state_cb = dict(callback=model_state_done) if support_state_callbacks else {}
model_state_default = dict(model='model', tokenizer='tokenizer', device='device',
base_model=kwargs['base_model'],
display_name=kwargs['base_model'],
tokenizer_base_model=kwargs['tokenizer_base_model'],
lora_weights=kwargs['lora_weights'],
inference_server=kwargs['inference_server'],
prompt_type=kwargs['prompt_type'],
prompt_dict=kwargs['prompt_dict'],
chat_template=kwargs['chat_template'],
visible_models=visible_models_to_model_choice(kwargs['visible_models'],
model_states),
h2ogpt_key=None,
# only apply at runtime when doing API call with gradio inference server
)
[model_state_default.update({k: v}) for k, v in kwargs['model_state_none'].items() if
k not in model_state_default]
model_state = gr.State(value=model_state_default, **model_state_cb)
my_db_state_cb = dict(callback=my_db_state_done) if support_state_callbacks else {}
model_state2 = gr.State(kwargs['model_state_none'].copy())
model_options_state = gr.State([model_options0], **model_state_cb)
lora_options_state = gr.State([lora_options])
server_options_state = gr.State([server_options])
my_db_state = gr.State(my_db_state0, **my_db_state_cb)
chat_state = gr.State({})
if kwargs['enable_tts'] and kwargs['tts_model'].startswith('tts_models/'):
from tts_coqui import get_role_to_wave_map
roles_state0 = roles_state0 if roles_state0 else get_role_to_wave_map()
else:
roles_state0 = {}
roles_state = gr.State(roles_state0)
docs_state00 = kwargs['document_choice'] + [DocumentChoice.ALL.value]
docs_state0 = []
[docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
docs_state = gr.State(docs_state0)
viewable_docs_state0 = ['None']
viewable_docs_state = gr.State(viewable_docs_state0)
selection_docs_state0 = update_langchain_mode_paths(selection_docs_state0)
selection_docs_state = gr.State(selection_docs_state0)
requests_state = gr.State(requests_state0)
if description is None:
description = ''
markdown_logo = f"""
{get_h2o_title(page_title, description, visible_h2ogpt_qrcode=kwargs['visible_h2ogpt_qrcode'])
if kwargs['h2ocolors'] else get_simple_title(page_title, description)}
"""
if kwargs['visible_h2ogpt_logo']:
gr.Markdown(markdown_logo)
# go button visible if
base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
nas = ' '.join(['NA'] * len(kwargs['model_states']))
res_value = "Response Score: NA" if not kwargs[
'model_lock'] else "Response Scores: %s" % nas
user_can_do_sum = kwargs['langchain_mode'] != LangChainMode.DISABLED.value and \
(kwargs['visible_side_bar'] or kwargs['visible_system_tab'])
if user_can_do_sum:
extra_prompt_form = ". Just Click Submit for simple Summarize/Extract"
else:
extra_prompt_form = ""
if allow_upload:
extra_prompt_form += ". Clicking Ingest adds text as URL/ArXiv/YouTube/Text."
if kwargs['input_lines'] > 1:
instruction_label = "Shift-Enter to Submit, Enter adds lines%s" % extra_prompt_form
else:
instruction_label = "Enter to Submit, Shift-Enter adds lines%s" % extra_prompt_form
def get_langchain_choices(selection_docs_state1):
langchain_modes = selection_docs_state1['langchain_modes']
if is_hf:
# don't show 'wiki' since only usually useful for internal testing at moment
no_show_modes = ['Disabled', 'wiki']
else:
no_show_modes = ['Disabled']
allowed_modes = langchain_modes.copy()
# allowed_modes = [x for x in allowed_modes if x in dbs]
allowed_modes += ['LLM']
if allow_upload_to_my_data and 'MyData' not in allowed_modes:
allowed_modes += ['MyData']
if allow_upload_to_user_data and 'UserData' not in allowed_modes:
allowed_modes += ['UserData']
choices = [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes]
return choices
def get_df_langchain_mode_paths(selection_docs_state1, db1s, dbs1=None):
langchain_choices1 = get_langchain_choices(selection_docs_state1)
langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if k in langchain_choices1}
if langchain_mode_paths:
langchain_mode_paths = langchain_mode_paths.copy()
for langchain_mode1 in langchain_modes_non_db:
langchain_mode_paths.pop(langchain_mode1, None)
df1 = pd.DataFrame.from_dict(langchain_mode_paths.items(), orient='columns')
df1.columns = ['Collection', 'Path']
df1 = df1.set_index('Collection')
else:
df1 = pd.DataFrame(None)
langchain_mode_types = selection_docs_state1['langchain_mode_types']
langchain_mode_types = {k: v for k, v in langchain_mode_types.items() if k in langchain_choices1}
if langchain_mode_types:
langchain_mode_types = langchain_mode_types.copy()
for langchain_mode1 in langchain_modes_non_db:
langchain_mode_types.pop(langchain_mode1, None)
df2 = pd.DataFrame.from_dict(langchain_mode_types.items(), orient='columns')
df2.columns = ['Collection', 'Type']
df2 = df2.set_index('Collection')
from gpt_langchain import get_persist_directory, load_embed
persist_directory_dict = {}
embed_dict = {}
chroma_version_dict = {}
for langchain_mode3 in langchain_mode_types:
langchain_type3 = langchain_mode_types.get(langchain_mode3, LangChainTypes.EITHER.value)
# this also makes a directory, but may not use it later
persist_directory3, langchain_type3 = get_persist_directory(langchain_mode3,
langchain_type=langchain_type3,
db1s=db1s, dbs=dbs1)
got_embedding3, use_openai_embedding3, hf_embedding_model3 = load_embed(
persist_directory=persist_directory3, use_openai_embedding=use_openai_embedding)
persist_directory_dict[langchain_mode3] = persist_directory3
embed_dict[langchain_mode3] = 'OpenAI' if not hf_embedding_model3 else hf_embedding_model3
if os.path.isfile(os.path.join(persist_directory3, 'chroma.sqlite3')):
chroma_version_dict[langchain_mode3] = 'ChromaDB>=0.4'
elif os.path.isdir(os.path.join(persist_directory3, 'index')):
chroma_version_dict[langchain_mode3] = 'ChromaDB<0.4'
elif not os.listdir(persist_directory3):
if db_type == 'chroma':
chroma_version_dict[langchain_mode3] = 'ChromaDB>=0.4' # will be
elif db_type == 'chroma_old':
chroma_version_dict[langchain_mode3] = 'ChromaDB<0.4' # will be
else:
chroma_version_dict[langchain_mode3] = 'Weaviate' # will be
if isinstance(hf_embedding_model, dict):
hf_embedding_model3 = hf_embedding_model['name']
else:
hf_embedding_model3 = 'OpenAI' if not hf_embedding_model else hf_embedding_model
assert isinstance(hf_embedding_model3, str)
embed_dict[langchain_mode3] = hf_embedding_model3 # will be
else:
chroma_version_dict[langchain_mode3] = 'Weaviate'
df3 = pd.DataFrame.from_dict(persist_directory_dict.items(), orient='columns')
df3.columns = ['Collection', 'Directory']
df3 = df3.set_index('Collection')
df4 = pd.DataFrame.from_dict(embed_dict.items(), orient='columns')
df4.columns = ['Collection', 'Embedding']
df4 = df4.set_index('Collection')
df5 = pd.DataFrame.from_dict(chroma_version_dict.items(), orient='columns')
df5.columns = ['Collection', 'DB']
df5 = df5.set_index('Collection')
else:
df2 = pd.DataFrame(None)
df3 = pd.DataFrame(None)
df4 = pd.DataFrame(None)
df5 = pd.DataFrame(None)
df_list = [df2, df1, df3, df4, df5]
df_list = [x for x in df_list if x.shape[1] > 0]
if len(df_list) > 1:
df = df_list[0].join(df_list[1:]).replace(np.nan, '').reset_index()
elif len(df_list) == 0:
df = df_list[0].replace(np.nan, '').reset_index()
else:
df = pd.DataFrame(None)
return df
normal_block = gr.Row(visible=not base_wanted, equal_height=False, elem_id="col_container")
with normal_block:
side_bar = gr.Column(elem_id="sidebar", scale=1, min_width=100, visible=kwargs['visible_side_bar'])
with side_bar:
with gr.Accordion("Chats", open=False, visible=True):
radio_chats = gr.Radio(value=None, label="Saved Chats", show_label=False,
visible=True, interactive=True,
type='value')
visible_speak_me = kwargs['enable_tts'] and kwargs['predict_from_text_func'] is not None
speak_human_button = gr.Button("Speak Instruction", visible=visible_speak_me, size='sm')
speak_bot_button = gr.Button("Speak Response", visible=visible_speak_me, size='sm')
speak_text_api_button = gr.Button("Speak Text API", visible=False)
speak_text_plain_api_button = gr.Button("Speak Text Plain API", visible=False)
stop_speak_button = gr.Button("Stop/Clear Speak", visible=visible_speak_me, size='sm')
if kwargs['enable_tts'] and kwargs['tts_model'].startswith('tts_models/'):
from tts_coqui import get_roles
chatbot_role = get_roles(choices=list(roles_state.value.keys()), value=kwargs['chatbot_role'])
else:
chatbot_role = gr.Dropdown(choices=['None'], visible=False, value='None')
if kwargs['enable_tts'] and kwargs['tts_model'].startswith('microsoft'):
from tts import get_speakers_gr
speaker = get_speakers_gr(value=kwargs['speaker'])
else:
speaker = gr.Radio(visible=False)
min_tts_speed = 1.0 if not have_pyrubberband else 0.1
tts_speed = gr.Number(minimum=min_tts_speed, maximum=10.0, step=0.1,
value=kwargs['tts_speed'],
label='Speech Speed',
visible=kwargs['enable_tts'] and not is_public,
interactive=not is_public)
upload_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload
url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
if have_arxiv and have_librosa:
url_label = 'URLs/ArXiv/Youtube'
elif have_arxiv:
url_label = 'URLs/ArXiv'
elif have_librosa:
url_label = 'URLs/Youtube'
else:
url_label = 'URLs'
text_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload
fileup_output_text = gr.Textbox(visible=False)
with gr.Accordion("Upload", open=False, visible=upload_visible and kwargs['actions_in_sidebar']):
fileup_output = gr.File(show_label=False,
file_types=['.' + x for x in file_types],
# file_types=['*', '*.*'], # for iPhone etc. needs to be unconstrained else doesn't work with extension-based restrictions
file_count="multiple",
scale=1,
min_width=0,
elem_id="warning", elem_classes="feedback",
)
if kwargs['actions_in_sidebar']:
max_quality = gr.Checkbox(label="Max Ingest Quality", value=kwargs['max_quality'],
visible=kwargs['visible_max_quality'] and not is_public)
gradio_upload_to_chatbot = gr.Checkbox(label="Add Doc to Chat",
value=kwargs['gradio_upload_to_chatbot'],
visible=kwargs[
'visible_add_doc_to_chat'] and not is_public)
url_text = gr.Textbox(label=url_label,
# placeholder="Enter Submits",