Skip to content

Commit f757a30

Browse files
committed
fix: rendering in streamlit
1 parent b3752eb commit f757a30

File tree

3 files changed

+62
-17
lines changed

3 files changed

+62
-17
lines changed

pygwalker/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pygwalker.services.global_var import GlobalVarManager
1111
from pygwalker.services.kaggle import show_tips_user_kaggle as __show_tips_user_kaggle
1212

13-
__version__ = "0.4.9.4"
13+
__version__ = "0.4.9.5"
1414
__hash__ = __rand_str()
1515

1616
from pygwalker.api.jupyter import walk, render, table

pygwalker/api/streamlit.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pygwalker.utils.check_walker_params import check_expired_params
2222
from pygwalker.utils import fallback_value
2323
from pygwalker.services.streamlit_components import pygwalker_component
24+
from pygwalker.services.data_parsers import get_dataset_hash
2425

2526

2627
class PreFilter(BaseModel):
@@ -65,7 +66,14 @@ def __init__(
6566
default_tab: Literal["data", "vis"] = "vis",
6667
**kwargs
6768
):
68-
"""Get pygwalker html render to streamlit
69+
"""Get pygwalker html render to streamlit.
70+
In Streamlit, pygwalker calculates a somewhat inaccurate gid based on the dataset to
71+
distinguish between datasets and uses it as the key for the Streamlit component to
72+
avoid redundant rendering.
73+
74+
In some use case, If user frequently use the same StreamlitRenderer to receive different dataframes,
75+
and the differences between these dataframes are so small that pygwalker's gid calculation logic cannot distinguish between different datasets,
76+
user should customize method to generate a gid to differentiate between datasets.
6977
7078
Args:
7179
- dataset (pl.DataFrame | pd.DataFrame | Connector, optional): dataframe.
@@ -87,7 +95,7 @@ def __init__(
8795
init_streamlit_comm()
8896

8997
self.walker = PygWalker(
90-
gid=gid,
98+
gid=gid if gid is not None else get_dataset_hash(dataset),
9199
dataset=dataset,
92100
field_specs=field_specs if field_specs is not None else [],
93101
spec=spec,

pygwalker/services/data_parsers.py

+51-14
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,61 @@
11
import sys
2-
from typing import Dict, Optional, Union, Any, List
2+
import hashlib
3+
import pandas as pd
4+
from typing import Dict, Optional, Union, Any, List, Tuple
5+
from typing_extensions import Literal
36

47
from pygwalker.data_parsers.base import BaseDataParser, FieldSpec
58
from pygwalker.data_parsers.database_parser import Connector
69
from pygwalker._typing import DataFrame
710

811
__classname2method = {}
912

13+
DatasetType = Literal['pandas', 'polars', 'modin', 'pyspark', 'connector', 'cloud_dataset']
14+
1015

1116
# pylint: disable=import-outside-toplevel
12-
def _get_data_parser(dataset: Union[DataFrame, Connector, str]) -> BaseDataParser:
17+
def _get_data_parser(dataset: Union[DataFrame, Connector, str]) -> Tuple[BaseDataParser, DatasetType]:
1318
"""
1419
Get DataFrameDataParser for dataset
1520
TODO: Maybe you can find a better way to handle the following code
1621
"""
1722
if type(dataset) in __classname2method:
1823
return __classname2method[type(dataset)]
1924

20-
if 'pandas' in sys.modules:
21-
import pandas as pd
22-
if isinstance(dataset, pd.DataFrame):
23-
from pygwalker.data_parsers.pandas_parser import PandasDataFrameDataParser
24-
__classname2method[pd.DataFrame] = PandasDataFrameDataParser
25-
return __classname2method[pd.DataFrame]
25+
if isinstance(dataset, pd.DataFrame):
26+
from pygwalker.data_parsers.pandas_parser import PandasDataFrameDataParser
27+
__classname2method[pd.DataFrame] = (PandasDataFrameDataParser, "pandas")
28+
return __classname2method[pd.DataFrame]
2629

2730
if 'polars' in sys.modules:
2831
import polars as pl
2932
if isinstance(dataset, pl.DataFrame):
3033
from pygwalker.data_parsers.polars_parser import PolarsDataFrameDataParser
31-
__classname2method[pl.DataFrame] = PolarsDataFrameDataParser
34+
__classname2method[pl.DataFrame] = (PolarsDataFrameDataParser, "polars")
3235
return __classname2method[pl.DataFrame]
3336

3437
if 'modin.pandas' in sys.modules:
3538
from modin import pandas as mpd
3639
if isinstance(dataset, mpd.DataFrame):
3740
from pygwalker.data_parsers.modin_parser import ModinPandasDataFrameDataParser
38-
__classname2method[mpd.DataFrame] = ModinPandasDataFrameDataParser
41+
__classname2method[mpd.DataFrame] = (ModinPandasDataFrameDataParser, "modin")
3942
return __classname2method[mpd.DataFrame]
4043

4144
if 'pyspark' in sys.modules:
4245
from pyspark.sql import DataFrame as SparkDataFrame
4346
if isinstance(dataset, SparkDataFrame):
4447
from pygwalker.data_parsers.spark_parser import SparkDataFrameDataParser
45-
__classname2method[SparkDataFrame] = SparkDataFrameDataParser
48+
__classname2method[SparkDataFrame] = (SparkDataFrameDataParser, "pyspark")
4649
return __classname2method[SparkDataFrame]
4750

4851
if isinstance(dataset, Connector):
4952
from pygwalker.data_parsers.database_parser import DatabaseDataParser
50-
__classname2method[DatabaseDataParser] = DatabaseDataParser
53+
__classname2method[DatabaseDataParser] = (DatabaseDataParser, "connector")
5154
return __classname2method[DatabaseDataParser]
5255

5356
if isinstance(dataset, str):
5457
from pygwalker.data_parsers.cloud_dataset_parser import CloudDatasetParser
55-
__classname2method[CloudDatasetParser] = CloudDatasetParser
58+
__classname2method[CloudDatasetParser] = (CloudDatasetParser, "cloud_dataset")
5659
return __classname2method[CloudDatasetParser]
5760

5861
raise TypeError(f"Unsupported data type: {type(dataset)}")
@@ -70,11 +73,45 @@ def get_parser(
7073
if other_params is None:
7174
other_params = {}
7275

73-
parser = _get_data_parser(dataset)(
76+
parser_func, _ = _get_data_parser(dataset)
77+
parser = parser_func(
7478
dataset,
7579
field_specs,
7680
infer_string_to_date,
7781
infer_number_to_dimension,
7882
other_params
7983
)
8084
return parser
85+
86+
87+
def get_dataset_hash(dataset: Union[DataFrame, Connector, str]) -> str:
88+
"""Just a less accurate way to get different dataset hash values."""
89+
_, dataset_type = _get_data_parser(dataset)
90+
if dataset_type in ["pandas", "modin", "polars"]:
91+
row_count = dataset.shape[0]
92+
other_info = str(dataset.shape) + "_" + dataset_type
93+
if row_count > 4000:
94+
dataset = dataset[:2000] + dataset[-2000:]
95+
if dataset_type == "modin":
96+
dataset = dataset._to_pandas()
97+
if dataset_type in ["pandas", "modin"]:
98+
hash_bytes = pd.util.hash_pandas_object(dataset).values.tobytes() + other_info.encode()
99+
else:
100+
hash_bytes = dataset.hash_rows().to_numpy().tobytes() + other_info.encode()
101+
return hashlib.md5(hash_bytes).hexdigest()
102+
103+
if dataset_type == "pyspark":
104+
shape = ((dataset.count(), len(dataset.columns)))
105+
row_count = shape[0]
106+
other_info = str(shape) + "_" + dataset_type
107+
if row_count > 4000:
108+
dataset = dataset.limit(4000)
109+
dataset_pd = dataset.toPandas()
110+
hash_bytes = pd.util.hash_pandas_object(dataset_pd).values.tobytes() + other_info.encode()
111+
return hashlib.md5(hash_bytes).hexdigest()
112+
113+
if dataset_type == "connector":
114+
return hashlib.md5("_".join([dataset.url, dataset.view_sql, dataset_type]).encode()).hexdigest()
115+
116+
if dataset_type == "cloud_dataset":
117+
return hashlib.md5("_".join([dataset, dataset_type]).encode()).hexdigest()

0 commit comments

Comments
 (0)