import argparse import sys import cv2 import numpy as np import time def preprocess_image(im: np.ndarray, model_input_size: list) -> np.ndarray: if len(im.shape) < 3: im = im[:, :, np.newaxis] im_np = np.transpose(im, (2, 0, 1)).astype(np.float32) im_np = np.expand_dims(im_np, axis=0) _, C, H_ori, W_ori = im_np.shape H_target, W_target = model_input_size x_target = np.linspace(0, W_ori - 1, W_target) y_target = np.linspace(0, H_ori - 1, H_target) xx_target, yy_target = np.meshgrid(x_target, y_target) x0 = np.floor(xx_target).astype(np.int32) x1 = np.minimum(x0 + 1, W_ori - 1) y0 = np.floor(yy_target).astype(np.int32) y1 = np.minimum(y0 + 1, H_ori - 1) wx0 = xx_target - x0 wx1 = 1 - wx0 wy0 = yy_target - y0 wy1 = 1 - wy0 im_interp = np.zeros((1, C, H_target, W_target), dtype=np.float32) for c in range(C): channel_data = im_np[0, c, :, :] top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1] bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1] im_interp[0, c, :, :] = wy1 * top + wy0 * bottom image = (im_interp / 1.0).astype(np.uint8) return image def postprocess_image(result: np.ndarray, im_size: list)-> np.ndarray: result_np = np.squeeze(result, axis=0) C, H_ori, W_ori = result_np.shape H_target, W_target = im_size # 目标尺寸(H,W) x_target = np.linspace(0, W_ori - 1, W_target) y_target = np.linspace(0, H_ori - 1, H_target) xx_target, yy_target = np.meshgrid(x_target, y_target) x0 = np.floor(xx_target).astype(np.int32) x1 = np.minimum(x0 + 1, W_ori - 1) y0 = np.floor(yy_target).astype(np.int32) y1 = np.minimum(y0 + 1, H_ori - 1) wx0 = xx_target - x0 wx1 = 1 - wx0 wy0 = yy_target - y0 wy1 = 1 - wy0 result_interp = np.zeros((C, H_target, W_target), dtype=np.float32) for c in range(C): channel_data = result_np[c, :, :] top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1] bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1] result_interp[c, :, :] = wy1 * top + wy0 * bottom ma = np.max(result_interp) mi = np.min(result_interp) result_norm = (result_interp - mi) / (ma - mi + 1e-8) # 加极小值避免除零 result_scaled = result_norm * 255 im_array = np.transpose(result_scaled, (1, 2, 0)).astype(np.uint8) im_array = np.squeeze(im_array) return im_array def inference(img_path, model_path, save_path): if model_path.endswith(".axmodel"): import axengine as ort session = ort.InferenceSession(model_path) input_name = None for inp_meta in session.get_inputs(): input_shape = inp_meta.shape[2:] input_name = inp_meta.name print(f"输入名称:{input_name},输入尺寸:{input_shape}") model_input_size = [1024, 1024] orig_im_bgr = cv2.imread(img_path) if orig_im_bgr is None: raise FileNotFoundError(f"无法读取图片文件:{img_path},请检查路径是否正确或图片是否损坏") orig_im = cv2.cvtColor(orig_im_bgr, cv2.COLOR_BGR2RGB) # 转换为RGB格式 (H,W,3) orig_im_size = orig_im.shape[0:2] image = preprocess_image(orig_im, model_input_size) t1 = time.time() result = session.run(None, {input_name: image}) t2 = time.time() print(f"推理时间:{(t2-t1)*1000:.2f} ms") result_image = postprocess_image(result[0], orig_im_size) # 得到单通道掩码 (H,W) orig_im_unchanged = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # 读取所有通道(BGR/BGRA) mask = result_image # 单通道掩码 (H,W),值范围0-255 if orig_im_unchanged.shape[-1] == 3: # 原图为BGR格式(无透明通道) b, g, r = cv2.split(orig_im_unchanged) a = mask no_bg_image = cv2.merge((b, g, r, a)) # 合并为BGRA格式 elif orig_im_unchanged.shape[-1] == 4: # 原图为BGRA格式(已有透明通道) b, g, r, _ = cv2.split(orig_im_unchanged) a = mask no_bg_image = cv2.merge((b, g, r, a)) else: raise ValueError(f"不支持的图片通道数:{orig_im_unchanged.shape[-1]},仅支持3通道(BGR)或4通道(BGRA)") if save_path.lower().endswith(('.jpg', '.jpeg')): cv2.imwrite(save_path, cv2.cvtColor(no_bg_image, cv2.COLOR_BGRA2BGR)) print(f"JPG格式不支持透明通道,已丢弃Alpha通道,结果保存至:{save_path}") else: cv2.imwrite(save_path, no_bg_image) print(f"推理完成,结果已保存至:{save_path}") def parse_args() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="ax rmbg exsample") parser.add_argument("--model","-m", type=str, help="compiled.axmodel path") parser.add_argument("--img","-i", type=str, help="img path") parser.add_argument("--save_path", type=str, default="./result.png", help="save result path (png)") args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() print(f"Command: {' '.join(sys.argv)}") print("Parameters:") print(f" --model: {args.model}") print(f" --img_path: {args.img}") print(f" --save_path: {args.save_path}") inference(args.img, args.model, args.save_path)