Checkpoint modification
This document describes the checkpoint modification procedure required when adding your own checkpoints to the SambaStudio Model Hub. Follow the procedure when using the Import a checkpoint using the CLI or Add a checkpoint from storage using the GUI workflows of adding your own checkpoint to the Model Hub.
Currently, checkpoint modification is only required for the following models:
|
Overview
SambaNova software stack has certain compiler optimizations that require symmetric splitting/sharding of weights to enable it to run fast on the SambaStudio platform. For some models, we need to pad or duplicate weight tensors in the checkpoints.
This is a straightforward process with a one-time cost for each checkpoint before adding checkpoints to the Model Hub. Given a checkpoint which follows the original config.json, we can modify it once using the script provided below, and simply use the modified checkpoint with the RDU.
Steps to do checkpoint modification
-
Confirm that the config.json file of your current checkpoint is the same as the original, e.g., the original config.json file for Qwen2.5 0.5B can be found on Hugging Face .
-
Copy over the updated config.json file for the model from the updated configs section and store it as a .json file. These updated configs have the modifications that we need to make to the checkpoint and will be used by the provided script to modify the checkpoint.
-
Copy over and store the python script provided in the Modification script section as a .py file.
-
In the input section of the python script, update the location of your checkpoint, preferred location for modified checkpoint, original config.json file (i.e. config file from Hugging Face as described in the model names of the Updated configurations section), and updated config.json file (copied over in Step 2).
-
Run the python script, and the modified checkpoint should be stored at the specified folder.
Modification script
See the python script below.
import transformers
import torch
import tqdm
import json
####### INPUT SECTION ##################################################################
original_ckpt_location = "Qwen2.5-Coder-0.5B-Instruct/" # Location of the original checkpoint
updated_ckpt_location = "Qwen2.5-Coder-0.5B-Instruct_surgery/" # Location of the updated checkpoint
original_config_location = "0.5B_config.json" # Location of the original config.json
updated_config_location = "0.5b_config_surgery.json" # Location of the updated config.json
####### END OF INPUT SECTION ###########################################################
# Load the original checkpoints
model = transformers.AutoModelForCausalLM.from_pretrained(original_ckpt_location)
tokenizer = transformers.AutoTokenizer.from_pretrained(original_ckpt_location)
# Load the original config
with open(original_config_location, "r") as f:
original_config = json.load(f)
# Load the updated config
with open(updated_config_location, "r") as f:
updated_config = json.load(f)
# Asserts to make sure that only num_attention_heads, num_key_value_heads, intermediate_size, tie_word_embeddings have changed
for param in original_config.keys():
if param not in ["num_attention_heads", "intermediate_size", "tie_word_embeddings", "num_key_value_heads", "transformers_version", "eos_token_id"]:
assert original_config[param] == updated_config[param], f"Updated config must have the same {param} as original config"
# Update possible changed parameters
if "head_dim" in original_config.keys():
HEAD_SIZE = int(original_config["head_dim"])
else:
HEAD_SIZE = int(original_config["hidden_size"] / original_config["num_attention_heads"])
OLD_NUM_KV_HEADS = int(original_config["num_key_value_heads"])
NEW_NUM_KV_HEADS = int(updated_config["num_key_value_heads"])
OLD_NUM_HEADS = int(original_config["num_attention_heads"])
NEW_NUM_HEADS = int(updated_config["num_attention_heads"])
# To change the intermediate size, update these parameters (or comment out to keep original intermediate size)
OLD_INTERMEDIATE_SIZE = int(original_config["intermediate_size"])
NEW_INTERMEDIATE_SIZE = int(updated_config["intermediate_size"])
for layer in tqdm.tqdm(model.model.layers):
# Update Q/O_Proj sizes if num_attention_heads changes
if OLD_NUM_HEADS != NEW_NUM_HEADS:
old_group_size = int(OLD_NUM_HEADS / OLD_NUM_KV_HEADS)
new_group_size = int(NEW_NUM_HEADS / OLD_NUM_KV_HEADS)
# Update Q Weight Matrix
new_q_weight_matrix = torch.zeros((NEW_NUM_HEADS * HEAD_SIZE, OLD_NUM_HEADS * HEAD_SIZE))
for i in range(OLD_NUM_KV_HEADS):
new_q_weight_matrix[i * new_group_size * HEAD_SIZE: ((i * new_group_size) + old_group_size) * HEAD_SIZE, :] = layer.self_attn.q_proj.weight[i * old_group_size * HEAD_SIZE: (i+1) * old_group_size * HEAD_SIZE,:]
layer.self_attn.q_proj.weight.data = new_q_weight_matrix
# Update Q Bias (in case of Qwen2)
if model.config.model_type == "qwen2":
new_q_bias = torch.zeros((NEW_NUM_HEADS * HEAD_SIZE))
for i in range(OLD_NUM_KV_HEADS):
new_q_bias[i * new_group_size * HEAD_SIZE:((i * new_group_size) + old_group_size) * HEAD_SIZE] = layer.self_attn.q_proj.bias[i * old_group_size * HEAD_SIZE: (i+1) * old_group_size * HEAD_SIZE]
layer.self_attn.q_proj.bias.data = new_q_bias
# Update O Weight Matrix
new_o_weight_matrix = torch.zeros((OLD_NUM_HEADS * HEAD_SIZE, NEW_NUM_HEADS * HEAD_SIZE))
for i in range(OLD_NUM_KV_HEADS):
new_o_weight_matrix[:,i * new_group_size * HEAD_SIZE:((i * new_group_size) + old_group_size) * HEAD_SIZE] = layer.self_attn.o_proj.weight[:, i * old_group_size * HEAD_SIZE: (i+1) * old_group_size * HEAD_SIZE]
layer.self_attn.o_proj.weight.data = new_o_weight_matrix
# Update KV Gemm Weight Matrices if needed
if OLD_NUM_KV_HEADS != NEW_NUM_KV_HEADS:
assert OLD_NUM_KV_HEADS == 2, "Updates to num_key_value_heads are only allowed for Qwen2.5 0.5B"
duplication_factor = int(NEW_NUM_KV_HEADS / OLD_NUM_KV_HEADS)
layer.self_attn.k_proj.weight.data = torch.cat((layer.self_attn.k_proj.weight[:HEAD_SIZE, :].repeat(duplication_factor, 1),
layer.self_attn.k_proj.weight[HEAD_SIZE:, :].repeat(duplication_factor, 1)), dim=0)
layer.self_attn.v_proj.weight.data = torch.cat((layer.self_attn.v_proj.weight[:HEAD_SIZE, :].repeat(duplication_factor, 1),
layer.self_attn.v_proj.weight[HEAD_SIZE:, :].repeat(duplication_factor, 1)), dim=0)
# Update KV Bias (in case of Qwen2)
if model.config.model_type == "qwen2":
layer.self_attn.k_proj.bias.data = torch.cat((layer.self_attn.k_proj.bias[:HEAD_SIZE].repeat(duplication_factor),
layer.self_attn.k_proj.bias[HEAD_SIZE:].repeat(duplication_factor)), dim=0)
layer.self_attn.v_proj.bias.data = torch.cat((layer.self_attn.v_proj.bias[:HEAD_SIZE].repeat(duplication_factor),
layer.self_attn.v_proj.bias[HEAD_SIZE:].repeat(duplication_factor)), dim=0)
# Update FFN Gemm sizes if intermediate dim changes
if OLD_INTERMEDIATE_SIZE != NEW_INTERMEDIATE_SIZE:
padding = NEW_INTERMEDIATE_SIZE - OLD_INTERMEDIATE_SIZE
# Update Down Gemm Weight Matrix
layer.mlp.down_proj.weight.data = torch.cat((layer.mlp.down_proj.weight.data,
torch.zeros((layer.mlp.down_proj.weight.data.shape[0], padding)).to(layer.mlp.down_proj.weight.device)),dim=1)
# Update Gate/Up Gemm Weight Matrix
layer.mlp.gate_proj.weight.data = torch.cat((layer.mlp.gate_proj.weight.data,
torch.zeros((padding, layer.mlp.gate_proj.weight.data.shape[1],)).to(layer.mlp.gate_proj.weight.device)),dim=0)
layer.mlp.up_proj.weight.data = torch.cat((layer.mlp.up_proj.weight.data,
torch.zeros((padding, layer.mlp.up_proj.weight.data.shape[1],)).to(layer.mlp.up_proj.weight.device)),dim=0)
# Update new model config
model.config.head_dim = HEAD_SIZE
model.config.num_attention_heads = NEW_NUM_HEADS
model.config.intermediate_size = NEW_INTERMEDIATE_SIZE
model.config.num_key_value_heads = NEW_NUM_KV_HEADS
# Check if weight needs to be untied offline
if original_config["tie_word_embeddings"] != updated_config["tie_word_embeddings"]:
new_weight = torch.nn.Parameter(model.get_input_embeddings().weight.data.clone())
assert model.get_input_embeddings().weight.data.data_ptr() != new_weight.data.data_ptr()
model.get_output_embeddings().weight = new_weight
model.config.tie_word_embeddings = False
# Save updated model weights in specified location
model.save_pretrained(updated_ckpt_location)
tokenizer.save_pretrained(updated_ckpt_location)
Updated configurations
These configs have been copied over from the HuggingFace and certain changes, have been added. |
Qwen2.5 0.5B
{
"architectures": [
"Qwen2ForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"head_dim": 64,
"hidden_size": 896,
"initializer_range": 0.02,
"intermediate_size": 5120,
"max_position_embeddings": 32768,
"max_window_layers": 21,
"model_type": "qwen2",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"sliding_window": 32768,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.43.1",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
}
Llama3.2 3B
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": [
128001,
128008,
128009
],
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.45.0.dev0",
"use_cache": true,
"vocab_size": 128256
}
QwQ 32B
{
"architectures": [
"Qwen2ForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"head_dim": 128,
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 28672,
"max_position_embeddings": 32768,
"max_window_layers": 64,
"model_type": "qwen2",
"num_attention_heads": 48,
"num_hidden_layers": 64,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000.0,
"sliding_window": 32768,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.43.1",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 152064
}
Qwen2.5-Coder 32B
{
"architectures": [
"Qwen2ForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"head_dim": 128,
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 28672,
"max_position_embeddings": 32768,
"max_window_layers": 64,
"model_type": "qwen2",
"num_attention_heads": 48,
"num_hidden_layers": 64,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"sliding_window": 131072,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.43.1",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 152064
}
Qwen2.5 72B
{
"architectures": [
"Qwen2ForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 8192,
"initializer_range": 0.02,
"intermediate_size": 30720,
"max_position_embeddings": 32768,
"max_window_layers": 70,
"model_type": "qwen2",
"num_attention_heads": 64,
"num_hidden_layers": 80,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"sliding_window": 131072,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.43.1",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 152064
}