| | import streamlit as st |
| | import pandas as pd |
| | import numpy as np |
| | from streamlit_echarts import st_echarts |
| | from app.show_examples import * |
| | from app.content import * |
| |
|
| | import pandas as pd |
| |
|
| | from app.content import wer_displayname2datasetname |
| | from model_information import get_dataframe |
| | info_df = get_dataframe() |
| |
|
| |
|
| | def draw(folder_name, category_name, displayname, metrics, cus_sort=True): |
| | |
| | folder = f"./results_organized/{metrics}/" |
| |
|
| | |
| | data_path = f'{folder}/{category_name.lower()}.csv' |
| | chart_data = pd.read_csv(data_path).round(3) |
| | |
| | dataset_name = displayname2datasetname[displayname] |
| | chart_data = chart_data[['Model', dataset_name]] |
| |
|
| | |
| | chart_data = chart_data.rename(columns=datasetname2diaplayname) |
| |
|
| | st.markdown(""" |
| | <style> |
| | .stMultiSelect [data-baseweb=select] span { |
| | max-width: 800px; |
| | font-size: 0.9rem; |
| | background-color: #3C6478 !important; /* Background color for selected items */ |
| | color: white; /* Change text color */ |
| | back |
| | } |
| | </style> |
| | """, unsafe_allow_html=True) |
| | |
| | |
| | display_model_names = {key.strip() :val.strip() for key, val in zip(info_df['Original Name'], info_df['Proper Display Name'])} |
| | chart_data['model_show'] = chart_data['Model'].map(lambda x: display_model_names.get(x, x)) |
| |
|
| |
|
| | models = st.multiselect("Please choose the model", |
| | sorted(chart_data['model_show'].tolist()), |
| | default = sorted(chart_data['model_show'].tolist()), |
| | ) |
| | |
| | chart_data = chart_data[chart_data['model_show'].isin(models)] |
| | chart_data = chart_data.sort_values(by=[displayname], ascending=cus_sort).dropna(axis=0) |
| |
|
| | if len(chart_data) == 0: return |
| | |
| | |
| | |
| | ''' |
| | Show Table |
| | ''' |
| | with st.container(): |
| | st.markdown('##### TABLE') |
| |
|
| | |
| | model_link = {key.strip(): val for key, val in zip(info_df['Proper Display Name'], info_df['Link'])} |
| |
|
| | chart_data['model_link'] = chart_data['model_show'].map(model_link) |
| |
|
| | chart_data_table = chart_data[['model_show', chart_data.columns[1], chart_data.columns[3]]] |
| |
|
| | |
| | |
| | cur_dataset_name = chart_data_table.columns[1] |
| |
|
| |
|
| | def highlight_first_element(x): |
| | |
| | df_style = pd.DataFrame('', index=x.index, columns=x.columns) |
| | |
| | |
| | df_style.iloc[0, 1] = 'background-color: #b0c1d7' |
| | |
| | return df_style |
| |
|
| | if cur_dataset_name in wer_displayname2datasetname: |
| | chart_data_table = chart_data_table.sort_values( |
| | by=chart_data_table.columns[1], |
| | ascending=True |
| | ).reset_index(drop=True) |
| | else: |
| | chart_data_table = chart_data_table.sort_values( |
| | by=chart_data_table.columns[1], |
| | ascending=False |
| | ).reset_index(drop=True) |
| | |
| |
|
| | styled_df = chart_data_table.style.format( |
| | {chart_data_table.columns[1]: "{:.3f}"} |
| | ).apply( |
| | highlight_first_element, axis=None |
| | ) |
| |
|
| |
|
| | st.dataframe( |
| | styled_df, |
| | column_config={ |
| | 'model_show': 'Model', |
| | chart_data_table.columns[1]: {'alignment': 'left'}, |
| | "model_link": st.column_config.LinkColumn( |
| | "Model Link", |
| | ), |
| | }, |
| | hide_index=True, |
| | use_container_width=True |
| | ) |
| | |
| |
|
| | |
| | ''' |
| | Show Chart |
| | ''' |
| |
|
| | |
| | if "show_chart" not in st.session_state: |
| | st.session_state.show_chart = False |
| |
|
| | |
| | if st.button("Show Chart"): |
| | st.session_state.show_chart = not st.session_state.show_chart |
| |
|
| | if st.session_state.show_chart: |
| |
|
| | with st.container(): |
| | st.markdown('##### CHART') |
| |
|
| | |
| | data_values = chart_data.iloc[:, 1] |
| | |
| | |
| | q1 = data_values.quantile(0.25) |
| | q3 = data_values.quantile(0.75) |
| |
|
| | |
| | iqr = q3 - q1 |
| |
|
| | |
| | lower_bound = q1 - 1.5 * iqr |
| | upper_bound = q3 + 1.5 * iqr |
| |
|
| | |
| | filtered_data = data_values[(data_values >= lower_bound) & (data_values <= upper_bound)] |
| |
|
| | |
| | min_value = round(filtered_data.min() - 0.1 * filtered_data.min(), 3) |
| | max_value = round(filtered_data.max() + 0.1 * filtered_data.max(), 3) |
| |
|
| | options = { |
| | |
| | "tooltip": { |
| | "trigger": "axis", |
| | "axisPointer": {"type": "cross", "label": {"backgroundColor": "#6a7985"}}, |
| | "triggerOn": 'mousemove', |
| | }, |
| | "legend": {"data": ['Overall Accuracy']}, |
| | "toolbox": {"feature": {"saveAsImage": {}}}, |
| | "grid": {"left": "3%", "right": "4%", "bottom": "3%", "containLabel": True}, |
| | "xAxis": [ |
| | { |
| | "type": "category", |
| | "boundaryGap": True, |
| | "triggerEvent": True, |
| | "data": chart_data['model_show'].tolist(), |
| | } |
| | ], |
| | "yAxis": [{"type": "value", |
| | "min": min_value, |
| | "max": max_value, |
| | "boundaryGap": True |
| | |
| | }], |
| | "series": [{ |
| | "name": f"{dataset_name}", |
| | "type": "bar", |
| | "data": chart_data[f'{displayname}'].tolist(), |
| | }], |
| | } |
| | |
| | events = { |
| | "click": "function(params) { return params.value }" |
| | } |
| |
|
| | value = st_echarts(options=options, events=events, height="500px") |
| | |