Source code for optuna.integration.lightgbm

import sys

import optuna

    import lightgbm as lgb  # NOQA
    _available = True
except ImportError as e:
    _import_error = e
    # LightGBMPruningCallback is disabled because LightGBM is not available.
    _available = False

# Attach lightgbm API.
if _available:
    # API from optuna integration.
    from optuna.integration import lightgbm_tuner as tuner

    # Workaround for mypy.
    from lightgbm import Dataset  # NOQA
    from optuna.integration.lightgbm_tuner import LightGBMTuner  # NOQA

    _names_from_tuners = ['train', 'LGBMModel', 'LGBMClassifier', 'LGBMRegressor']

    # API from lightgbm.
    for api_name in lgb.__dict__['__all__']:
        if api_name in _names_from_tuners:
        setattr(sys.modules[__name__], api_name, lgb.__dict__[api_name])

    for api_name in _names_from_tuners:
        setattr(sys.modules[__name__], api_name, tuner.__dict__[api_name])
    LightGBMTuner = object  # type: ignore

[docs]class LightGBMPruningCallback(object): """Callback for LightGBM to prune unpromising trials. Example: Add a pruning callback which observes validation scores to training of a LightGBM model. .. code:: param = {'objective': 'binary', 'metric': 'binary_error'} pruning_callback = LightGBMPruningCallback(trial, 'binary_error') gbm = lgb.train(param, dtrain, valid_sets=[dtest], callbacks=[pruning_callback]) Args: trial: A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the objective function. metric: An evaluation metric for pruning, e.g., ``binary_error`` and ``multi_error``. Please refer to `LightGBM reference <>`_ for further details. valid_name: The name of the target validation. Validation names are specified by ``valid_names`` option of `train method <>`_. If omitted, ``valid_0`` is used which is the default name of the first validation. Note that this argument will be ignored if you are calling `cv method <>`_ instead of train method. """ def __init__(self, trial, metric, valid_name='valid_0'): # type: (optuna.trial.Trial, str, str) -> None _check_lightgbm_availability() self._trial = trial self._valid_name = valid_name self._metric = metric def __call__(self, env): # type: (lgb.callback.CallbackEnv) -> None # If this callback has been passed to `` function, # the value of `is_cv` becomes `True`. See also: # # Note that `5` is not the number of folds but the length of sequence. is_cv = len(env.evaluation_result_list) > 0 and len(env.evaluation_result_list[0]) == 5 if is_cv: target_valid_name = 'cv_agg' else: target_valid_name = self._valid_name for evaluation_result in env.evaluation_result_list: valid_name, metric, current_score, is_higher_better = evaluation_result[:4] if valid_name != target_valid_name or metric != self._metric: continue if is_higher_better: if != \ optuna.structs.StudyDirection.MAXIMIZE: raise ValueError( "The intermediate values are inconsistent with the objective values in " "terms of study directions. Please specify a metric to be minimized for " "LightGBMPruningCallback.") else: if != \ optuna.structs.StudyDirection.MINIMIZE: raise ValueError( "The intermediate values are inconsistent with the objective values in " "terms of study directions. Please specify a metric to be maximized for " "LightGBMPruningCallback."), step=env.iteration) if self._trial.should_prune(): message = "Trial was pruned at iteration {}.".format(env.iteration) raise optuna.exceptions.TrialPruned(message) return None raise ValueError( 'The entry associated with the validation name "{}" and the metric name "{}" ' 'is not found in the evaluation result list {}.'.format( target_valid_name, self._metric, str(env.evaluation_result_list)))
def _check_lightgbm_availability(): # type: () -> None if not _available: raise ImportError( 'LightGBM is not available. Please install LightGBM to use this feature. ' 'LightGBM can be installed by executing `$ pip install lightgbm`. ' 'For further information, please refer to the installation guide of LightGBM. ' '(The actual import error is as follows: ' + str(_import_error) + ')')