RMBG-1.4 / ax_inference.py
yz
Add files
b3357d3
import argparse
import sys
import cv2
import numpy as np
import time
def preprocess_image(im: np.ndarray, model_input_size: list) -> np.ndarray:
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
im_np = np.transpose(im, (2, 0, 1)).astype(np.float32)
im_np = np.expand_dims(im_np, axis=0)
_, C, H_ori, W_ori = im_np.shape
H_target, W_target = model_input_size
x_target = np.linspace(0, W_ori - 1, W_target)
y_target = np.linspace(0, H_ori - 1, H_target)
xx_target, yy_target = np.meshgrid(x_target, y_target)
x0 = np.floor(xx_target).astype(np.int32)
x1 = np.minimum(x0 + 1, W_ori - 1)
y0 = np.floor(yy_target).astype(np.int32)
y1 = np.minimum(y0 + 1, H_ori - 1)
wx0 = xx_target - x0
wx1 = 1 - wx0
wy0 = yy_target - y0
wy1 = 1 - wy0
im_interp = np.zeros((1, C, H_target, W_target), dtype=np.float32)
for c in range(C):
channel_data = im_np[0, c, :, :]
top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1]
bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1]
im_interp[0, c, :, :] = wy1 * top + wy0 * bottom
image = (im_interp / 1.0).astype(np.uint8)
return image
def postprocess_image(result: np.ndarray, im_size: list)-> np.ndarray:
result_np = np.squeeze(result, axis=0)
C, H_ori, W_ori = result_np.shape
H_target, W_target = im_size # 目标尺寸(H,W)
x_target = np.linspace(0, W_ori - 1, W_target)
y_target = np.linspace(0, H_ori - 1, H_target)
xx_target, yy_target = np.meshgrid(x_target, y_target)
x0 = np.floor(xx_target).astype(np.int32)
x1 = np.minimum(x0 + 1, W_ori - 1)
y0 = np.floor(yy_target).astype(np.int32)
y1 = np.minimum(y0 + 1, H_ori - 1)
wx0 = xx_target - x0
wx1 = 1 - wx0
wy0 = yy_target - y0
wy1 = 1 - wy0
result_interp = np.zeros((C, H_target, W_target), dtype=np.float32)
for c in range(C):
channel_data = result_np[c, :, :]
top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1]
bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1]
result_interp[c, :, :] = wy1 * top + wy0 * bottom
ma = np.max(result_interp)
mi = np.min(result_interp)
result_norm = (result_interp - mi) / (ma - mi + 1e-8) # 加极小值避免除零
result_scaled = result_norm * 255
im_array = np.transpose(result_scaled, (1, 2, 0)).astype(np.uint8)
im_array = np.squeeze(im_array)
return im_array
def inference(img_path,
model_path,
save_path):
if model_path.endswith(".axmodel"):
import axengine as ort
session = ort.InferenceSession(model_path)
input_name = None
for inp_meta in session.get_inputs():
input_shape = inp_meta.shape[2:]
input_name = inp_meta.name
print(f"输入名称:{input_name},输入尺寸:{input_shape}")
model_input_size = [1024, 1024]
orig_im_bgr = cv2.imread(img_path)
if orig_im_bgr is None:
raise FileNotFoundError(f"无法读取图片文件:{img_path},请检查路径是否正确或图片是否损坏")
orig_im = cv2.cvtColor(orig_im_bgr, cv2.COLOR_BGR2RGB) # 转换为RGB格式 (H,W,3)
orig_im_size = orig_im.shape[0:2]
image = preprocess_image(orig_im, model_input_size)
t1 = time.time()
result = session.run(None, {input_name: image})
t2 = time.time()
print(f"推理时间:{(t2-t1)*1000:.2f} ms")
result_image = postprocess_image(result[0], orig_im_size) # 得到单通道掩码 (H,W)
orig_im_unchanged = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # 读取所有通道(BGR/BGRA)
mask = result_image # 单通道掩码 (H,W),值范围0-255
if orig_im_unchanged.shape[-1] == 3: # 原图为BGR格式(无透明通道)
b, g, r = cv2.split(orig_im_unchanged)
a = mask
no_bg_image = cv2.merge((b, g, r, a)) # 合并为BGRA格式
elif orig_im_unchanged.shape[-1] == 4: # 原图为BGRA格式(已有透明通道)
b, g, r, _ = cv2.split(orig_im_unchanged)
a = mask
no_bg_image = cv2.merge((b, g, r, a))
else:
raise ValueError(f"不支持的图片通道数:{orig_im_unchanged.shape[-1]},仅支持3通道(BGR)或4通道(BGRA)")
if save_path.lower().endswith(('.jpg', '.jpeg')):
cv2.imwrite(save_path, cv2.cvtColor(no_bg_image, cv2.COLOR_BGRA2BGR))
print(f"JPG格式不支持透明通道,已丢弃Alpha通道,结果保存至:{save_path}")
else:
cv2.imwrite(save_path, no_bg_image)
print(f"推理完成,结果已保存至:{save_path}")
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="ax rmbg exsample")
parser.add_argument("--model","-m", type=str, help="compiled.axmodel path")
parser.add_argument("--img","-i", type=str, help="img path")
parser.add_argument("--save_path", type=str, default="./result.png", help="save result path (png)")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
print(f"Command: {' '.join(sys.argv)}")
print("Parameters:")
print(f" --model: {args.model}")
print(f" --img_path: {args.img}")
print(f" --save_path: {args.save_path}")
inference(args.img, args.model, args.save_path)