Skip to content

Commit e61bcbe

Browse files
authored
[python-package] Infer feature names from pyarrow.Table (#6781)
1 parent e0c34e7 commit e61bcbe

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

python-package/lightgbm/basic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2126,6 +2126,8 @@ def _lazy_init(
21262126
categorical_feature=categorical_feature,
21272127
pandas_categorical=self.pandas_categorical,
21282128
)
2129+
elif _is_pyarrow_table(data) and feature_name == "auto":
2130+
feature_name = data.column_names
21292131

21302132
# process for args
21312133
params = {} if params is None else params
@@ -2185,7 +2187,6 @@ def _lazy_init(
21852187
self.__init_from_np2d(data, params_str, ref_dataset)
21862188
elif _is_pyarrow_table(data):
21872189
self.__init_from_pyarrow_table(data, params_str, ref_dataset)
2188-
feature_name = data.column_names
21892190
elif isinstance(data, list) and len(data) > 0:
21902191
if _is_list_of_numpy_arrays(data):
21912192
self.__init_from_list_np2d(data, params_str, ref_dataset)

tests/python_package_test/test_arrow.py

+22
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,25 @@ def test_predict_ranking():
432432
num_boost_round=5,
433433
)
434434
assert_equal_predict_arrow_pandas(booster, data)
435+
436+
437+
def test_arrow_feature_name_auto():
438+
data = generate_dummy_arrow_table()
439+
dataset = lgb.Dataset(
440+
data, label=pa.array([0, 1, 0, 0, 1]), params=dummy_dataset_params(), categorical_feature=["a"]
441+
)
442+
booster = lgb.train({"num_leaves": 7}, dataset, num_boost_round=5)
443+
assert booster.feature_name() == ["a", "b"]
444+
445+
446+
def test_arrow_feature_name_manual():
447+
data = generate_dummy_arrow_table()
448+
dataset = lgb.Dataset(
449+
data,
450+
label=pa.array([0, 1, 0, 0, 1]),
451+
params=dummy_dataset_params(),
452+
feature_name=["c", "d"],
453+
categorical_feature=["c"],
454+
)
455+
booster = lgb.train({"num_leaves": 7}, dataset, num_boost_round=5)
456+
assert booster.feature_name() == ["c", "d"]

0 commit comments

Comments
 (0)