Arcus Documentation
Back to homepage

Model Training

Using Arcus Model Enrichment, you connected the model and first-party data to valuable external signals and data that enriched the ML model with additional context and helped it make better predictions. To do so, you first configured the model and ran a trial to understand how the model performs against different external data candidates. Then based on the results of this trial, you selected the external data candidate that best suits your needs.

Once this selection is made, you’re now free to use the enriched model in your training workflow. Using the Arcus Model SDK, you wrap the model and training loop, which keeps the original model and workflow intact while being able to consume external data from the platform.

Using the Pytorch Lightning setup used in the trial and selection sections, now you can train the model with the data selection.

Setup

Similar to the previous sections, you define the Arcus Config and wrap your original model with the Arcus Model wrapper.

import arcus

# Initialize the original Pytorch model
my_model = MyModel()

# Provide the configuration
arcus_config = arcus.model.shared.Config(
  api_key='MY_API_KEY',
  project_id='MY_PROJECT_ID',
)

# Wrap the model with Arcus
arcus_model = arcus.model.torch.Model(my_model, arcus_config)

Training

Using Pytorch Lightning, the model was originally trained with the Pytorch Lightning Trainer class and a MyLightningModule class, which contains the model’s training and validation loops. The original model training code might look like the following:

my_lightning_module = MyLightningModule(my_model)

my_trainer = pl.Trainer()

trainer.fit(
  my_module,
  train_dataloader,
  val_dataloader
)

To train the model with the data selection that was made, you change those three lines to use the Arcus Trainer class. You use the same MyLightningModule class, but with the arcus_model which communicates with the Arcus Data Platform. The arcus.model.torch.Trainer class contains all of the existing functionality of the Pytorch Lightning Trainer class but integrates with the Arcus Data Platform to enrich the model during training.

arcus_module = MyLightningModule(arcus_model)

arcus_trainer = arcus.model.torch.Trainer()

arcus_trainer.fit(
  arcus_module,
  train_dataloader,
  val_dataloader
)

Note: you can only run this code only after you’ve made a data selection on the Arcus platform. Otherwise, your model will not be enriched with external data and instead only use your first-party data.

Serving a model uses an identical workflow, which integrates the model with the platform to enrich the model’s predictions with external data.