Source code for optuna.integration.allennlp

import json
import os
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import warnings

import optuna
from optuna._experimental import experimental

try:
    import _jsonnet
    import allennlp.commands
    import allennlp.common.util

    _available = True
except ImportError as e:
    _import_error = e
    _available = False
    TrackerCallback = object


[docs]def dump_best_config(input_config_file: str, output_config_file: str, study: optuna.Study) -> None: """Save jsonnet after updating with parameters from the best trial in the study. Args: input_config_file: Input configuration file used with :class:`~optuna.integration.AllenNLPExecutor`. output_config_file: Output configuration file. study: An optimized study (``study.best_trial`` does not raise an error). """ best_params = study.best_params for key, value in best_params.items(): best_params[key] = str(value) best_config = json.loads(_jsonnet.evaluate_file(input_config_file, ext_vars=best_params)) best_config = allennlp.common.params.infer_and_cast(best_config) with open(output_config_file, "w") as f: json.dump(best_config, f, indent=4)
[docs]@experimental("1.4.0") class AllenNLPExecutor(object): """AllenNLP extension to use optuna with a jsonnet config file. This feature is experimental since AllenNLP major release will come soon. The interface may change without prior notice to correspond to the update. See the examples of `objective function <https://github.com/optuna/optuna/blob/ master/examples/allennlp/allennlp_jsonnet.py>`_ and `config file <https://github.com/optuna/optuna/blob/master/ examples/allennlp/classifier.jsonnet>`_. Args: trial: A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the objective function. config_file: Config file for AllenNLP. Hyperparameters should be masked with ``std.extVar``. Please refer to `the config example <https://github.com/allenai/allentune/blob/ master/examples/classifier.jsonnet>`_. serialization_dir: A path which model weights and logs are saved. metrics: An evaluation metric for the result of ``objective``. include_package: Additional packages to include. For more information, please see `AllenNLP documentation <https://docs.allennlp.org/master/api/commands/train/>`_. """ def __init__( self, trial: optuna.Trial, config_file: str, serialization_dir: str, metrics: str = "best_validation_accuracy", *, include_package: Union[str, List[str]] = [] ): _check_allennlp_availability() self._params = trial.params self._config_file = config_file self._serialization_dir = serialization_dir self._metrics = metrics if isinstance(include_package, str): self._include_package = [include_package] else: self._include_package = include_package def _build_params(self) -> Dict[str, Any]: """Create a dict of params for AllenNLP.""" # _build_params is based on allentune's train_func. # https://github.com/allenai/allentune/blob/master/allentune/modules/allennlp_runner.py#L34-L65 for key, value in self._params.items(): self._params[key] = str(value) _params = json.loads(_jsonnet.evaluate_file(self._config_file, ext_vars=self._params)) # _params contains a list of string or string as value values. # Some params couldn't be casted correctly and # infer_and_cast converts them into desired values. return allennlp.common.params.infer_and_cast(_params)
[docs] def run(self) -> float: """Train a model using AllenNLP.""" try: import_func = allennlp.common.util.import_submodules except AttributeError: import_func = allennlp.common.util.import_module_and_submodules warnings.warn("AllenNLP>0.9 has not been supported officially yet.") for package_name in self._include_package: import_func(package_name) params = allennlp.common.params.Params(self._build_params()) allennlp.commands.train.train_model(params, self._serialization_dir) metrics = json.load(open(os.path.join(self._serialization_dir, "metrics.json"))) return metrics[self._metrics]
def _check_allennlp_availability() -> None: if not _available: raise ImportError( "AllenNLP is not available. Please install AllenNLP to use this feature. " "AllenNLP can be installed by executing `$ pip install allennlp`. " "For further information, please refer to the installation guide of AllenNLP. " "(The actual import error is as follows: " + str(_import_error) + ")" )