Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
I
image_search
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
service
image_search
Commits
285b001c
Commit
285b001c
authored
May 24, 2025
by
zhengyaoqiu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
模块重构
parent
ef587af1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
307 deletions
+72
-307
feature_extractor.py
app/models/feature_extractor.py
+4
-4
image_search.py
app/models/image_search.py
+12
-206
milvus.py
app/models/milvus.py
+56
-97
No files found.
app/models/feature_extractor.py
View file @
285b001c
...
@@ -11,12 +11,12 @@ import logging
...
@@ -11,12 +11,12 @@ import logging
class
FeatureExtractor
:
class
FeatureExtractor
:
__model
=
None
#
__model = None
__preprocess
=
None
#
__preprocess = None
__logger
=
logging
.
getLogger
(
__name__
)
# __device = "ViT-B/32"
__instance
=
None
__instance
=
None
__lock
=
threading
.
Lock
()
__lock
=
threading
.
Lock
()
__
device
=
"ViT-B/32"
__
logger
=
logging
.
getLogger
(
__name__
)
def
__new__
(
cls
,
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
,
model_name
=
"ViT-B/32"
):
def
__new__
(
cls
,
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
,
model_name
=
"ViT-B/32"
):
# 第一次检查 - 不带锁
# 第一次检查 - 不带锁
...
...
app/models/image_search.py
View file @
285b001c
from
typing
import
List
,
Tuple
,
Optional
,
Union
import
logging
import
logging
import
numpy
as
np
from
PIL
import
Image
from
pymilvus
import
Collection
# 引入特征提取器和MilvusClient
from
feature_extractor
import
FeatureExtractor
from
milvus
import
MilvusClient
# 假设MilvusClient在milvus_client.py文件中
class
ImageSearch
:
class
ImageSearch
:
"""
__logger
=
logging
.
getLogger
(
__name__
)
图像搜索类,提供基于图像相似度的搜索功能
"""
_logger
=
logging
.
getLogger
(
__name__
)
_feature_extractor_initialized
=
False
_milvus_client
=
None
@
classmethod
def
_initialize_feature_extractor
(
cls
)
->
None
:
"""初始化特征提取器"""
if
not
cls
.
_feature_extractor_initialized
:
try
:
FeatureExtractor
.
initialize
()
cls
.
_feature_extractor_initialized
=
True
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Failed to initialize feature extractor: {e}"
)
raise
RuntimeError
(
f
"Failed to initialize feature extractor: {e}"
)
@
classmethod
def
get_milvus_client
(
cls
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
MilvusClient
:
"""
获取或创建Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
MilvusClient: Milvus客户端实例
"""
if
cls
.
_milvus_client
is
None
:
cls
.
_milvus_client
=
MilvusClient
(
host
,
port
,
collection_name
)
.
connect
()
return
cls
.
_milvus_client
@
classmethod
def
get_collection
(
cls
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
Collection
:
"""
获取Milvus集合
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
Collection: Milvus集合对象
Raises:
def
__init__
(
self
,
feature_extractor
,
milvus
):
RuntimeError: 如果无法获取集合
self
.
feature_extractor
=
feature_extractor
"""
self
.
milvus
=
milvus
try
:
client
=
cls
.
get_milvus_client
(
host
,
port
,
collection_name
)
return
client
.
get_collection
()
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Failed to get collection: {e}"
)
raise
RuntimeError
(
f
"Failed to get collection: {e}"
)
@
classmethod
def
extract_features
(
cls
,
image
:
Union
[
str
,
Image
.
Image
])
->
np
.
ndarray
:
"""
从图像提取特征向量
Args:
image: 图像URL或PIL图像对象
Returns:
np.ndarray: 特征向量
Raises:
ValueError: 如果特征提取失败
"""
# 确保特征提取器已初始化
cls
.
_initialize_feature_extractor
()
try
:
# id = product_id
# 根据图像类型调用相应的提取方法
def
image_to_image_search
(
self
,
image
,
key_name
,
top_k
=
100
):
if
isinstance
(
image
,
str
):
features
=
FeatureExtractor
.
extract_from_url
(
image
)
elif
isinstance
(
image
,
Image
.
Image
):
features
=
FeatureExtractor
.
extract_from_image
(
image
)
else
:
raise
ValueError
(
f
"Unsupported image type: {type(image)}"
)
if
features
is
None
:
raise
ValueError
(
"Feature extraction returned None"
)
return
features
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Feature extraction failed: {e}"
)
raise
ValueError
(
f
"Failed to extract features: {e}"
)
@
classmethod
def
image_to_image_search
(
cls
,
image
:
Union
[
str
,
Image
.
Image
],
top_k
:
int
=
100
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
List
[
Tuple
[
str
,
float
]]:
"""
使用图像查询相似图像
Args:
image: 查询图像的URL或PIL图像对象
top_k: 返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[Tuple[str, float]]: 产品ID和相似度分数的列表,按相似度降序排列
Raises:
ValueError: 如果图像处理或搜索过程中出错
"""
try
:
try
:
# 提取查询图像的特征
# 提取查询图像的特征
query_embedding
=
cls
.
extract_features
(
image
)
query_embedding
=
self
.
feature_extractor
.
extract_features
(
image
)
# 获取Milvus客户端并搜索
results
=
self
.
milvus
.
search
(
query_embedding
,
limit
=
top_k
)
client
=
cls
.
get_milvus_client
(
host
,
port
,
collection_name
)
results
=
client
.
search
(
query_embedding
,
limit
=
top_k
)
# 处理结果
# 处理结果
if
not
results
or
len
(
results
)
==
0
:
if
not
results
or
len
(
results
)
==
0
:
return
[]
return
[]
# 返回结果
# 返回结果
product_ids
=
[
hit
.
entity
.
get
(
'product_id'
)
for
hit
in
results
[
0
]]
keys
=
[
hit
.
entity
.
get
(
key_name
)
for
hit
in
results
[
0
]]
scores
=
[
hit
.
score
for
hit
in
results
[
0
]]
scores
=
[
hit
.
score
for
hit
in
results
[
0
]]
return
list
(
zip
(
product_id
s
,
scores
))
return
list
(
zip
(
key
s
,
scores
))
except
Exception
as
e
:
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Image search failed: {e}"
)
self
.
_
_logger
.
error
(
f
"Image search failed: {e}"
)
raise
ValueError
(
f
"Image search failed: {e}"
)
raise
ValueError
(
f
"Image search failed: {e}"
)
\ No newline at end of file
@
classmethod
def
batch_image_search
(
cls
,
images
:
List
[
Union
[
str
,
Image
.
Image
]],
top_k
:
int
=
100
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
List
[
List
[
Tuple
[
str
,
float
]]]:
"""
批量图像搜索
Args:
images: 查询图像URL或PIL图像对象的列表
top_k: 每个查询返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[List[Tuple[str, float]]]: 每个查询图像对应的结果列表
"""
# 确保特征提取器已初始化
cls
.
_initialize_feature_extractor
()
# 批量提取特征
batch_embeddings
=
[]
valid_indices
=
[]
for
i
,
image
in
enumerate
(
images
):
try
:
features
=
cls
.
extract_features
(
image
)
batch_embeddings
.
append
(
features
)
valid_indices
.
append
(
i
)
except
Exception
as
e
:
cls
.
_logger
.
warning
(
f
"Feature extraction failed for image at index {i}: {e}"
)
# 如果没有有效的特征,返回空列表
if
not
batch_embeddings
:
return
[[]
for
_
in
range
(
len
(
images
))]
try
:
# 获取集合
collection
=
cls
.
get_collection
(
host
,
port
,
collection_name
)
collection
.
load
()
# 批量搜索
search_params
=
{
"metric_type"
:
"IP"
,
"params"
:
{
"ef"
:
100
}}
batch_results
=
collection
.
search
(
data
=
batch_embeddings
,
anns_field
=
"embedding"
,
param
=
search_params
,
limit
=
top_k
,
output_fields
=
[
"image"
,
"product_id"
]
)
# 处理结果
all_results
=
[[]
for
_
in
range
(
len
(
images
))]
for
i
,
results
in
enumerate
(
batch_results
):
original_idx
=
valid_indices
[
i
]
product_ids
=
[
hit
.
entity
.
get
(
'product_id'
)
for
hit
in
results
]
scores
=
[
hit
.
score
for
hit
in
results
]
all_results
[
original_idx
]
=
list
(
zip
(
product_ids
,
scores
))
return
all_results
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Batch image search failed: {e}"
)
raise
ValueError
(
f
"Batch image search failed: {e}"
)
app/models/milvus.py
View file @
285b001c
...
@@ -2,111 +2,70 @@ from typing import Dict, List, Any, Optional, Union
...
@@ -2,111 +2,70 @@ from typing import Dict, List, Any, Optional, Union
import
numpy
as
np
import
numpy
as
np
from
pymilvus
import
connections
,
Collection
,
FieldSchema
,
CollectionSchema
,
DataType
,
utility
from
pymilvus
import
connections
,
Collection
,
FieldSchema
,
CollectionSchema
,
DataType
,
utility
# , collection_name: str = "image_collection"
class
MilvusClient
:
class
MilvusClient
:
def
__init__
(
self
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
None
:
def
__init__
(
self
,
host
=
"localhost"
,
port
=
"19530"
):
"""初始化Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
"""
self
.
host
:
str
=
host
self
.
host
:
str
=
host
self
.
port
:
str
=
port
self
.
port
:
str
=
port
self
.
collection_name
:
str
=
collection_name
self
.
collection
:
Optional
[
Collection
]
=
None
def
connect
(
self
)
->
'MilvusClient'
:
def
connect
(
self
,
alias
=
"default"
)
->
'MilvusClient'
:
"""连接到Milvus服务器
connections
.
connect
(
alias
,
host
=
self
.
host
,
port
=
self
.
port
)
Returns:
MilvusClient: 当前客户端实例,支持链式调用
"""
connections
.
connect
(
"default"
,
host
=
self
.
host
,
port
=
self
.
port
)
return
self
return
self
def
get_collection
(
self
)
->
Collection
:
# # 定义集合结构
"""获取或创建集合
# fields: List[FieldSchema] = [
# FieldSchema(name="image", dtype=DataType.VARCHAR, max_length=256, is_primary=True, auto_id=False),
Returns:
# FieldSchema(name="product_id", dtype=DataType.VARCHAR, max_length=256),
Collection: Milvus集合对象
# FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=512) # CLIP ViT-B/32的特征维度为512
"""
# ]
# 定义集合结构
@
staticmethod
fields
:
List
[
FieldSchema
]
=
[
def
create_new_collection
(
collection_name
,
fields
,
description
)
->
Collection
:
FieldSchema
(
name
=
"image"
,
dtype
=
DataType
.
VARCHAR
,
max_length
=
256
,
is_primary
=
True
,
auto_id
=
False
),
schema
:
CollectionSchema
=
CollectionSchema
(
fields
,
description
)
FieldSchema
(
name
=
"product_id"
,
dtype
=
DataType
.
VARCHAR
,
max_length
=
256
),
FieldSchema
(
name
=
"embedding"
,
dtype
=
DataType
.
FLOAT_VECTOR
,
dim
=
512
)
# CLIP ViT-B/32的特征维度为512
# 创建集合
]
collection
=
Collection
(
name
=
collection_name
,
schema
=
schema
)
schema
:
CollectionSchema
=
CollectionSchema
(
fields
,
"图像特征集合"
)
return
collection
# 创建或获取集合
if
utility
.
has_collection
(
self
.
collection_name
):
@
staticmethod
self
.
collection
=
Collection
(
name
=
self
.
collection_name
)
def
get_collection
(
collection_name
):
else
:
if
not
utility
.
has_collection
(
collection_name
):
self
.
collection
=
Collection
(
name
=
self
.
collection_name
,
schema
=
schema
)
raise
RuntimeError
(
f
"集合 '{collection_name}' 不存在"
)
self
.
_create_index
()
return
Collection
(
name
=
collection_name
)
return
self
.
collection
# index_params: Dict[str, Any] = {
# "index_type": "HNSW",
def
_create_index
(
self
)
->
None
:
# "metric_type": "IP", # 内积相似度
"""创建索引"""
# "params": {"M": 16, "efConstruction": 200}
index_params
:
Dict
[
str
,
Any
]
=
{
# }
"index_type"
:
"HNSW"
,
# def create_index(self, index_params) -> None:
"metric_type"
:
"IP"
,
# 内积相似度
# if self.collection is not None:
"params"
:
{
"M"
:
16
,
"efConstruction"
:
200
}
# self.collection.create_index(field_name="embedding", index_params=index_params)
}
if
self
.
collection
is
not
None
:
# anns_field = embedding
self
.
collection
.
create_index
(
field_name
=
"embedding"
,
index_params
=
index_params
)
# output_fields=["image", "product_id"]
# search_params: Dict[str, Any] = {"metric_type": "IP", "params": {"ef": 100}}
def
search
(
self
,
vector
:
Union
[
List
[
float
],
np
.
ndarray
],
limit
:
int
=
10
)
->
Any
:
"""搜索相似向量
def
search
(
self
,
collection_name
,
vector
,
anns_field
,
search_params
,
output_fields
,
top_k
=
10
)
->
Any
:
collection
=
self
.
get_collection
(
collection_name
)
Args:
vector: 查询向量,可以是列表或numpy数组
results
=
collection
.
search
(
limit: 返回结果数量
Returns:
查询结果
"""
if
self
.
collection
is
None
:
self
.
get_collection
()
self
.
collection
.
load
()
search_params
:
Dict
[
str
,
Any
]
=
{
"metric_type"
:
"IP"
,
"params"
:
{
"ef"
:
100
}}
results
=
self
.
collection
.
search
(
data
=
[
vector
],
data
=
[
vector
],
anns_field
=
"embedding"
,
anns_field
=
anns_field
,
param
=
search_params
,
param
=
search_params
,
limit
=
limit
,
limit
=
top_k
,
output_fields
=
[
"image"
,
"product_id"
]
output_fields
=
output_fields
)
)
return
results
return
results
def
insert
(
self
,
data
:
List
[
Dict
[
str
,
Any
]])
->
None
:
# entities: List[List[Any]] = [
"""插入数据
# [item["image"] for item in data],
# [item["product_id"] for item in data],
Args:
# [item["embedding"] for item in data]
data: 包含image、product_id和embedding的字典列表
# ]
每个字典应包含键: "image", "product_id", "embedding"
def
insert
(
self
,
collection_name
,
entities
):
"""
self
.
get_collection
(
collection_name
)
.
insert
(
entities
)
if
self
.
collection
is
None
:
self
.
get_collection
()
@
staticmethod
def
close
(
alias
=
"default"
)
->
None
:
entities
:
List
[
List
[
Any
]]
=
[
connections
.
disconnect
(
alias
)
[
item
[
"image"
]
for
item
in
data
],
[
item
[
"product_id"
]
for
item
in
data
],
[
item
[
"embedding"
]
for
item
in
data
]
]
self
.
collection
.
insert
(
entities
)
def
drop_collection
(
self
)
->
None
:
"""删除集合"""
if
utility
.
has_collection
(
self
.
collection_name
):
utility
.
drop_collection
(
self
.
collection_name
)
self
.
collection
=
None
def
close
(
self
)
->
None
:
"""关闭连接"""
connections
.
disconnect
(
"default"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment