x10z commited on
Commit
378bb55
·
verified ·
1 Parent(s): 067ce82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -108
app.py CHANGED
@@ -1,144 +1,142 @@
1
- ###########################################################################################
2
- # Code based on the Hugging Face Space of Depth Anything v2
3
- # https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/app.py
4
- ###########################################################################################
5
-
6
  import gradio as gr
7
  import cv2
8
- import matplotlib
9
  import numpy as np
10
- import os
11
- from PIL import Image
12
- import spaces
13
  import torch
14
  import tempfile
15
- from gradio_imageslider import ImageSlider
16
- from huggingface_hub import hf_hub_download
 
17
 
18
- from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline
19
- from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel
20
  from diffusers import DDIMScheduler, AutoencoderKL
21
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
 
 
22
 
23
-
24
- css = """
25
- #img-display-container {
26
- max-height: 100vh;
27
- }
28
- #img-display-input {
29
- max-height: 80vh;
30
- }
31
- #img-display-output {
32
- max-height: 80vh;
33
- }
34
- #download {
35
- height: 62px;
36
- }
37
- """
38
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
39
  checkpoint_path = "GonzaloMG/geowizard-e2e-ft"
 
 
40
  vae = AutoencoderKL.from_pretrained(checkpoint_path, subfolder='vae')
41
  scheduler = DDIMScheduler.from_pretrained(checkpoint_path, timestep_spacing="trailing", subfolder='scheduler')
42
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(checkpoint_path, subfolder="image_encoder")
43
  feature_extractor = CLIPImageProcessor.from_pretrained(checkpoint_path, subfolder="feature_extractor")
44
  unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet")
45
- pipe = DepthNormalEstimationPipeline(vae=vae,
46
- image_encoder=image_encoder,
47
- feature_extractor=feature_extractor,
48
- unet=unet,
49
- scheduler=scheduler)
50
- pipe = pipe.to(DEVICE)
 
 
 
51
  pipe.unet.eval()
52
 
53
- title = "# End-to-End Fine-Tuned GeoWizard"
 
54
  description = """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details."""
55
-
56
  @spaces.GPU
57
- def predict(image, processing_res_choice):
 
 
 
58
  with torch.no_grad():
