Commit be610aa6 authored by zhengyaoqiu's avatar zhengyaoqiu

优化

parent 324f97a4
...@@ -10,19 +10,19 @@ class FeatureExtractor: ...@@ -10,19 +10,19 @@ class FeatureExtractor:
__logger = logging.getLogger(__name__) __logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
def __init__(self, device = "xpu" if torch.xpu.is_available() else "cpu", model_name = "ViT-B/32"): def __init__(self, device = "xpu" if torch.xpu.is_available() else "cpu", model_name = "ViT-L/14@336px"):
device = "cpu" device = "cpu"
self.model, self.preprocess = self.init_model(device, model_name) self.model, self.preprocess = self.init_model(device, model_name)
self.device = device self.device = device
@staticmethod @staticmethod
def init_model(device="xpu" if torch.xpu.is_available() else "cpu", model_name="ViT-B/32"): def init_model(device="xpu" if torch.xpu.is_available() else "cpu", model_name="ViT-L/14@336px"):
print(f"创建并初始化 CLIP 模型: {model_name} 在设备: {device}") print(f"创建并初始化 CLIP 模型: {model_name} 在设备: {device}")
model, preprocess = clip.load(model_name, device=device) model, preprocess = clip.load(model_name, device=device)
return model, preprocess return model, preprocess
@staticmethod @staticmethod
def resize_with_padding(img, target_size = (224, 224)): def resize_with_padding(img, target_size = (336, 336)):
""" """
调整图像大小,保持纵横比并添加填充 调整图像大小,保持纵横比并添加填充
...@@ -97,7 +97,7 @@ class FeatureExtractor: ...@@ -97,7 +97,7 @@ class FeatureExtractor:
# device = "xpu" if torch.xpu.is_available() else "cpu" # device = "xpu" if torch.xpu.is_available() else "cpu"
# device = "cpu" # device = "cpu"
# model_name = "ViT-B/32" # model_name = "ViT-L/14@336px"
# model, preprocess = self.init_model(device, model_name) # model, preprocess = self.init_model(device, model_name)
try: try:
...@@ -113,7 +113,9 @@ class FeatureExtractor: ...@@ -113,7 +113,9 @@ class FeatureExtractor:
# 归一化特征向量 # 归一化特征向量
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().astype(np.float32).flatten() result = image_features.cpu().numpy().astype(np.float32).flatten()
# ViT-L/14@336px 需要转换列表结构
return result.tolist()
except Exception as e: except Exception as e:
self.__logger.error(f"Error extracting features from image: {e}") self.__logger.error(f"Error extracting features from image: {e}")
......
...@@ -13,12 +13,12 @@ class ImageUpload: ...@@ -13,12 +13,12 @@ class ImageUpload:
def upload_many(self, bucket, image2keys): def upload_many(self, bucket, image2keys):
keys = [] keys = []
vectors = [] vectors = []
images = [image2key["image"] for image2key in image2keys]
new_images, exist_images = self.milvus.filter_new_urls(bucket, images)
images = [] images = []
print(f"总图片数: {len(images)}") all_images = [image2key["image"] for image2key in image2keys]
new_images, exist_images = self.milvus.filter_new_urls(bucket, all_images)
print(f"总图片数: {len(all_images)}")
print(f"新图片数: {len(new_images)}") print(f"新图片数: {len(new_images)}")
print(f"已存在图片数: {len(exist_images)}") print(f"已存在图片数: {len(exist_images)}")
print(f"新图片: {new_images}") print(f"新图片: {new_images}")
......
...@@ -11,7 +11,8 @@ class TestCreateCollectionFunction(unittest.TestCase): ...@@ -11,7 +11,8 @@ class TestCreateCollectionFunction(unittest.TestCase):
fields: List[FieldSchema] = [ fields: List[FieldSchema] = [
FieldSchema(name="image", dtype=DataType.VARCHAR, max_length=256, is_primary=True, auto_id=False), FieldSchema(name="image", dtype=DataType.VARCHAR, max_length=256, is_primary=True, auto_id=False),
FieldSchema(name="key", dtype=DataType.VARCHAR, max_length=256), FieldSchema(name="key", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=512) # CLIP ViT-B/32的特征维度为512 # FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=512) # CLIP ViT-B/32的特征维度为512
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=768) # CLIP ViT-L/14@336px的特征维度为512
] ]
MilvusClient().connect().create_collection("pc3", fields, "PC3 图片向量存储") MilvusClient().connect().create_collection("pc3", fields, "PC3 图片向量存储")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment