integrations.base
integrations.base
Base class for all plugins.
A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl. Plugins can be used to integrate third-party models, modify the training process, or add new features.
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
Classes
Name | Description |
---|---|
BaseOptimizerFactory | Base class for factories to create custom optimizers |
BasePlugin | Base class for all plugins. Defines the interface for plugin methods. |
PluginManager | The PluginManager class is responsible for loading and managing plugins. It |
BaseOptimizerFactory
integrations.base.BaseOptimizerFactory()
Base class for factories to create custom optimizers
BasePlugin
integrations.base.BasePlugin()
Base class for all plugins. Defines the interface for plugin methods.
A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl. Plugins can be used to integrate third-party models, modify the training process, or add new features.
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
Note
Plugin methods include: - register(cfg): Registers the plugin with the given configuration. - load_datasets(cfg): Loads and preprocesses the dataset for training. - pre_model_load(cfg): Performs actions before the model is loaded. - post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied. - pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. - post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. - post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters. - post_trainer_create(cfg, trainer): Performs actions after the trainer is created. - create_optimizer(cfg, trainer): Creates and returns an optimizer for training. - create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler. - add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. - add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
Methods
Name | Description |
---|---|
add_callbacks_post_trainer | Adds callbacks to the trainer after creating the trainer. This is useful for |
add_callbacks_pre_trainer | Set up callbacks before creating the trainer. |
create_lr_scheduler | Creates and returns a learning rate scheduler. |
create_optimizer | Creates and returns an optimizer for training. |
get_collator_cls_and_kwargs | Returns a custom class for the collator. |
get_input_args | Returns a pydantic model for the plugin’s input arguments. |
get_trainer_cls | Returns a custom class for the trainer. |
get_training_args | Returns custom training arguments to set on TrainingArgs. |
get_training_args_mixin | Returns a dataclass model for the plugin’s training arguments. |
load_datasets | Loads and preprocesses the dataset for training. |
post_lora_load | Performs actions after LoRA weights are loaded. |
post_model_build | Performs actions after the model is built/loaded, but before any adapters are applied. |
post_model_load | Performs actions after the model is loaded. |
post_train | Performs actions after training is complete. |
post_train_unload | Performs actions after training is complete and the model is unloaded. |
post_trainer_create | Performs actions after the trainer is created. |
pre_lora_load | Performs actions before LoRA weights are loaded. |
pre_model_load | Performs actions before the model is loaded. |
register | Registers the plugin with the given configuration. |
add_callbacks_post_trainer
integrations.base.BasePlugin.add_callbacks_post_trainer(cfg, trainer)
Adds callbacks to the trainer after creating the trainer. This is useful for callbacks that require access to the model or trainer.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
trainer | Trainer | The trainer object for training. | required |
Returns
Name | Type | Description |
---|---|---|
list[Callable] | A list of callback functions to be added |
add_callbacks_pre_trainer
integrations.base.BasePlugin.add_callbacks_pre_trainer(cfg, model)
Set up callbacks before creating the trainer.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
model | PreTrainedModel | The loaded model. | required |
Returns
Name | Type | Description |
---|---|---|
list[Callable] | A list of callback functions to be added to the TrainingArgs . |
create_lr_scheduler
integrations.base.BasePlugin.create_lr_scheduler(
cfg,
trainer,
optimizer,
num_training_steps, )
Creates and returns a learning rate scheduler.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
trainer | Trainer | The trainer object for training. | required |
optimizer | Optimizer | The optimizer for training. | required |
num_training_steps | int | Total number of training steps | required |
Returns
Name | Type | Description |
---|---|---|
LRScheduler | None | The created learning rate scheduler. |
create_optimizer
integrations.base.BasePlugin.create_optimizer(cfg, trainer)
Creates and returns an optimizer for training.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
trainer | Trainer | The trainer object for training. | required |
Returns
Name | Type | Description |
---|---|---|
Optimizer | None | The created optimizer. |
get_collator_cls_and_kwargs
=False) integrations.base.BasePlugin.get_collator_cls_and_kwargs(cfg, is_eval
Returns a custom class for the collator.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The global axolotl configuration. | required |
is_eval | bool | Whether this is an eval split. | False |
Returns
Name | Type | Description |
---|---|---|
class | The class for the collator. |
get_input_args
integrations.base.BasePlugin.get_input_args()
Returns a pydantic model for the plugin’s input arguments.
get_trainer_cls
integrations.base.BasePlugin.get_trainer_cls(cfg)
Returns a custom class for the trainer.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The global axolotl configuration. | required |
Returns
Name | Type | Description |
---|---|---|
Trainer | None | The first non-None trainer class returned by a plugin. |
get_training_args
integrations.base.BasePlugin.get_training_args(cfg)
Returns custom training arguments to set on TrainingArgs.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The global axolotl configuration. | required |
Returns
Name | Type | Description |
---|---|---|
object | dict containing the training arguments. |
get_training_args_mixin
integrations.base.BasePlugin.get_training_args_mixin()
Returns a dataclass model for the plugin’s training arguments.
load_datasets
=False) integrations.base.BasePlugin.load_datasets(cfg, preprocess
Loads and preprocesses the dataset for training.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
preprocess | bool | Whether this is the preprocess step of the datasets. | False |
Returns
Name | Type | Description |
---|---|---|
dataset_meta | Union['TrainDatasetMeta', None] | The metadata for the training dataset. |
post_lora_load
integrations.base.BasePlugin.post_lora_load(cfg, model)
Performs actions after LoRA weights are loaded.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
model | PreTrainedModel | PeftModel | The loaded model. | required |
post_model_build
integrations.base.BasePlugin.post_model_build(cfg, model)
Performs actions after the model is built/loaded, but before any adapters are applied.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
post_model_load
integrations.base.BasePlugin.post_model_load(cfg, model)
Performs actions after the model is loaded.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
model | PreTrainedModel | PeftModel | The loaded model. | required |
post_train
integrations.base.BasePlugin.post_train(cfg, model)
Performs actions after training is complete.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The axolotl configuration. | required |
model | PreTrainedModel | PeftModel | The loaded model. | required |
post_train_unload
integrations.base.BasePlugin.post_train_unload(cfg)
Performs actions after training is complete and the model is unloaded.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
post_trainer_create
integrations.base.BasePlugin.post_trainer_create(cfg, trainer)
Performs actions after the trainer is created.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
trainer | Trainer | The trainer object for training. | required |
pre_lora_load
integrations.base.BasePlugin.pre_lora_load(cfg, model)
Performs actions before LoRA weights are loaded.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
model | PreTrainedModel | The loaded model. | required |
pre_model_load
integrations.base.BasePlugin.pre_model_load(cfg)
Performs actions before the model is loaded.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
register
integrations.base.BasePlugin.register(cfg)
Registers the plugin with the given configuration.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugin. | required |
PluginManager
integrations.base.PluginManager()
The PluginManager
class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.
Attributes
Name | Type | Description |
---|---|---|
plugins | OrderedDict[str, BasePlugin] | A list of loaded plugins. |
Note
Key methods include:
- get_instance(): Static method to get the singleton instance of PluginManager
.
- register(plugin_name: str): Registers a new plugin by its name.
- pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
Methods
Name | Description |
---|---|
add_callbacks_post_trainer | Calls the add_callbacks_post_trainer method of all registered plugins. |
add_callbacks_pre_trainer | Calls the add_callbacks_pre_trainer method of all registered plugins. |
create_lr_scheduler | Calls the create_lr_scheduler method of all registered plugins and returns |
create_optimizer | Calls the create_optimizer method of all registered plugins and returns |
get_collator_cls_and_kwargs | Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class. |
get_input_args | Returns a list of Pydantic classes for all registered plugins’ input arguments.’ |
get_instance | Returns the singleton instance of PluginManager. If the instance doesn’t |
get_trainer_cls | Calls the get_trainer_cls method of all registered plugins and returns the |
get_training_args | Calls the get_training_args method of all registered plugins and returns the combined training arguments. |
get_training_args_mixin | Returns a list of dataclasses for all registered plugins’ training args mixins’ |
load_datasets | Calls the load_datasets method of each registered plugin. |
post_lora_load | Calls the post_lora_load method of all registered plugins. |
post_model_build | Calls the post_model_build method of all registered plugins after the |
post_model_load | Calls the post_model_load method of all registered plugins after the model |
post_train | Calls the post_train method of all registered plugins. |
post_train_unload | Calls the post_train_unload method of all registered plugins. |
post_trainer_create | Calls the post_trainer_create method of all registered plugins. |
pre_lora_load | Calls the pre_lora_load method of all registered plugins. |
pre_model_load | Calls the pre_model_load method of all registered plugins. |
register | Registers a new plugin by its name. |
add_callbacks_post_trainer
integrations.base.PluginManager.add_callbacks_post_trainer(cfg, trainer)
Calls the add_callbacks_post_trainer
method of all registered plugins.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
trainer | Trainer | The trainer object for training. | required |
Returns
Name | Type | Description |
---|---|---|
list[Callable] | A list of callback functions to be added to the TrainingArgs . |
add_callbacks_pre_trainer
integrations.base.PluginManager.add_callbacks_pre_trainer(cfg, model)
Calls the add_callbacks_pre_trainer method of all registered plugins.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
model | PreTrainedModel | The loaded model. | required |
Returns
Name | Type | Description |
---|---|---|
list[Callable] | A list of callback functions to be added to the TrainingArgs . |
create_lr_scheduler
integrations.base.PluginManager.create_lr_scheduler(
trainer,
optimizer,
num_training_steps, )
Calls the create_lr_scheduler
method of all registered plugins and returns
the first non-None
scheduler.
Parameters
Name | Type | Description | Default |
---|---|---|---|
trainer | Trainer | The trainer object for training. | required |
optimizer | Optimizer | The optimizer for training. | required |
Returns
Name | Type | Description |
---|---|---|
LRScheduler | None | The created learning rate scheduler, or None if not found. |
create_optimizer
integrations.base.PluginManager.create_optimizer(trainer)
Calls the create_optimizer
method of all registered plugins and returns
the first non-None
optimizer.
Parameters
Name | Type | Description | Default |
---|---|---|---|
trainer | Trainer | The trainer object for training. | required |
Returns
Name | Type | Description |
---|---|---|
Optimizer | None | The created optimizer, or None if none was found. |
get_collator_cls_and_kwargs
=False) integrations.base.PluginManager.get_collator_cls_and_kwargs(cfg, is_eval
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
Parameters: cfg (dict): The configuration for the plugins. is_eval (bool): Whether this is an eval split.
Returns: object: The collator class, or None if none was found.
get_input_args
integrations.base.PluginManager.get_input_args()
Returns a list of Pydantic classes for all registered plugins’ input arguments.’
Returns
Name | Type | Description |
---|---|---|
list[str] | A list of Pydantic classes for all registered plugins’ input arguments.’ |
get_instance
integrations.base.PluginManager.get_instance()
Returns the singleton instance of PluginManager. If the instance doesn’t exist, it creates a new one.
get_trainer_cls
integrations.base.PluginManager.get_trainer_cls(cfg)
Calls the get_trainer_cls
method of all registered plugins and returns the
first non-None
trainer class.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
Returns
Name | Type | Description |
---|---|---|
Trainer | None | The first non-None trainer class returned by a plugin. |
get_training_args
integrations.base.PluginManager.get_training_args(cfg)
Calls the get_training_args method of all registered plugins and returns the combined training arguments.
Parameters: cfg (dict): The configuration for the plugins.
Returns: object: The training arguments
get_training_args_mixin
integrations.base.PluginManager.get_training_args_mixin()
Returns a list of dataclasses for all registered plugins’ training args mixins’
Returns: list[str]: A list of dataclsses
load_datasets
=False) integrations.base.PluginManager.load_datasets(cfg, preprocess
Calls the load_datasets method of each registered plugin.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
preprocess | bool | Whether this is preprocess step of the datasets. | False |
Returns
Name | Type | Description |
---|---|---|
Union['TrainDatasetMeta', None] | The dataset metadata loaded from all registered plugins. |
post_lora_load
integrations.base.PluginManager.post_lora_load(cfg, model)
Calls the post_lora_load
method of all registered plugins.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
model | PreTrainedModel | PeftModel | The loaded model. | required |
post_model_build
integrations.base.PluginManager.post_model_build(cfg, model)
Calls the post_model_build
method of all registered plugins after the
model has been built / loaded, but before any adapters have been applied.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
model | PreTrainedModel | The loaded model. | required |
post_model_load
integrations.base.PluginManager.post_model_load(cfg, model)
Calls the post_model_load
method of all registered plugins after the model
has been loaded inclusive of any adapters.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
model | PreTrainedModel | PeftModel | The loaded model. | required |
post_train
integrations.base.PluginManager.post_train(cfg, model)
Calls the post_train method of all registered plugins.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
model | PreTrainedModel | PeftModel | The loaded model. | required |
post_train_unload
integrations.base.PluginManager.post_train_unload(cfg)
Calls the post_train_unload method of all registered plugins.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
post_trainer_create
integrations.base.PluginManager.post_trainer_create(cfg, trainer)
Calls the post_trainer_create
method of all registered plugins.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
trainer | Trainer | The trainer object for training. | required |
pre_lora_load
integrations.base.PluginManager.pre_lora_load(cfg, model)
Calls the pre_lora_load
method of all registered plugins.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
model | PreTrainedModel | The loaded model. | required |
pre_model_load
integrations.base.PluginManager.pre_model_load(cfg)
Calls the pre_model_load method of all registered plugins.
Parameters
Name | Type | Description | Default |
---|---|---|---|
cfg | DictDefault | The configuration for the plugins. | required |
register
integrations.base.PluginManager.register(plugin_name)
Registers a new plugin by its name.
Parameters
Name | Type | Description | Default |
---|---|---|---|
plugin_name | str | The name of the plugin to be registered. | required |
Raises
Name | Type | Description |
---|---|---|
ImportError | If the plugin module cannot be imported. |
Functions
Name | Description |
---|---|
load_plugin | Loads a plugin based on the given plugin name. |
load_plugin
integrations.base.load_plugin(plugin_name)
Loads a plugin based on the given plugin name.
The plugin name should be in the format “module_name.class_name”. This function splits the plugin name into module and class, imports the module, retrieves the class from the module, and creates an instance of the class.
Parameters
Name | Type | Description | Default |
---|---|---|---|
plugin_name | str | The name of the plugin to be loaded. The name should be in the format “module_name.class_name”. | required |
Returns
Name | Type | Description |
---|---|---|
BasePlugin | An instance of the loaded plugin. |
Raises
Name | Type | Description |
---|---|---|
ImportError | If the plugin module cannot be imported. |