59
- pipe_out = pipe(image, denoising_steps=1, ensemble_size=1, noise="zeros", processing_res=processing_res_choice, match_input_res=True)
60
- # depth
61
- depth_pred = pipe_out.depth_np
62
- depth_colored = pipe_out.depth_colored
63
- # normals
64
- normal_pred = pipe_out.normal_np
65
- normal_colored = pipe_out.normal_colored
66
- return depth_pred, depth_colored, normal_pred, normal_colored
67
-
68
- with gr.Blocks(css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  gr.Markdown(title)
70
  gr.Markdown(description)
71
- gr.Markdown("### Depth and Normals Prediction demo")
72
 
73
  with gr.Row():
74
- depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
75
- normal_image_slider = ImageSlider(label="Normal Map with Slider View", elem_id='normal-display-output', position=0.5)
76
-
77
- with gr.Row():
78
- input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
79
  with gr.Column():
80
  processing_res_choice = gr.Radio(
81
  [
82
  ("Recommended (768)", 768),
83
- ("Native", 0),
84
  ],
85
  label="Processing resolution",
86
  value=768,
87
  )
88
  submit = gr.Button(value="Compute Depth and Normals")
89
-
90
- colored_depth_file = gr.File(label="Colored Depth Image", elem_id="download")
91
- gray_depth_file = gr.File(label="Grayscale Depth Map", elem_id="download")
92
- raw_depth_file = gr.File(label="Raw Depth Data (.npy)", elem_id="download")
93
- colored_normal_file = gr.File(label="Colored Normal Image", elem_id="download")
94
- raw_normal_file = gr.File(label="Raw Normal Data (.npy)", elem_id="download")
95
-
96
- cmap = matplotlib.colormaps.get_cmap('Spectral_r')
97
-
98
- def on_submit(image, processing_res_choice):
99
-
100
- if image is None:
101
- print("No image uploaded.")
102
- return None
103
-
104
- pil_image = Image.fromarray(image.astype('uint8'))
105
- depth_pred, depth_colored, normal_pred, normal_colored = predict(pil_image, processing_res_choice)
106
-
107
- # Save depth and normals npy data
108
- tmp_npy_depth = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
109
- np.save(tmp_npy_depth.name, depth_pred)
110
- tmp_npy_normal = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
111
- np.save(tmp_npy_normal.name, normal_pred)
112
-
113
- # Save the grayscale depth map
114
- depth_gray = (depth_pred * 65535.0).astype(np.uint16)
115
- tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
116
- Image.fromarray(depth_gray).save(tmp_gray_depth.name, mode="I;16")
117
-
118
- # Save the colored depth and normals maps
119
- tmp_colored_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
120
- depth_colored.save(tmp_colored_depth.name)
121
- tmp_colored_normal = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
122
- normal_colored.save(tmp_colored_normal.name)
123
-
124
- return (
125
- (pil_image, depth_colored), # For ImageSlider: (base image, overlay image)
126
- (pil_image, normal_colored), # For gr.Image
127
- tmp_colored_depth.name, # File outputs
128
- tmp_gray_depth.name,
129
- tmp_npy_depth.name,
130
- tmp_colored_normal.name,
131
- tmp_npy_normal.name
132
- )
133
 
134
- submit.click(on_submit, inputs=[input_image, processing_res_choice], outputs=[depth_image_slider,normal_image_slider,colored_depth_file,gray_depth_file,raw_depth_file,colored_normal_file,raw_normal_file])
135
-
136
- example_files = os.listdir('assets/examples')
137
- example_files.sort()
138
- example_files = [os.path.join('assets/examples', filename) for filename in example_files]
139
- example_files = [[image, 768] for image in example_files]
140
- examples = gr.Examples(examples=example_files, inputs=[input_image, processing_res_choice], outputs=[depth_image_slider,normal_image_slider,colored_depth_file,gray_depth_file,raw_depth_file,colored_normal_file,raw_normal_file], fn=on_submit)
141
 
 
 
 
 
 
142
 
143
- if __name__ == '__main__':
144
- demo.queue().launch(share=True)
 
 
 
 
 
 
1
  import gradio as gr
2
  import cv2
 
3
  import numpy as np
 
 
 
4
  import torch
5
  import tempfile
6
+ from PIL import Image
7
+ import spaces
8
+ from tqdm.auto import tqdm
9
 
 
 
10
  from diffusers import DDIMScheduler, AutoencoderKL
11
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
12
+ from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel
13
+ from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline
14
 
15
+ # Device setup
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  checkpoint_path = "GonzaloMG/geowizard-e2e-ft"
18
+
19
+ # Load pretrained components
20
  vae = AutoencoderKL.from_pretrained(checkpoint_path, subfolder='vae')
21
  scheduler = DDIMScheduler.from_pretrained(checkpoint_path, timestep_spacing="trailing", subfolder='scheduler')
22
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(checkpoint_path, subfolder="image_encoder")
23
  feature_extractor = CLIPImageProcessor.from_pretrained(checkpoint_path, subfolder="feature_extractor")
24
  unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet")
25
+
26
+ # Instantiate pipeline
27
+ pipe = DepthNormalEstimationPipeline(
28
+ vae=vae,
29
+ image_encoder=image_encoder,
30
+ feature_extractor=feature_extractor,
31
+ unet=unet,
32
+ scheduler=scheduler
33
+ ).to(device)
34
  pipe.unet.eval()
35
 
36
+ # UI texts
37
+ title = "# End-to-End Fine-Tuned GeoWizard Video"
38
  description = """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details."""
39
+
40
  @spaces.GPU
41
+ def predict(image: Image.Image, processing_res_choice: int):
42
+ """
43
+ Single-frame prediction wrapped for GPU execution.
44
+ """
45
  with torch.no_grad():
46
+ return pipe(
47
+ image,
48
+ denoising_steps=1,
49
+ ensemble_size=1,
50
+ noise="zeros",
51
+ processing_res=processing_res_choice,
52
+ match_input_res=True
53
+ )
54
+
55
+
56
+ def on_submit_video(video_path: str, processing_res_choice: int):
57
+ """
58
+ Processes each frame of the input video, generating separate depth and normal videos.
59
+ """
60
+ if video_path is None:
61
+ print("No video uploaded.")
62
+ return None
63
+
64
+ cap = cv2.VideoCapture(video_path)
65
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
66
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
67
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
68
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
69
+
70
+ # Create temporary output video files
71
+ tmp_depth = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
72
+ tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
73
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
74
+ out_depth = cv2.VideoWriter(tmp_depth.name, fourcc, fps, (width, height))
75
+ out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height))
76
+
77
+ # Process frames
78
+ for _ in tqdm(range(frame_count), desc="Processing frames"):
79
+ ret, frame = cap.read()
80
+ if not ret:
81
+ break
82
+
83
+ # Convert BGR to RGB PIL image
84
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
85
+ pil_image = Image.fromarray(rgb)
86
+
87
+ # Run prediction
88
+ time_error
89
+ depth_np, depth_colored, normal_np, normal_colored = predict(pil_image, processing_res_choice)
90
+
91
+ # Write depth frame
92
+ depth_frame = np.array(depth_colored)
93
+ depth_bgr = cv2.cvtColor(depth_frame, cv2.COLOR_RGB2BGR)
94
+ out_depth.write(depth_bgr)
95
+
96
+ # Write normal frame
97
+ normal_frame = np.array(normal_colored)
98
+ normal_bgr = cv2.cvtColor(normal_frame, cv2.COLOR_RGB2BGR)
99
+ out_normal.write(normal_bgr)
100
+
101
+ # Release resources
102
+ cap.release()
103
+ out_depth.release()
104
+ out_normal.release()
105
+
106
+ return tmp_depth.name, tmp_normal.name
107
+
108
+ # Build Gradio interface
109
+ with gr.Blocks() as demo:
110
  gr.Markdown(title)
111
  gr.Markdown(description)
112
+ gr.Markdown("### Depth and Normals Prediction on Video")
113
 
114
  with gr.Row():
115
+ input_video = gr.Video(
116
+ label="Input Video",
117
+ type="filepath",
118
+ elem_id='video-display-input'
119
+ )
120
  with gr.Column():
121
  processing_res_choice = gr.Radio(
122
  [
123
  ("Recommended (768)", 768),
124
+ ("Native (original)", 0),
125
  ],
126
  label="Processing resolution",
127
  value=768,
128
  )
129
  submit = gr.Button(value="Compute Depth and Normals")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ with gr.Row():
132
+ output_depth_video = gr.Video(label="Depth Video", elem_id='download')
133
+ output_normal_video = gr.Video(label="Normal Video", elem_id='download')
 
 
 
 
134
 
135
+ submit.click(
136
+ fn=on_submit_video,
137
+ inputs=[input_video, processing_res_choice],
138
+ outputs=[output_depth_video, output_normal_video]
139
+ )
140
 
141
+ if __name__ == "__main__":
142
+ demo.queue().launch(share=True)