Commit 9858b763 authored by zhengyaoqiu's avatar zhengyaoqiu

特征提取模块

parent 8a4d6982
import torch
import clip
import requests
from PIL import Image
from io import BytesIO
class FeatureExtractor:
"""
使用CLIP模型提取图像特征向量的工具类
"""
# 类变量,用于存储模型实例
_model = None
_preprocess = None
_device = None
@classmethod
def initialize(cls, model_name="ViT-B/32"):
"""
初始化CLIP模型
Args:
model_name (str): CLIP模型名称
"""
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
cls._model, cls._preprocess = clip.load(model_name, device=cls._device)
print(f"CLIP model {model_name} loaded on {cls._device}")
@staticmethod
def resize_with_padding(img):
"""
调整图像大小,保持纵横比并添加填充
Args:
img (Image.Image): 输入图像
Returns:
PIL.Image: 调整大小后的图像
"""
target_size = (224, 224)
# 计算调整大小的比例
ratio = min(target_size[0] / img.width, target_size[1] / img.height)
new_size = (int(img.width * ratio), int(img.height * ratio))
# 调整图像大小,保持长宽比
resized_img = img.resize(new_size, Image.LANCZOS)
# 创建新的填充图像
new_img = Image.new("RGB", target_size, (255, 255, 255))
# 计算粘贴位置(居中)
paste_position = ((target_size[0] - new_size[0]) // 2,
(target_size[1] - new_size[1]) // 2)
# 粘贴调整后的图像
new_img.paste(resized_img, paste_position)
return new_img
@classmethod
def extract_from_url(cls, image_url):
"""
从URL加载图像并提取特征向量
Args:
image_url (str): 图像URL
Returns:
numpy.ndarray: 特征向量
"""
if cls._model is None:
cls.initialize()
try:
# 下载图片
response = requests.get(image_url, stream=True)
response.raise_for_status() # 确保请求成功
# 将图片数据转换为 PIL Image 对象
image = Image.open(BytesIO(response.content)).convert("RGB")
return cls.extract_from_image(image)
except Exception as e:
print(f"Error extracting features from URL: {e}")
return None
@classmethod
def extract_from_image(cls, img):
"""
从PIL图像对象提取特征向量
Args:
img (Image.Image): 输入图像
Returns:
numpy.ndarray: 特征向量
"""
if cls._model is None:
cls.initialize()
try:
# 调整图像大小并添加填充
image = cls.resize_with_padding(img)
# 预处理并提取特征
image_tensor = cls._preprocess(image).unsqueeze(0).to(cls._device)
with torch.no_grad():
image_features = cls._model.encode_image(image_tensor)
# 归一化特征向量
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().flatten()
except Exception as e:
print(f"Error extracting features from image: {e}")
return None
\ No newline at end of file
......@@ -2,4 +2,34 @@ Flask~=3.1.1
python-dotenv~=1.1.0
pip~=25.1.1
protobuf~=6.31.0
filelock~=3.18.0
\ No newline at end of file
filelock~=3.18.0
torch~=2.7.0
requests~=2.32.3
numpy~=2.2.6
pillow~=11.2.1
typing_extensions~=4.13.2
MarkupSafe~=3.0.2
Werkzeug~=3.1.3
click~=8.2.1
Jinja2~=3.1.6
blinker~=1.9.0
itsdangerous~=2.2.0
clip~=1.0
packaging~=25.0
torchvision~=0.22.0
tqdm~=4.67.1
ftfy~=6.3.1
regex~=2024.11.6
wcwidth~=0.2.13
colorama~=0.4.6
typing_extensions~=4.13.2
charset-normalizer~=3.4.2
setuptools~=80.7.1
pytz~=2025.2
sympy~=1.14.0
mpmath~=1.3.0
networkx~=3.4.2
urllib3~=2.4.0
fsspec~=2025.3.2
ujson~=5.10.0
pandas~=2.2.3
\ No newline at end of file
import unittest
from app.models.feature_extractor import FeatureExtractor
class TestFeatureExtractorFunction(unittest.TestCase):
def test_feature_extractor(self):
url = "https://pc3oscdn.chillcy.com/3359847025/QSIiPR0XExYACM/00f9bdfa63158ec9477e4f7fe70f5989.jpg"
feature = FeatureExtractor.extract_from_url(url)
print(feature)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment