"""This module contains njitted routines and data structures to:
- Find the best possible split of a node. For a given node, a split is
characterized by a feature and a bin.
- Apply a split to a node, i.e. split the indices of the samples at the node
into the newly created left and right childs.
"""
import numpy as np
from numba import njit, jitclass, prange, float32, uint8, uint32
import numba
from .histogram import _build_histogram
from .histogram import _subtract_histograms
from .histogram import _build_histogram_no_hessian
from .histogram import _build_histogram_root
from .histogram import _build_histogram_root_no_hessian
from .histogram import HISTOGRAM_DTYPE
from .utils import get_threads_chunks
@jitclass([
('gain', float32),
('feature_idx', uint32),
('bin_idx', uint8),
('gradient_left', float32),
('hessian_left', float32),
('gradient_right', float32),
('hessian_right', float32),
('n_samples_left', uint32),
('n_samples_right', uint32),
])
class SplitInfo:
"""Pure data class to store information about a potential split.
Parameters
----------
gain : float32
The gain of the split
feature_idx : int
The index of the feature to be split
bin_idx : int
The index of the bin on which the split is made
gradient_left : float32
The sum of the gradients of all the samples in the left child
hessian_left : float32
The sum of the hessians of all the samples in the left child
gradient_right : float32
The sum of the gradients of all the samples in the right child
hessian_right : float32
The sum of the hessians of all the samples in the right child
n_samples_left : int
The number of samples in the left child
n_samples_right : int
The number of samples in the right child
"""
def __init__(self, gain=-1., feature_idx=0, bin_idx=0,
gradient_left=0., hessian_left=0.,
gradient_right=0., hessian_right=0.,
n_samples_left=0, n_samples_right=0):
self.gain = gain
self.feature_idx = feature_idx
self.bin_idx = bin_idx
self.gradient_left = gradient_left
self.hessian_left = hessian_left
self.gradient_right = gradient_right
self.hessian_right = hessian_right
self.n_samples_left = n_samples_left
self.n_samples_right = n_samples_right
@jitclass([
('n_features', uint32),
('X_binned', uint8[::1, :]),
('max_bins', uint32),
('n_bins_per_feature', uint32[::1]),
('min_samples_leaf', uint32),
('min_gain_to_split', float32),
('gradients', float32[::1]),
('hessians', float32[::1]),
('ordered_gradients', float32[::1]),
('ordered_hessians', float32[::1]),
('sum_gradients', float32),
('sum_hessians', float32),
('constant_hessian', uint8),
('constant_hessian_value', float32),
('l2_regularization', float32),
('min_hessian_to_split', float32),
('partition', uint32[::1]),
('left_indices_buffer', uint32[::1]),
('right_indices_buffer', uint32[::1]),
])
class SplittingContext:
"""Pure data class defining a splitting context.
Ideally it would also have methods but numba does not support annotating
jitclasses (so we can't use parallel=True). This structure is
instanciated in the grower and stores all the required information to
compute the SplitInfo and histograms of each node.
Parameters
----------
X_binned : array of int
The binned input samples. Must be Fortran-aligned.
max_bins : int, optional(default=256)
The maximum number of bins. Used to define the shape of the
histograms.
n_bins_per_feature : array-like of int
The actual number of bins needed for each feature, which is lower or
equal to max_bins.
gradients : array-like, shape=(n_samples,)
The gradients of each training sample. Those are the gradients of the
loss w.r.t the predictions, evaluated at iteration i - 1.
hessians : array-like, shape=(n_samples,)
The hessians of each training sample. Those are the hessians of the
loss w.r.t the predictions, evaluated at iteration i - 1.
l2_regularization : float
The L2 regularization parameter.
min_hessian_to_split : float
The minimum sum of hessians needed in each node. Splits that result in
at least one child having a sum of hessians less than
min_hessian_to_split are discarded.
min_samples_leaf : int
The minimum number of samples per leaf.
min_gain_to_split : float, optional(default=0.)
The minimum gain needed to split a node. Splits with lower gain will
be ignored.
"""
def __init__(self, X_binned, max_bins, n_bins_per_feature,
gradients, hessians, l2_regularization,
min_hessian_to_split=1e-3, min_samples_leaf=20,
min_gain_to_split=0.):
self.X_binned = X_binned
self.n_features = X_binned.shape[1]
# Note: all histograms will have <max_bins> bins, but some of the
# last bins may be unused if n_bins_per_feature[f] < max_bins
self.max_bins = max_bins
self.n_bins_per_feature = n_bins_per_feature
self.gradients = gradients
self.hessians = hessians
# for root node, gradients and hessians are already ordered
self.ordered_gradients = gradients.copy()
self.ordered_hessians = hessians.copy()
self.sum_gradients = self.gradients.sum()
self.sum_hessians = self.hessians.sum()
self.constant_hessian = hessians.shape[0] == 1
self.l2_regularization = l2_regularization
self.min_hessian_to_split = min_hessian_to_split
self.min_samples_leaf = min_samples_leaf
self.min_gain_to_split = min_gain_to_split
if self.constant_hessian:
self.constant_hessian_value = self.hessians[0] # 1 scalar
else:
self.constant_hessian_value = float32(1.) # won't be used anyway
# The partition array maps each sample index into the leaves of the
# tree (a leaf in this context is a node that isn't splitted yet, not
# necessarily a 'finalized' leaf). Initially, the root contains all
# the indices, e.g.:
# partition = [abcdefghijkl]
# After a call to split_indices, it may look e.g. like this:
# partition = [cef|abdghijkl]
# we have 2 leaves, the left one is at position 0 and the second one at
# position 3. The order of the samples is irrelevant.
self.partition = np.arange(0, X_binned.shape[0], 1, np.uint32)
# buffers used in split_indices to support parallel splitting.
self.left_indices_buffer = np.empty_like(self.partition)
self.right_indices_buffer = np.empty_like(self.partition)
[docs]@njit(parallel=True,
locals={'sample_idx': uint32,
'left_count': uint32,
'right_count': uint32})
def split_indices(context, split_info, sample_indices):
"""Split samples into left and right arrays.
Parameters
----------
context : SplittingContext
The splitting context
split_ingo : SplitInfo
The SplitInfo of the node to split
sample_indices : array of int
The indices of the samples at the node to split. This is a view on
context.partition, and it is modified inplace by placing the indices
of the left child at the beginning, and the indices of the right child
at the end.
Returns
-------
left_indices : array of int
The indices of the samples in the left child. This is a view on
context.partition.
right_indices : array of int
The indices of the samples in the right child. This is a view on
context.partition.
"""
# This is a multi-threaded implementation inspired by lightgbm.
# Here is a quick break down. Let's suppose we want to split a node with
# 24 samples named from a to x. context.partition looks like this (the *
# are indices in other leaves that we don't care about):
# partition = [*************abcdefghijklmnopqrstuvwx****************]
# ^ ^
# node_position node_position + node.n_samples
# Ultimately, we want to reorder the samples inside the boundaries of the
# leaf (which becomes a node) to now represent the samples in its left and
# right child. For example:
# partition = [*************abefilmnopqrtuxcdghjksvw*****************]
# ^ ^
# left_child_pos right_child_pos
# Note that left_child_pos always takes the value of node_position, and
# right_child_pos = left_child_pos + left_child.n_samples. The order of
# the samples inside a leaf is irrelevant.
# 1. samples_indices is a view on this region a..x. We conceptually
# divide it into n_threads regions. Each thread will be responsible for
# its own region. Here is an example with 4 threads:
# samples_indices = [abcdef|ghijkl|mnopqr|stuvwx]
# 2. Each thread processes 6 = 24 // 4 entries and maps them into
# left_indices_buffer or right_indices_buffer. For example, we could
# have the following mapping ('.' denotes an undefined entry):
# - left_indices_buffer = [abef..|il....|mnopqr|tux...]
# - right_indices_buffer = [cd....|ghjk..|......|svw...]
# 3. We keep track of the start positions of the regions (the '|') in
# ``offset_in_buffers`` as well as the size of each region. We also keep
# track of the number of samples put into the left/right child by each
# thread. Concretely:
# - left_counts = [4, 2, 6, 3]
# - right_counts = [2, 4, 0, 3]
# 4. Finally, we put left/right_indices_buffer back into the
# samples_indices, without any undefined entries and the partition looks
# as expected
# partition = [*************abefilmnopqrtuxcdghjksvw*****************]
# Note: We here show left/right_indices_buffer as being the same size as
# sample_indices for simplicity, but in reality they are of the same size
# as partition.
X_binned = context.X_binned.T[split_info.feature_idx]
n_threads = numba.config.NUMBA_DEFAULT_NUM_THREADS
n_samples = sample_indices.shape[0]
# Note: we could probably allocate all the arrays of size n_threads in the
# splitting context as well, but gains are probably going to be minimal
sizes = np.full(n_threads, n_samples // n_threads, dtype=np.int32)
if n_samples % n_threads > 0:
# array[:0] will cause a bug in numba 0.41 so we need the if. Remove
# once issue numba 3554 is fixed.
sizes[:n_samples % n_threads] += 1
offset_in_buffers = np.zeros(n_threads, dtype=np.int32)
offset_in_buffers[1:] = np.cumsum(sizes[:-1])
left_counts = np.empty(n_threads, dtype=np.int32)
right_counts = np.empty(n_threads, dtype=np.int32)
# Need to declare local variables, else they're not updated :/
# (see numba issue 3459)
left_indices_buffer = context.left_indices_buffer
right_indices_buffer = context.right_indices_buffer
# map indices from samples_indices to left/right_indices_buffer
for thread_idx in prange(n_threads):
left_count = 0
right_count = 0
start = offset_in_buffers[thread_idx]
stop = start + sizes[thread_idx]
for i in range(start, stop):
sample_idx = sample_indices[i]
if X_binned[sample_idx] <= split_info.bin_idx:
left_indices_buffer[start + left_count] = sample_idx
left_count += 1
else:
right_indices_buffer[start + right_count] = sample_idx
right_count += 1
left_counts[thread_idx] = left_count
right_counts[thread_idx] = right_count
# position of right child = just after the left child
right_child_position = left_counts.sum()
# offset of each thread in samples_indices for left and right child, i.e.
# where each thread will start to write.
left_offset = np.zeros(n_threads, dtype=np.int32)
left_offset[1:] = np.cumsum(left_counts[:-1])
right_offset = np.full(n_threads, right_child_position, dtype=np.int32)
right_offset[1:] += np.cumsum(right_counts[:-1])
# map indices in left/right_indices_buffer back into samples_indices. This
# also updates context.partition since samples_indice is a view.
for thread_idx in prange(n_threads):
for i in range(left_counts[thread_idx]):
sample_indices[left_offset[thread_idx] + i] = \
left_indices_buffer[offset_in_buffers[thread_idx] + i]
for i in range(right_counts[thread_idx]):
sample_indices[right_offset[thread_idx] + i] = \
right_indices_buffer[offset_in_buffers[thread_idx] + i]
return (sample_indices[:right_child_position],
sample_indices[right_child_position:])
[docs]@njit(parallel=True)
def find_node_split(context, sample_indices):
"""For each feature, find the best bin to split on at a given node.
Returns the best split info among all features, and the histograms of
all the features. The histograms are computed by scanning the whole
data.
Parameters
----------
context : SplittingContext
The splitting context
sample_indices : array of int
The indices of the samples at the node to split.
Returns
-------
best_split_info : SplitInfo
The info about the best possible split among all features.
histograms : array of HISTOGRAM_DTYPE, shape=(n_features, max_bins)
The histograms of each feature. A histogram is an array of
HISTOGRAM_DTYPE of size ``max_bins`` (only
``n_bins_per_features[feature]`` entries are relevant).
"""
ctx = context # shorter name to avoid various line breaks
n_samples = sample_indices.shape[0]
# Need to declare local variables, else they're not updated
# (see numba issue 3459)
ordered_gradients = ctx.ordered_gradients
ordered_hessians = ctx.ordered_hessians
# Populate ordered_gradients and ordered_hessians. (Already done for root)
# Ordering the gradients and hessians helps to improve cache hit.
# This is a parallelized version of the following vanilla code:
# for i range(n_samples):
# ctx.ordered_gradients[i] = ctx.gradients[samples_indices[i]]
if sample_indices.shape[0] != ctx.gradients.shape[0]:
starts, ends, n_threads = get_threads_chunks(n_samples)
if ctx.constant_hessian:
for thread_idx in prange(n_threads):
for i in range(starts[thread_idx], ends[thread_idx]):
ordered_gradients[i] = ctx.gradients[sample_indices[i]]
else:
for thread_idx in prange(n_threads):
for i in range(starts[thread_idx], ends[thread_idx]):
ordered_gradients[i] = ctx.gradients[sample_indices[i]]
ordered_hessians[i] = ctx.hessians[sample_indices[i]]
ctx.sum_gradients = ctx.ordered_gradients[:n_samples].sum()
if ctx.constant_hessian:
ctx.sum_hessians = ctx.constant_hessian_value * float32(n_samples)
else:
ctx.sum_hessians = ctx.ordered_hessians[:n_samples].sum()
# Pre-allocate the results datastructure to be able to use prange:
# numba jitclass do not seem to properly support default values for kwargs.
split_infos = [SplitInfo(-1., 0, 0, 0., 0., 0., 0., 0, 0)
for i in range(context.n_features)]
histograms = np.empty(
shape=(np.int64(context.n_features), np.int64(context.max_bins)),
dtype=HISTOGRAM_DTYPE
)
for feature_idx in prange(context.n_features):
split_info, histogram = _find_histogram_split(
context, feature_idx, sample_indices)
split_infos[feature_idx] = split_info
histograms[feature_idx, :] = histogram
split_info = _find_best_feature_to_split_helper(split_infos)
return split_info, histograms
[docs]@njit(parallel=True)
def find_node_split_subtraction(context, sample_indices, parent_histograms,
sibling_histograms):
"""For each feature, find the best bin to split on at a given node.
Returns the best split info among all features, and the histograms of
all the features.
This does the same job as ``find_node_split()`` but uses the histograms
of the parent and sibling of the node to split. This allows to use the
identity: ``histogram(parent) = histogram(node) - histogram(sibling)``,
which is significantly faster than computing the histograms from data.
Returns the best SplitInfo among all features, along with all the feature
histograms that can be latter used to compute the sibling or children
histograms by substraction.
Parameters
----------
context : SplittingContext
The splitting context
sample_indices : array of int
The indices of the samples at the node to split.
parent_histograms : array of HISTOGRAM_DTYPE of shape(n_features, max_bins)
The histograms of the parent
sibling_histograms : array of HISTOGRAM_DTYPE of \
shape(n_features, max_bins)
The histograms of the sibling
Returns
-------
best_split_info : SplitInfo
The info about the best possible split among all features.
histograms : array of HISTOGRAM_DTYPE, shape=(n_features, max_bins)
The histograms of each feature. A histogram is an array of
HISTOGRAM_DTYPE of size ``max_bins`` (only
``n_bins_per_features[feature]`` entries are relevant).
"""
# We can pick any feature (here the first) in the histograms to
# compute the gradients: they must be the same across all features
# anyway, we have tests ensuring this. Maybe a more robust way would
# be to compute an average but it's probably not worth it.
context.sum_gradients = (parent_histograms[0]['sum_gradients'].sum() -
sibling_histograms[0]['sum_gradients'].sum())
n_samples = sample_indices.shape[0]
if context.constant_hessian:
context.sum_hessians = \
context.constant_hessian_value * float32(n_samples)
else:
context.sum_hessians = (parent_histograms[0]['sum_hessians'].sum() -
sibling_histograms[0]['sum_hessians'].sum())
# Pre-allocate the results datastructure to be able to use prange
split_infos = [SplitInfo(-1., 0, 0, 0., 0., 0., 0., 0, 0)
for i in range(context.n_features)]
histograms = np.empty(
shape=(np.int64(context.n_features), np.int64(context.max_bins)),
dtype=HISTOGRAM_DTYPE
)
for feature_idx in prange(context.n_features):
split_info, histogram = _find_histogram_split_subtraction(
context, feature_idx, parent_histograms,
sibling_histograms, n_samples)
split_infos[feature_idx] = split_info
histograms[feature_idx, :] = histogram
split_info = _find_best_feature_to_split_helper(split_infos)
return split_info, histograms
@njit
def _find_best_feature_to_split_helper(split_infos):
best_gain = None
for i, split_info in enumerate(split_infos):
gain = split_info.gain
if best_gain is None or gain > best_gain:
best_gain = gain
best_split_info = split_info
return best_split_info
@njit(fastmath=True)
def _find_histogram_split(context, feature_idx, sample_indices):
"""Compute the histogram for a given feature
Returns the best SplitInfo among all the possible bins of the feature.
"""
n_samples = sample_indices.shape[0]
X_binned = context.X_binned.T[feature_idx]
root_node = X_binned.shape[0] == n_samples
ordered_gradients = context.ordered_gradients[:n_samples]
ordered_hessians = context.ordered_hessians[:n_samples]
if root_node:
if context.constant_hessian:
histogram = _build_histogram_root_no_hessian(
context.max_bins, X_binned, ordered_gradients)
else:
histogram = _build_histogram_root(
context.max_bins, X_binned, ordered_gradients,
context.ordered_hessians)
else:
if context.constant_hessian:
histogram = _build_histogram_no_hessian(
context.max_bins, sample_indices, X_binned,
ordered_gradients)
else:
histogram = _build_histogram(
context.max_bins, sample_indices, X_binned,
ordered_gradients, ordered_hessians)
return _find_best_bin_to_split_helper(context, feature_idx, histogram,
n_samples)
@njit(fastmath=True)
def _find_histogram_split_subtraction(context, feature_idx,
parent_histograms, sibling_histograms,
n_samples):
"""Compute the histogram by substraction of parent and sibling
Uses the identity: hist(parent) = hist(left) + hist(right).
Returns the best SplitInfo among all the possible bins of the feature.
"""
histogram = _subtract_histograms(
context.max_bins,
parent_histograms[feature_idx], sibling_histograms[feature_idx])
return _find_best_bin_to_split_helper(context, feature_idx, histogram,
n_samples)
@njit(locals={'gradient_left': float32, 'hessian_left': float32,
'n_samples_left': uint32},
fastmath=True)
def _find_best_bin_to_split_helper(context, feature_idx, histogram, n_samples):
"""Find best bin to split on, and return the corresponding SplitInfo.
Splits that do not satisfy the splitting constraints (min_gain_to_split,
etc.) are discarded here. If no split can satisfy the constraints, a
SplitInfo with a gain of -1 is returned. If for a given node the best
SplitInfo has a gain of -1, it is finalized into a leaf.
"""
# Allocate the structure for the best split information. It can be
# returned as such (with a negative gain) if the min_hessian_to_split
# condition is not satisfied. Such invalid splits are later discarded by
# the TreeGrower.
best_split = SplitInfo(-1., 0, 0, 0., 0., 0., 0., 0, 0)
gradient_left, hessian_left = 0., 0.
n_samples_left = 0
for bin_idx in range(context.n_bins_per_feature[feature_idx]):
n_samples_left += histogram[bin_idx]['count']
n_samples_right = n_samples - n_samples_left
if context.constant_hessian:
hessian_left += (histogram[bin_idx]['count']
* context.constant_hessian_value)
else:
hessian_left += histogram[bin_idx]['sum_hessians']
hessian_right = context.sum_hessians - hessian_left
gradient_left += histogram[bin_idx]['sum_gradients']
gradient_right = context.sum_gradients - gradient_left
if n_samples_left < context.min_samples_leaf:
continue
if n_samples_right < context.min_samples_leaf:
# won't get any better
break
if hessian_left < context.min_hessian_to_split:
continue
if hessian_right < context.min_hessian_to_split:
# won't get any better (hessians are > 0 since loss is convex)
break
gain = _split_gain(gradient_left, hessian_left,
gradient_right, hessian_right,
context.sum_gradients, context.sum_hessians,
context.l2_regularization)
if gain > best_split.gain and gain > context.min_gain_to_split:
best_split.gain = gain
best_split.feature_idx = feature_idx
best_split.bin_idx = bin_idx
best_split.gradient_left = gradient_left
best_split.hessian_left = hessian_left
best_split.n_samples_left = n_samples_left
best_split.gradient_right = gradient_right
best_split.hessian_right = hessian_right
best_split.n_samples_right = n_samples_right
return best_split, histogram
@njit(fastmath=False)
def _split_gain(gradient_left, hessian_left, gradient_right, hessian_right,
sum_gradients, sum_hessians, l2_regularization):
"""Loss reduction
Compute the reduction in loss after taking a split, compared to keeping
the node a leaf of the tree.
See Equation 7 of:
XGBoost: A Scalable Tree Boosting System, T. Chen, C. Guestrin, 2016
https://arxiv.org/abs/1603.02754
"""
def negative_loss(gradient, hessian):
return (gradient ** 2) / (hessian + l2_regularization)
gain = negative_loss(gradient_left, hessian_left)
gain += negative_loss(gradient_right, hessian_right)
gain -= negative_loss(sum_gradients, sum_hessians)
return gain