|
|
|
|
|
""" |
|
|
Extract text-only LLM from HyperCLOVAX-SEED-Think-32B VLM. |
|
|
Converts to LLaMA-compatible format for standard inference engines. |
|
|
|
|
|
Usage: |
|
|
python extract_llm.py --input ./HyperCLOVAX-SEED-Think-32B --output ./HyperCLOVAX-SEED-Text-Think-32B |
|
|
|
|
|
Requirements: |
|
|
pip install safetensors torch tqdm |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from collections import defaultdict |
|
|
from safetensors import safe_open |
|
|
from safetensors.torch import save_file |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
def load_weight_index(model_path: Path) -> dict: |
|
|
"""Load the safetensors weight index file.""" |
|
|
index_path = model_path / "model.safetensors.index.json" |
|
|
with open(index_path, "r") as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def extract_llm_weights(model_path: Path, output_path: Path): |
|
|
""" |
|
|
Extract LLM weights from VLM. |
|
|
|
|
|
Key mapping: |
|
|
- model.language_model.model.* → model.* |
|
|
- model.language_model.lm_head.* → lm_head.* |
|
|
|
|
|
All vision encoder and MM projector weights are excluded. |
|
|
""" |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
weight_index = load_weight_index(model_path) |
|
|
weight_map = weight_index["weight_map"] |
|
|
|
|
|
|
|
|
llm_weights = {} |
|
|
for key, shard_file in weight_map.items(): |
|
|
if key.startswith("model.language_model."): |
|
|
if key.startswith("model.language_model.model."): |
|
|
new_key = key.replace("model.language_model.model.", "model.") |
|
|
elif key.startswith("model.language_model.lm_head."): |
|
|
new_key = key.replace("model.language_model.", "") |
|
|
else: |
|
|
new_key = key.replace("model.language_model.", "") |
|
|
llm_weights[new_key] = (key, shard_file) |
|
|
|
|
|
print(f"Found {len(llm_weights)} LLM weight tensors") |
|
|
print(f"Excluded {len(weight_map) - len(llm_weights)} vision/projector tensors") |
|
|
|
|
|
|
|
|
shard_to_weights = defaultdict(list) |
|
|
for new_key, (old_key, shard_file) in llm_weights.items(): |
|
|
shard_to_weights[shard_file].append((old_key, new_key)) |
|
|
|
|
|
|
|
|
all_tensors = {} |
|
|
shard_files = sorted(set(shard_to_weights.keys())) |
|
|
|
|
|
print(f"\nLoading weights from {len(shard_files)} shards...") |
|
|
for shard_file in tqdm(shard_files, desc="Loading shards"): |
|
|
shard_path = model_path / shard_file |
|
|
with safe_open(shard_path, framework="pt", device="cpu") as f: |
|
|
for old_key, new_key in shard_to_weights[shard_file]: |
|
|
tensor = f.get_tensor(old_key) |
|
|
all_tensors[new_key] = tensor |
|
|
|
|
|
print(f"\nTotal tensors extracted: {len(all_tensors)}") |
|
|
|
|
|
total_size = sum(t.numel() * t.element_size() for t in all_tensors.values()) |
|
|
print(f"Total size: {total_size / 1e9:.2f} GB") |
|
|
|
|
|
|
|
|
max_shard_size = 5 * 1024 * 1024 * 1024 |
|
|
|
|
|
print("\nSaving extracted weights...") |
|
|
save_sharded_safetensors(all_tensors, output_path, max_shard_size) |
|
|
|
|
|
return list(all_tensors.keys()) |
|
|
|
|
|
|
|
|
def save_sharded_safetensors(tensors: dict, output_path: Path, max_shard_size: int): |
|
|
"""Save tensors as sharded safetensors files with index.""" |
|
|
sorted_keys = sorted(tensors.keys()) |
|
|
|
|
|
shards = [] |
|
|
current_shard = {} |
|
|
current_size = 0 |
|
|
shard_idx = 1 |
|
|
weight_map = {} |
|
|
|
|
|
for key in sorted_keys: |
|
|
tensor = tensors[key] |
|
|
tensor_size = tensor.numel() * tensor.element_size() |
|
|
|
|
|
if current_size + tensor_size > max_shard_size and current_shard: |
|
|
shards.append((shard_idx, current_shard)) |
|
|
shard_idx += 1 |
|
|
current_shard = {} |
|
|
current_size = 0 |
|
|
|
|
|
current_shard[key] = tensor |
|
|
current_size += tensor_size |
|
|
|
|
|
if current_shard: |
|
|
shards.append((shard_idx, current_shard)) |
|
|
|
|
|
total_shards = len(shards) |
|
|
total_size = sum(t.numel() * t.element_size() for t in tensors.values()) |
|
|
|
|
|
for shard_idx, shard_tensors in tqdm(shards, desc="Saving shards"): |
|
|
shard_name = f"model-{shard_idx:05d}-of-{total_shards:05d}.safetensors" |
|
|
shard_path = output_path / shard_name |
|
|
save_file(shard_tensors, shard_path) |
|
|
|
|
|
for key in shard_tensors.keys(): |
|
|
weight_map[key] = shard_name |
|
|
|
|
|
|
|
|
index = { |
|
|
"metadata": {"total_size": total_size}, |
|
|
"weight_map": weight_map |
|
|
} |
|
|
index_path = output_path / "model.safetensors.index.json" |
|
|
with open(index_path, "w") as f: |
|
|
json.dump(index, f, indent=2) |
|
|
|
|
|
print(f"Saved {total_shards} shards to {output_path}") |
|
|
|
|
|
|
|
|
def create_llama_config(original_config_path: Path, output_path: Path): |
|
|
""" |
|
|
Create LLaMA-compatible config from VLM config. |
|
|
|
|
|
Note: HyperCLOVAX uses attention_multiplier ≈ 1/sqrt(head_dim) |
|
|
which matches standard LLaMA scaled dot-product attention. |
|
|
""" |
|
|
with open(original_config_path, "r") as f: |
|
|
vlm_config = json.load(f) |
|
|
|
|
|
text_config = vlm_config["text_config"] |
|
|
|
|
|
llama_config = { |
|
|
"architectures": ["LlamaForCausalLM"], |
|
|
"attention_bias": text_config.get("attention_bias", False), |
|
|
"attention_dropout": text_config.get("attention_dropout", 0.0), |
|
|
"bos_token_id": text_config.get("bos_token_id", 128000), |
|
|
"eos_token_id": text_config.get("eos_token_id", 128001), |
|
|
"head_dim": text_config.get("head_dim", 128), |
|
|
"hidden_act": text_config.get("hidden_act", "silu"), |
|
|
"hidden_size": text_config.get("hidden_size", 5120), |
|
|
"initializer_range": text_config.get("initializer_range", 0.006), |
|
|
"intermediate_size": text_config.get("intermediate_size", 24192), |
|
|
"max_position_embeddings": text_config.get("max_position_embeddings", 131072), |
|
|
"mlp_bias": text_config.get("mlp_bias", False), |
|
|
"model_type": "llama", |
|
|
"num_attention_heads": text_config.get("num_attention_heads", 40), |
|
|
"num_hidden_layers": text_config.get("num_hidden_layers", 72), |
|
|
"num_key_value_heads": text_config.get("num_key_value_heads", 8), |
|
|
"pad_token_id": text_config.get("pad_token_id", 0), |
|
|
"pretraining_tp": 1, |
|
|
"rms_norm_eps": text_config.get("rms_norm_eps", 1e-05), |
|
|
"rope_scaling": text_config.get("rope_scaling", None), |
|
|
"rope_theta": text_config.get("rope_theta", 50000000), |
|
|
"tie_word_embeddings": text_config.get("tie_word_embeddings", False), |
|
|
"torch_dtype": "bfloat16", |
|
|
"transformers_version": "4.52.4", |
|
|
"use_cache": True, |
|
|
"vocab_size": text_config.get("vocab_size", 128256), |
|
|
} |
|
|
|
|
|
config_path = output_path / "config.json" |
|
|
with open(config_path, "w") as f: |
|
|
json.dump(llama_config, f, indent=2) |
|
|
|
|
|
print(f"Saved LLaMA config to {config_path}") |
|
|
|
|
|
|
|
|
gen_config = { |
|
|
"bos_token_id": llama_config["bos_token_id"], |
|
|
"eos_token_id": llama_config["eos_token_id"], |
|
|
"pad_token_id": llama_config["pad_token_id"], |
|
|
"do_sample": True, |
|
|
"temperature": 0.7, |
|
|
"top_p": 0.9, |
|
|
"max_length": 4096 |
|
|
} |
|
|
gen_config_path = output_path / "generation_config.json" |
|
|
with open(gen_config_path, "w") as f: |
|
|
json.dump(gen_config, f, indent=2) |
|
|
|
|
|
return llama_config |
|
|
|
|
|
|
|
|
def copy_tokenizer_files(original_path: Path, output_path: Path): |
|
|
"""Copy tokenizer files from original model.""" |
|
|
tokenizer_files = [ |
|
|
"tokenizer.json", |
|
|
"tokenizer_config.json", |
|
|
"special_tokens_map.json", |
|
|
"added_tokens.json", |
|
|
"vocab.json", |
|
|
"merges.txt", |
|
|
"chat_template.jinja" |
|
|
] |
|
|
|
|
|
copied = [] |
|
|
for fname in tokenizer_files: |
|
|
src = original_path / fname |
|
|
if src.exists(): |
|
|
dst = output_path / fname |
|
|
shutil.copy2(src, dst) |
|
|
copied.append(fname) |
|
|
|
|
|
print(f"Copied tokenizer files: {copied}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Extract text-only LLM from HyperCLOVAX-SEED-Think-32B VLM", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Example: |
|
|
# Download original VLM |
|
|
huggingface-cli download naver-hyperclovax/HyperCLOVAX-SEED-Think-32B \\ |
|
|
--local-dir ./HyperCLOVAX-SEED-Think-32B |
|
|
|
|
|
# Extract text-only LLM |
|
|
python extract_llm.py \\ |
|
|
--input ./HyperCLOVAX-SEED-Think-32B \\ |
|
|
--output ./HyperCLOVAX-SEED-Text-Think-32B |
|
|
""" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--input", "-i", |
|
|
type=Path, |
|
|
required=True, |
|
|
help="Path to original HyperCLOVAX-SEED-Think-32B VLM" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", "-o", |
|
|
type=Path, |
|
|
required=True, |
|
|
help="Output path for extracted text-only LLM" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not args.input.exists(): |
|
|
print(f"Error: Input path does not exist: {args.input}") |
|
|
return 1 |
|
|
|
|
|
if not (args.input / "model.safetensors.index.json").exists(): |
|
|
print(f"Error: model.safetensors.index.json not found in {args.input}") |
|
|
return 1 |
|
|
|
|
|
print("=" * 60) |
|
|
print("HyperCLOVAX VLM → Text-only LLM Extraction") |
|
|
print("=" * 60) |
|
|
print(f"Input: {args.input}") |
|
|
print(f"Output: {args.output}") |
|
|
|
|
|
print("\n[Step 1] Extracting LLM weights...") |
|
|
extracted_keys = extract_llm_weights(args.input, args.output) |
|
|
|
|
|
print("\n[Step 2] Creating LLaMA-compatible config...") |
|
|
config = create_llama_config(args.input / "config.json", args.output) |
|
|
|
|
|
print("\n[Step 3] Copying tokenizer files...") |
|
|
copy_tokenizer_files(args.input, args.output) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Extraction complete!") |
|
|
print(f"Output: {args.output}") |
|
|
print("=" * 60) |
|
|
|
|
|
print(f"\nModel summary:") |
|
|
print(f" - Architecture: LlamaForCausalLM") |
|
|
print(f" - Hidden size: {config['hidden_size']}") |
|
|
print(f" - Layers: {config['num_hidden_layers']}") |
|
|
print(f" - Attention heads: {config['num_attention_heads']}") |
|
|
print(f" - KV heads: {config['num_key_value_heads']}") |
|
|
print(f" - Vocab size: {config['vocab_size']}") |
|
|
print(f" - Max context: {config['max_position_embeddings']}") |
|
|
|
|
|
print(f"\nYou can now use the model with vLLM, transformers, or other LLaMA-compatible frameworks.") |
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|