| | import gradio as gr |
| | from transformers import AutoProcessor, AutoModelForCausalLM, pipeline |
| | import torch |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | checkpoint1 = "microsoft/git-base" |
| | processor = AutoProcessor.from_pretrained(checkpoint1) |
| | |
| | model1 = AutoModelForCausalLM.from_pretrained(checkpoint1) |
| |
|
| | |
| | checkpoint2 = "wangjin2000/git-base-finetune" |
| | |
| | model2 = AutoModelForCausalLM.from_pretrained(checkpoint2) |
| |
|
| | |
| | |
| | en_zh_translator = pipeline("translation",model="liam168/trans-opus-mt-en-zh") |
| |
|
| | def img2cap_com(image): |
| | input1 = processor(images=image, return_tensors="pt").to(device) |
| | pixel_values1 = input1.pixel_values |
| | generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50) |
| | generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0] |
| | |
| | translated_caption1 = [generated_caption1, en_zh_translator(generated_caption1)] |
| | |
| | input2 = processor(images=image, return_tensors="pt").to(device) |
| | pixel_values2 = input2.pixel_values |
| | generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50) |
| | generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0] |
| | translated_caption2 = [generated_caption2, en_zh_translator(generated_caption2)] |
| | |
| | return translated_caption1,translated_caption2 |
| |
|
| | inputs = [ |
| | gr.Image(type="pil", label="Original Image") |
| | ] |
| |
|
| | outputs = [ |
| | gr.Textbox(label="Caption from pre-trained model"), |
| | gr.Textbox(label="Caption from fine-tuned model"), |
| | ] |
| |
|
| | title = "Image Captioning using Pre-trained and Fine-tuned Model" |
| | description = "GIT-base is used to generate Image Caption for the uploaded image." |
| |
|
| | examples = [ |
| | ["Image1.png"], |
| | ["Image2.png"], |
| | ["Image3.png"], |
| | ["Image4.png"], |
| | ["Image5.png"], |
| | ["Image6.png"] |
| | ] |
| |
|
| | gr.Interface( |
| | img2cap_com, |
| | inputs, |
| | outputs, |
| | title=title, |
| | description=description, |
| | examples=examples, |
| | theme="huggingface", |
| | ).launch() |
| |
|