Commit 160369c6 authored by zhengyaoqiu's avatar zhengyaoqiu

优化

parent 38adf698
......@@ -10,13 +10,13 @@ class FeatureExtractor:
__logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
def __init__(self, device = "xpu" if torch.xpu.is_available() else "cpu", model_name = "ViT-L/14@336px"):
def __init__(self, device = "xpu" if torch.xpu.is_available() else "cpu", model_name = "ViT-B/32"):
device = "cpu"
self.model, self.preprocess = self.init_model(device, model_name)
self.device = device
@staticmethod
def init_model(device="xpu" if torch.xpu.is_available() else "cpu", model_name="ViT-L/14@336px"):
def init_model(device="xpu" if torch.xpu.is_available() else "cpu", model_name="ViT-B/32"):
print(f"创建并初始化 CLIP 模型: {model_name} 在设备: {device}")
model, preprocess = clip.load(model_name, device=device)
return model, preprocess
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment