Skip to content

Commit 1090a93

Browse files
authored
[python-package] do not copy column-major numpy arrays when predicting (#6751)
1 parent b33a12e commit 1090a93

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

python-package/lightgbm/basic.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1291,10 +1291,7 @@ def __inner_predict_np2d(
12911291
predict_type: int,
12921292
preds: Optional[np.ndarray],
12931293
) -> Tuple[np.ndarray, int]:
1294-
if mat.dtype == np.float32 or mat.dtype == np.float64:
1295-
data = np.asarray(mat.reshape(mat.size), dtype=mat.dtype)
1296-
else: # change non-float data to float data, need to copy
1297-
data = np.array(mat.reshape(mat.size), dtype=np.float32)
1294+
data, layout = _np2d_to_np1d(mat)
12981295
ptr_data, type_ptr_data, _ = _c_float_array(data)
12991296
n_preds = self.__get_num_preds(
13001297
start_iteration=start_iteration,
@@ -1314,7 +1311,7 @@ def __inner_predict_np2d(
13141311
ctypes.c_int(type_ptr_data),
13151312
ctypes.c_int32(mat.shape[0]),
13161313
ctypes.c_int32(mat.shape[1]),
1317-
ctypes.c_int(_C_API_IS_ROW_MAJOR),
1314+
ctypes.c_int(layout),
13181315
ctypes.c_int(predict_type),
13191316
ctypes.c_int(start_iteration),
13201317
ctypes.c_int(num_iteration),

tests/python_package_test/test_engine.py

+15
Original file line numberDiff line numberDiff line change
@@ -4611,3 +4611,18 @@ def test_bagging_by_query_in_lambdarank():
46114611
ndcg_score_no_bagging_by_query = gbm_no_bagging_by_query.best_score["valid_0"]["ndcg@5"]
46124612
assert ndcg_score_bagging_by_query >= ndcg_score - 0.1
46134613
assert ndcg_score_no_bagging_by_query >= ndcg_score - 0.1
4614+
4615+
4616+
def test_equal_predict_from_row_major_and_col_major_data():
4617+
X_row, y = make_synthetic_regression()
4618+
assert X_row.flags["C_CONTIGUOUS"] and not X_row.flags["F_CONTIGUOUS"]
4619+
ds = lgb.Dataset(X_row, y)
4620+
params = {"num_leaves": 8, "verbose": -1}
4621+
bst = lgb.train(params, ds, num_boost_round=5)
4622+
preds_row = bst.predict(X_row)
4623+
4624+
X_col = np.asfortranarray(X_row)
4625+
assert X_col.flags["F_CONTIGUOUS"] and not X_col.flags["C_CONTIGUOUS"]
4626+
preds_col = bst.predict(X_col)
4627+
4628+
np.testing.assert_allclose(preds_row, preds_col)

0 commit comments

Comments
 (0)