Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import io | |
| import os | |
| import yaml | |
| import pyarrow | |
| import tokenizers | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| # SETTING PAGE CONFIG TO WIDE MODE | |
| st.set_page_config(layout="wide") | |
| def from_library(): | |
| from retro_reader import RetroReader | |
| from retro_reader import constants as C | |
| return C, RetroReader | |
| C, RetroReader = from_library() | |
| # https://stackoverflow.com/questions/70274841/streamlit-unhashable-typeerror-when-i-use-st-cache | |
| my_hash_func = { | |
| io.TextIOWrapper: lambda _: None, | |
| pyarrow.lib.Buffer: lambda _: 0, | |
| tokenizers.Tokenizer: lambda _: None, | |
| tokenizers.AddedToken: lambda _: None | |
| } | |
| # @st.cache(hash_funcs=my_hash_func, allow_output_mutation=True) | |
| # def load_ko_roberta_large_model(): | |
| # config_file = "configs/inference_ko_roberta_large.yaml" | |
| # return RetroReader.load(config_file=config_file) | |
| # @st.cache(hash_funcs=my_hash_func, allow_output_mutation=True) | |
| # def load_ko_electra_small_model(): | |
| # config_file = "configs/inference_ko_electra_small.yaml" | |
| # return RetroReader.load(config_file=config_file) | |
| # @st.cache(hash_funcs=my_hash_func, allow_output_mutation=True) | |
| # def load_en_electra_large_model(): | |
| # config_file = "configs/inference_en_electra_large.yaml" | |
| # return RetroReader.load(config_file=config_file) | |
| def load_vi_electra_base_model(): | |
| config_file = "configs/inference_vi_electra_base.yaml" | |
| return RetroReader.load(config_file=config_file) | |
| RETRO_READER_HOST = { | |
| # "klue/roberta-large": load_ko_roberta_large_model(), | |
| # "monologg/koelectra-small-v3-discriminator": load_ko_electra_small_model(), | |
| "google/electra-large-discriminator": load_vi_electra_base_model(), | |
| } | |
| def main(): | |
| st.title("Retrospective Reader Demo") | |
| # st.markdown("## Model name") | |
| # option = st.selectbox( | |
| # label="Choose the model used in retro reader", | |
| # options=( | |
| # # "[ko_KR] klue/roberta-large", | |
| # # "[ko_KR] monologg/koelectra-small-v3-discriminator", | |
| # "[vi_XX] google/electra-large-discriminator", | |
| # ), | |
| # index=0, | |
| # ) | |
| # lang_code, model_name = option.split(" ") | |
| retro_reader = load_vi_electra_base_model() | |
| # retro_reader = load_model() | |
| lang_prefix = "EN" | |
| height = 300 | |
| # retro_reader.null_score_diff_threshold = st.sidebar.slider( | |
| # label="null_score_diff_threshold", | |
| # min_value=-10.0, max_value=10.0, value=0.0, step=1.0, | |
| # help="ma!", | |
| # ) | |
| # retro_reader.rear_threshold = st.sidebar.slider( | |
| # label="rear_threshold", | |
| # min_value=-10.0, max_value=10.0, value=0.0, step=1.0, | |
| # help="ma!", | |
| # ) | |
| # retro_reader.n_best_size = st.sidebar.slider( | |
| # label="n_best_size", | |
| # min_value=1, max_value=50, value=20, step=1, | |
| # help="ma!", | |
| # ) | |
| # retro_reader.beta1 = st.sidebar.slider( | |
| # label="beta1", | |
| # min_value=-10.0, max_value=10.0, value=1.0, step=1.0, | |
| # help="ma!", | |
| # ) | |
| # retro_reader.beta2 = st.sidebar.slider( | |
| # label="beta2", | |
| # min_value=-10.0, max_value=10.0, value=1.0, step=1.0, | |
| # help="ma!", | |
| # ) | |
| # retro_reader.best_cof = st.sidebar.slider( | |
| # label="best_cof", | |
| # min_value=-10.0, max_value=10.0, value=1.0, step=1.0, | |
| # help="ma!", | |
| # ) | |
| # return_submodule_outputs = st.sidebar.checkbox('return_submodule_outputs', value=False) | |
| return_submodule_outputs = False | |
| st.markdown("## Demonstration") | |
| with st.form(key="my_form"): | |
| query = st.text_input( | |
| label="Type your query", | |
| value=getattr(C, f"{lang_prefix}_EXAMPLE_QUERY"), | |
| max_chars=None, | |
| help=getattr(C, f"{lang_prefix}_QUERY_HELP_TEXT"), | |
| ) | |
| context = st.text_area( | |
| label="Type your context", | |
| value=getattr(C, f"{lang_prefix}_EXAMPLE_CONTEXTS"), | |
| height=height, | |
| max_chars=None, | |
| help=getattr(C, f"{lang_prefix}_CONTEXT_HELP_TEXT"), | |
| ) | |
| submit_button = st.form_submit_button(label="Submit") | |
| if submit_button: | |
| with st.spinner("Please wait.."): | |
| outputs = retro_reader( | |
| query=query, | |
| context=context, | |
| return_submodule_outputs=return_submodule_outputs, | |
| ) | |
| answer, score = outputs[0]["id-01"], outputs[1] | |
| if not answer: | |
| answer = "No answer" | |
| st.markdown("## Results") | |
| st.write(answer) | |
| st.markdown("### Rear Verification Score") | |
| st.json(score) | |
| # if return_submodule_outputs: | |
| # score_ext, nbest_preds, score_diff = outputs[2:] | |
| # st.markdown("### Sketch Reader Score (score_ext)") | |
| # st.json(score_ext) | |
| # st.markdown("### Intensive Reader Score (score_diff)") | |
| # st.json(score_diff) | |
| # st.markdown("### N Best Predictions (from intensive reader)") | |
| # st.json(nbest_preds) | |
| if __name__ == "__main__": | |
| main() |