|
|
import argparse |
|
|
import sys |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import time |
|
|
|
|
|
def preprocess_image(im: np.ndarray, model_input_size: list) -> np.ndarray: |
|
|
if len(im.shape) < 3: |
|
|
im = im[:, :, np.newaxis] |
|
|
im_np = np.transpose(im, (2, 0, 1)).astype(np.float32) |
|
|
im_np = np.expand_dims(im_np, axis=0) |
|
|
|
|
|
_, C, H_ori, W_ori = im_np.shape |
|
|
H_target, W_target = model_input_size |
|
|
|
|
|
x_target = np.linspace(0, W_ori - 1, W_target) |
|
|
y_target = np.linspace(0, H_ori - 1, H_target) |
|
|
xx_target, yy_target = np.meshgrid(x_target, y_target) |
|
|
|
|
|
x0 = np.floor(xx_target).astype(np.int32) |
|
|
x1 = np.minimum(x0 + 1, W_ori - 1) |
|
|
y0 = np.floor(yy_target).astype(np.int32) |
|
|
y1 = np.minimum(y0 + 1, H_ori - 1) |
|
|
|
|
|
wx0 = xx_target - x0 |
|
|
wx1 = 1 - wx0 |
|
|
wy0 = yy_target - y0 |
|
|
wy1 = 1 - wy0 |
|
|
|
|
|
im_interp = np.zeros((1, C, H_target, W_target), dtype=np.float32) |
|
|
for c in range(C): |
|
|
channel_data = im_np[0, c, :, :] |
|
|
top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1] |
|
|
bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1] |
|
|
im_interp[0, c, :, :] = wy1 * top + wy0 * bottom |
|
|
|
|
|
image = (im_interp / 1.0).astype(np.uint8) |
|
|
|
|
|
return image |
|
|
|
|
|
def postprocess_image(result: np.ndarray, im_size: list)-> np.ndarray: |
|
|
result_np = np.squeeze(result, axis=0) |
|
|
C, H_ori, W_ori = result_np.shape |
|
|
H_target, W_target = im_size |
|
|
|
|
|
x_target = np.linspace(0, W_ori - 1, W_target) |
|
|
y_target = np.linspace(0, H_ori - 1, H_target) |
|
|
xx_target, yy_target = np.meshgrid(x_target, y_target) |
|
|
|
|
|
x0 = np.floor(xx_target).astype(np.int32) |
|
|
x1 = np.minimum(x0 + 1, W_ori - 1) |
|
|
y0 = np.floor(yy_target).astype(np.int32) |
|
|
y1 = np.minimum(y0 + 1, H_ori - 1) |
|
|
|
|
|
wx0 = xx_target - x0 |
|
|
wx1 = 1 - wx0 |
|
|
wy0 = yy_target - y0 |
|
|
wy1 = 1 - wy0 |
|
|
|
|
|
result_interp = np.zeros((C, H_target, W_target), dtype=np.float32) |
|
|
for c in range(C): |
|
|
channel_data = result_np[c, :, :] |
|
|
top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1] |
|
|
bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1] |
|
|
result_interp[c, :, :] = wy1 * top + wy0 * bottom |
|
|
|
|
|
ma = np.max(result_interp) |
|
|
mi = np.min(result_interp) |
|
|
|
|
|
result_norm = (result_interp - mi) / (ma - mi + 1e-8) |
|
|
result_scaled = result_norm * 255 |
|
|
im_array = np.transpose(result_scaled, (1, 2, 0)).astype(np.uint8) |
|
|
im_array = np.squeeze(im_array) |
|
|
return im_array |
|
|
|
|
|
def inference(img_path, |
|
|
model_path, |
|
|
save_path): |
|
|
|
|
|
if model_path.endswith(".axmodel"): |
|
|
import axengine as ort |
|
|
|
|
|
session = ort.InferenceSession(model_path) |
|
|
input_name = None |
|
|
for inp_meta in session.get_inputs(): |
|
|
input_shape = inp_meta.shape[2:] |
|
|
input_name = inp_meta.name |
|
|
print(f"输入名称:{input_name},输入尺寸:{input_shape}") |
|
|
|
|
|
model_input_size = [1024, 1024] |
|
|
orig_im_bgr = cv2.imread(img_path) |
|
|
if orig_im_bgr is None: |
|
|
raise FileNotFoundError(f"无法读取图片文件:{img_path},请检查路径是否正确或图片是否损坏") |
|
|
orig_im = cv2.cvtColor(orig_im_bgr, cv2.COLOR_BGR2RGB) |
|
|
orig_im_size = orig_im.shape[0:2] |
|
|
image = preprocess_image(orig_im, model_input_size) |
|
|
|
|
|
t1 = time.time() |
|
|
result = session.run(None, {input_name: image}) |
|
|
t2 = time.time() |
|
|
print(f"推理时间:{(t2-t1)*1000:.2f} ms") |
|
|
|
|
|
result_image = postprocess_image(result[0], orig_im_size) |
|
|
orig_im_unchanged = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) |
|
|
mask = result_image |
|
|
if orig_im_unchanged.shape[-1] == 3: |
|
|
b, g, r = cv2.split(orig_im_unchanged) |
|
|
a = mask |
|
|
no_bg_image = cv2.merge((b, g, r, a)) |
|
|
elif orig_im_unchanged.shape[-1] == 4: |
|
|
b, g, r, _ = cv2.split(orig_im_unchanged) |
|
|
a = mask |
|
|
no_bg_image = cv2.merge((b, g, r, a)) |
|
|
else: |
|
|
raise ValueError(f"不支持的图片通道数:{orig_im_unchanged.shape[-1]},仅支持3通道(BGR)或4通道(BGRA)") |
|
|
|
|
|
if save_path.lower().endswith(('.jpg', '.jpeg')): |
|
|
cv2.imwrite(save_path, cv2.cvtColor(no_bg_image, cv2.COLOR_BGRA2BGR)) |
|
|
print(f"JPG格式不支持透明通道,已丢弃Alpha通道,结果保存至:{save_path}") |
|
|
else: |
|
|
cv2.imwrite(save_path, no_bg_image) |
|
|
print(f"推理完成,结果已保存至:{save_path}") |
|
|
|
|
|
def parse_args() -> argparse.ArgumentParser: |
|
|
parser = argparse.ArgumentParser(description="ax rmbg exsample") |
|
|
parser.add_argument("--model","-m", type=str, help="compiled.axmodel path") |
|
|
parser.add_argument("--img","-i", type=str, help="img path") |
|
|
parser.add_argument("--save_path", type=str, default="./result.png", help="save result path (png)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
|
|
|
print(f"Command: {' '.join(sys.argv)}") |
|
|
print("Parameters:") |
|
|
print(f" --model: {args.model}") |
|
|
print(f" --img_path: {args.img}") |
|
|
print(f" --save_path: {args.save_path}") |
|
|
|
|
|
inference(args.img, args.model, args.save_path) |