Commit be610aa6 authored by zhengyaoqiu's avatar zhengyaoqiu

优化

parent 324f97a4
......@@ -10,19 +10,19 @@ class FeatureExtractor:
__logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
def __init__(self, device = "xpu" if torch.xpu.is_available() else "cpu", model_name = "ViT-B/32"):
def __init__(self, device = "xpu" if torch.xpu.is_available() else "cpu", model_name = "ViT-L/14@336px"):
device = "cpu"
self.model, self.preprocess = self.init_model(device, model_name)
self.device = device
@staticmethod
def init_model(device="xpu" if torch.xpu.is_available() else "cpu", model_name="ViT-B/32"):
def init_model(device="xpu" if torch.xpu.is_available() else "cpu", model_name="ViT-L/14@336px"):
print(f"创建并初始化 CLIP 模型: {model_name} 在设备: {device}")
model, preprocess = clip.load(model_name, device=device)
return model, preprocess
@staticmethod
def resize_with_padding(img, target_size = (224, 224)):
def resize_with_padding(img, target_size = (336, 336)):
"""
调整图像大小,保持纵横比并添加填充
......@@ -97,7 +97,7 @@ class FeatureExtractor:
# device = "xpu" if torch.xpu.is_available() else "cpu"
# device = "cpu"
# model_name = "ViT-B/32"
# model_name = "ViT-L/14@336px"
# model, preprocess = self.init_model(device, model_name)
try:
......@@ -113,7 +113,9 @@ class FeatureExtractor:
# 归一化特征向量
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().astype(np.float32).flatten()
result = image_features.cpu().numpy().astype(np.float32).flatten()
# ViT-L/14@336px 需要转换列表结构
return result.tolist()
except Exception as e:
self.__logger.error(f"Error extracting features from image: {e}")
......
......@@ -13,12 +13,12 @@ class ImageUpload:
def upload_many(self, bucket, image2keys):
keys = []
vectors = []
images = [image2key["image"] for image2key in image2keys]
new_images, exist_images = self.milvus.filter_new_urls(bucket, images)
images = []
print(f"总图片数: {len(images)}")
all_images = [image2key["image"] for image2key in image2keys]
new_images, exist_images = self.milvus.filter_new_urls(bucket, all_images)
print(f"总图片数: {len(all_images)}")
print(f"新图片数: {len(new_images)}")
print(f"已存在图片数: {len(exist_images)}")
print(f"新图片: {new_images}")
......
......@@ -11,7 +11,8 @@ class TestCreateCollectionFunction(unittest.TestCase):
fields: List[FieldSchema] = [
FieldSchema(name="image", dtype=DataType.VARCHAR, max_length=256, is_primary=True, auto_id=False),
FieldSchema(name="key", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=512) # CLIP ViT-B/32的特征维度为512
# FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=512) # CLIP ViT-B/32的特征维度为512
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=768) # CLIP ViT-L/14@336px的特征维度为512
]
MilvusClient().connect().create_collection("pc3", fields, "PC3 图片向量存储")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment