feature_extractor.py 3.88 KB
import numpy as np
import torch
import clip
import requests
from PIL import Image
from io import BytesIO
import logging

class FeatureExtractor:
    __logger = logging.getLogger(__name__)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    def __init__(self, device = "xpu" if torch.xpu.is_available() else "cpu", model_name = "ViT-B/32"):
        device = "cpu"
        self.model, self.preprocess = self.init_model(device, model_name)
        self.device = device

    @staticmethod
    def init_model(device="xpu" if torch.xpu.is_available() else "cpu", model_name="ViT-B/32"):
        torch.xpu.empty_cache()
        print(f"创建并初始化 CLIP 模型: {model_name} 在设备: {device}")
        model, preprocess = clip.load(model_name, device=device)
        return model, preprocess

    @staticmethod
    def resize_with_padding(img, target_size = (224, 224)):
        """
        调整图像大小,保持纵横比并添加填充

        Args:
            img: 输入图像
            target_size: 目标尺寸,默认为(224, 224)

        Returns:
            调整大小后的图像
        """
        # 计算调整大小的比例
        ratio = min(target_size[0] / img.width, target_size[1] / img.height)
        new_size = (int(img.width * ratio), int(img.height * ratio))

        # 调整图像大小,保持长宽比
        resized_img = img.resize(new_size, Image.LANCZOS)

        # 创建新的填充图像
        new_img = Image.new("RGB", target_size, (255, 255, 255))

        # 计算粘贴位置(居中)
        paste_position = ((target_size[0] - new_size[0]) // 2,
                          (target_size[1] - new_size[1]) // 2)

        # 粘贴调整后的图像
        new_img.paste(resized_img, paste_position)

        return new_img

    def extract_from_url(self, image_url):
        """
        从URL加载图像并提取特征向量

        Args:
            image_url: 图像URL

        Returns:
            特征向量,如果提取失败则返回None
        """

        try:
            # 下载图片
            response = requests.get(image_url, stream=True, timeout=10)
            response.raise_for_status()  # 确保请求成功

            # 将图片数据转换为 PIL Image 对象
            image = Image.open(BytesIO(response.content)).convert("RGB")

            return self.extract_from_image(image)

        except requests.RequestException as e:
            self.__logger.error(f"Network error when downloading image from {image_url}: {e}")
            return None
        except Exception as e:
            self.__logger.error(f"Error extracting features from URL {image_url}: {e}")
            return None

    def extract_from_image(self, img):
        """
        从PIL图像对象提取特征向量

        Args:
            img: 输入图像

        Returns:
            特征向量,如果提取失败则返回None
        """

        device = self.device
        model = self.model
        preprocess = self.preprocess

        # device = "xpu" if torch.xpu.is_available() else "cpu"
        # device = "cpu"
        # model_name = "ViT-B/32"
        # model, preprocess = self.init_model(device, model_name)

        try:
            # 调整图像大小并添加填充
            image = self.resize_with_padding(img)

            # 预处理并提取特征
            image_tensor = preprocess(image).unsqueeze(0).to(device)

            with torch.no_grad():
                image_features = model.encode_image(image_tensor)

                # 归一化特征向量
                image_features /= image_features.norm(dim=-1, keepdim=True)

            return image_features.cpu().numpy().astype(np.float32).flatten()

        except Exception as e:
            self.__logger.error(f"Error extracting features from image: {e}")
            return None

def get_feature_extractor():
    return feature_extractor

feature_extractor = FeatureExtractor()