johnrobinsn's picture
Update app.py
ba9e177 verified
# Monkey-patch gradio_client's broken schema parser
import gradio_client.utils as client_utils
_original_json_schema_to_python_type = client_utils._json_schema_to_python_type
def _patched_json_schema_to_python_type(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _original_json_schema_to_python_type(schema, defs)
client_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
custom_css = """
.gradio-container {
max-width: 900px !important;
margin: auto !important;
}
.gr-image {
max-height: 300px !important;
}
.gr-button {
padding: 8px 16px !important;
}
.gr-padded {
padding: 10px !important;
}
h1 {
font-size: 1.5rem !important;
}
#depth-viewer { height: 600px; }
"""
import gradio as gr
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from depth_viewer import depthviewer2html
# Don't load models at module level
_model = None
_feature_extractor = None
def get_model():
global _model, _feature_extractor
if _model is None:
from transformers import DPTImageProcessor, DPTForDepthEstimation
_feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
return _feature_extractor, _model
def process_image(image_path):
if image_path is None:
return ""
feature_extractor, model = get_model()
image_path = Path(image_path)
image = Image.open(image_path)
if image.size[0] > 512:
image = image.resize((512, int(512 * image.size[1] / image.size[0])), Image.Resampling.LANCZOS)
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
depth = (output * 255 / np.max(output)).astype('uint8')
return depthviewer2html(image, depth)
title = "3D Visualization of Depth Maps Generated using MiDaS"
description = "Improved 3D interactive depth viewer using Three.js"
with gr.Blocks(css=custom_css) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
input_image = gr.Image(type="filepath", label="Input Image")
output_html = gr.HTML(label="Depth Viewer", elem_id="depth-viewer")
input_image.change(fn=process_image, inputs=input_image, outputs=output_html)
gr.Examples(
examples=[["examples/owl1.jpg"], ["examples/marsattacks.jpg"], ["examples/kitten.jpg"]],
inputs=input_image,
cache_examples=False
)
demo.launch()