Training Run: marin-8b-instruct-orpo
This document outlines the configuration and parameters used for training the model marin-8b-instruct-orpo using the EasyDeL library.
EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models, with a primary focus on JAX/Flax for TPU/GPU environments.
How to Load This Checkpoint
You can load the checkpoint generated from this training run using EasyDeL as follows:
import easydel as ed
from jax import numpy as jnp, lax
# Path to the directory where this README.md is located
repo_id = "user/model-id" # <-- TODO: Update this path with the actual save directory or model repo
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
repo_id,
config_kwargs=EasyDeLBaseConfigDict(
# use_scan_mlp=False, # Set to True to potentially reduce memory usage
attn_dtype=jnp.float16, # Or jnp.bfloat16
# freq_max_position_embeddings=max_length, # Set if using RoPE and need truncation
# mask_max_position_embeddings=max_length, # Set if max length is defined
attn_mechanism=ed.AttentionMechanisms.SPLASH # Matches the mechanism used by this model
),
dtype=jnp.float16, # Or jnp.bfloat16 - Computation data type
param_dtype=jnp.float16, # Or jnp.bfloat16 - Parameter data type
precision=lax.Precision("fastest"), # Like "default", "fastest", "high", "highest"
auto_shard_model=True, # Auto-shard across available devices
)
Note: Replace checkpoint_path with the actual path to the saved checkpoint directory.
The params returned are ready to be used with the model.
Training Configuration Summary
Model & Hardware
- Model Name (Run Name):
marin-8b-instruct-orpo - Base Model Architecture:
llama - Platform:
TPU - Number of Devices Used:
4(total),4(local) - EasyDeL Version:
v0.1.5
Key Training Parameters
- Learning Rate (Start โ End):
8e-07 - Optimizer:
EasyDeLOptimizers.ADAMW - Scheduler:
EasyDeLSchedulers.COSINE - Warmup Steps:
0 - Weight Decay:
0.01 - Loss Configuration:
LossConfig( ignore_index : -100 label_smoothing : 0.0 z_loss : 0.0 loss_normalizing_factor : SpecialLossNormalizingFactor.NO_WEIGHT_NUM_REAL_TARGET_TOKENS num_labels : None problem_type : None divide_weight_sum : False shift_tokens : True break_on_nan : True reduction : None num_classification_labels : None classification_problem_type : None )
Data & Batching
- Number of Training Epochs:
8 - Total Batch Size (per step):
4 - Maximum Sequence Length:
4096 - Gradient Accumulation Steps:
1
Datatypes & Precision
- Computation
dtype:<class 'jax.numpy.bfloat16'> - Parameter
param_dtype:<class 'jax.numpy.bfloat16'> - Gradient Checkpointing Method:
EasyDeLGradientCheckPointers.NOTHING_SAVEABLE - Attention Mechanism Used in Training:
splash(can be loaded asAttentionMechanisms.SPLASHif usingEasyDeLConfig)
Run Control
- Max Training Steps:
Not Set - Max Evaluation Steps:
Not Set - Training Time Limit:
Not Set
Citation
If you use EasyDeL in your research or work, please cite it:
@misc{Zare Chavoshi_2023,
title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
url={https://github.com/erfanzar/EasyDeL},
author={Zare Chavoshi, Erfan},
year={2023}
}
This document was automatically generated by EasyDeL v0.1.5 during the training run.
- Downloads last month
- 2
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support