minpeter's picture
Upload extract_llm.py with huggingface_hub
d9f34c1 verified
#!/usr/bin/env python3
"""
Extract text-only LLM from HyperCLOVAX-SEED-Think-32B VLM.
Converts to LLaMA-compatible format for standard inference engines.
Usage:
python extract_llm.py --input ./HyperCLOVAX-SEED-Think-32B --output ./HyperCLOVAX-SEED-Text-Think-32B
Requirements:
pip install safetensors torch tqdm
"""
import argparse
import json
import os
import shutil
from pathlib import Path
from collections import defaultdict
from safetensors import safe_open
from safetensors.torch import save_file
import torch
from tqdm import tqdm
def load_weight_index(model_path: Path) -> dict:
"""Load the safetensors weight index file."""
index_path = model_path / "model.safetensors.index.json"
with open(index_path, "r") as f:
return json.load(f)
def extract_llm_weights(model_path: Path, output_path: Path):
"""
Extract LLM weights from VLM.
Key mapping:
- model.language_model.model.* → model.*
- model.language_model.lm_head.* → lm_head.*
All vision encoder and MM projector weights are excluded.
"""
output_path.mkdir(parents=True, exist_ok=True)
weight_index = load_weight_index(model_path)
weight_map = weight_index["weight_map"]
# Filter and remap LLM weights
llm_weights = {}
for key, shard_file in weight_map.items():
if key.startswith("model.language_model."):
if key.startswith("model.language_model.model."):
new_key = key.replace("model.language_model.model.", "model.")
elif key.startswith("model.language_model.lm_head."):
new_key = key.replace("model.language_model.", "")
else:
new_key = key.replace("model.language_model.", "")
llm_weights[new_key] = (key, shard_file)
print(f"Found {len(llm_weights)} LLM weight tensors")
print(f"Excluded {len(weight_map) - len(llm_weights)} vision/projector tensors")
# Group by source shard for efficient loading
shard_to_weights = defaultdict(list)
for new_key, (old_key, shard_file) in llm_weights.items():
shard_to_weights[shard_file].append((old_key, new_key))
# Load all LLM tensors
all_tensors = {}
shard_files = sorted(set(shard_to_weights.keys()))
print(f"\nLoading weights from {len(shard_files)} shards...")
for shard_file in tqdm(shard_files, desc="Loading shards"):
shard_path = model_path / shard_file
with safe_open(shard_path, framework="pt", device="cpu") as f:
for old_key, new_key in shard_to_weights[shard_file]:
tensor = f.get_tensor(old_key)
all_tensors[new_key] = tensor
print(f"\nTotal tensors extracted: {len(all_tensors)}")
total_size = sum(t.numel() * t.element_size() for t in all_tensors.values())
print(f"Total size: {total_size / 1e9:.2f} GB")
# Save as sharded safetensors (~5GB per shard)
max_shard_size = 5 * 1024 * 1024 * 1024
print("\nSaving extracted weights...")
save_sharded_safetensors(all_tensors, output_path, max_shard_size)
return list(all_tensors.keys())
def save_sharded_safetensors(tensors: dict, output_path: Path, max_shard_size: int):
"""Save tensors as sharded safetensors files with index."""
sorted_keys = sorted(tensors.keys())
shards = []
current_shard = {}
current_size = 0
shard_idx = 1
weight_map = {}
for key in sorted_keys:
tensor = tensors[key]
tensor_size = tensor.numel() * tensor.element_size()
if current_size + tensor_size > max_shard_size and current_shard:
shards.append((shard_idx, current_shard))
shard_idx += 1
current_shard = {}
current_size = 0
current_shard[key] = tensor
current_size += tensor_size
if current_shard:
shards.append((shard_idx, current_shard))
total_shards = len(shards)
total_size = sum(t.numel() * t.element_size() for t in tensors.values())
for shard_idx, shard_tensors in tqdm(shards, desc="Saving shards"):
shard_name = f"model-{shard_idx:05d}-of-{total_shards:05d}.safetensors"
shard_path = output_path / shard_name
save_file(shard_tensors, shard_path)
for key in shard_tensors.keys():
weight_map[key] = shard_name
# Create index file
index = {
"metadata": {"total_size": total_size},
"weight_map": weight_map
}
index_path = output_path / "model.safetensors.index.json"
with open(index_path, "w") as f:
json.dump(index, f, indent=2)
print(f"Saved {total_shards} shards to {output_path}")
def create_llama_config(original_config_path: Path, output_path: Path):
"""
Create LLaMA-compatible config from VLM config.
Note: HyperCLOVAX uses attention_multiplier ≈ 1/sqrt(head_dim)
which matches standard LLaMA scaled dot-product attention.
"""
with open(original_config_path, "r") as f:
vlm_config = json.load(f)
text_config = vlm_config["text_config"]
llama_config = {
"architectures": ["LlamaForCausalLM"],
"attention_bias": text_config.get("attention_bias", False),
"attention_dropout": text_config.get("attention_dropout", 0.0),
"bos_token_id": text_config.get("bos_token_id", 128000),
"eos_token_id": text_config.get("eos_token_id", 128001),
"head_dim": text_config.get("head_dim", 128),
"hidden_act": text_config.get("hidden_act", "silu"),
"hidden_size": text_config.get("hidden_size", 5120),
"initializer_range": text_config.get("initializer_range", 0.006),
"intermediate_size": text_config.get("intermediate_size", 24192),
"max_position_embeddings": text_config.get("max_position_embeddings", 131072),
"mlp_bias": text_config.get("mlp_bias", False),
"model_type": "llama",
"num_attention_heads": text_config.get("num_attention_heads", 40),
"num_hidden_layers": text_config.get("num_hidden_layers", 72),
"num_key_value_heads": text_config.get("num_key_value_heads", 8),
"pad_token_id": text_config.get("pad_token_id", 0),
"pretraining_tp": 1,
"rms_norm_eps": text_config.get("rms_norm_eps", 1e-05),
"rope_scaling": text_config.get("rope_scaling", None),
"rope_theta": text_config.get("rope_theta", 50000000),
"tie_word_embeddings": text_config.get("tie_word_embeddings", False),
"torch_dtype": "bfloat16",
"transformers_version": "4.52.4",
"use_cache": True,
"vocab_size": text_config.get("vocab_size", 128256),
}
config_path = output_path / "config.json"
with open(config_path, "w") as f:
json.dump(llama_config, f, indent=2)
print(f"Saved LLaMA config to {config_path}")
# Generation config
gen_config = {
"bos_token_id": llama_config["bos_token_id"],
"eos_token_id": llama_config["eos_token_id"],
"pad_token_id": llama_config["pad_token_id"],
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"max_length": 4096
}
gen_config_path = output_path / "generation_config.json"
with open(gen_config_path, "w") as f:
json.dump(gen_config, f, indent=2)
return llama_config
def copy_tokenizer_files(original_path: Path, output_path: Path):
"""Copy tokenizer files from original model."""
tokenizer_files = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"added_tokens.json",
"vocab.json",
"merges.txt",
"chat_template.jinja"
]
copied = []
for fname in tokenizer_files:
src = original_path / fname
if src.exists():
dst = output_path / fname
shutil.copy2(src, dst)
copied.append(fname)
print(f"Copied tokenizer files: {copied}")
def main():
parser = argparse.ArgumentParser(
description="Extract text-only LLM from HyperCLOVAX-SEED-Think-32B VLM",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Example:
# Download original VLM
huggingface-cli download naver-hyperclovax/HyperCLOVAX-SEED-Think-32B \\
--local-dir ./HyperCLOVAX-SEED-Think-32B
# Extract text-only LLM
python extract_llm.py \\
--input ./HyperCLOVAX-SEED-Think-32B \\
--output ./HyperCLOVAX-SEED-Text-Think-32B
"""
)
parser.add_argument(
"--input", "-i",
type=Path,
required=True,
help="Path to original HyperCLOVAX-SEED-Think-32B VLM"
)
parser.add_argument(
"--output", "-o",
type=Path,
required=True,
help="Output path for extracted text-only LLM"
)
args = parser.parse_args()
if not args.input.exists():
print(f"Error: Input path does not exist: {args.input}")
return 1
if not (args.input / "model.safetensors.index.json").exists():
print(f"Error: model.safetensors.index.json not found in {args.input}")
return 1
print("=" * 60)
print("HyperCLOVAX VLM → Text-only LLM Extraction")
print("=" * 60)
print(f"Input: {args.input}")
print(f"Output: {args.output}")
print("\n[Step 1] Extracting LLM weights...")
extracted_keys = extract_llm_weights(args.input, args.output)
print("\n[Step 2] Creating LLaMA-compatible config...")
config = create_llama_config(args.input / "config.json", args.output)
print("\n[Step 3] Copying tokenizer files...")
copy_tokenizer_files(args.input, args.output)
print("\n" + "=" * 60)
print("Extraction complete!")
print(f"Output: {args.output}")
print("=" * 60)
print(f"\nModel summary:")
print(f" - Architecture: LlamaForCausalLM")
print(f" - Hidden size: {config['hidden_size']}")
print(f" - Layers: {config['num_hidden_layers']}")
print(f" - Attention heads: {config['num_attention_heads']}")
print(f" - KV heads: {config['num_key_value_heads']}")
print(f" - Vocab size: {config['vocab_size']}")
print(f" - Max context: {config['max_position_embeddings']}")
print(f"\nYou can now use the model with vLLM, transformers, or other LLaMA-compatible frameworks.")
return 0
if __name__ == "__main__":
exit(main())