-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathauto_labeler.py
51 lines (39 loc) · 2 KB
/
auto_labeler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import sys
from PySide6.QtCore import QThread, Signal, Slot
from PIL import Image
import torch
class Auto_Labbeler(QThread):
result_signal = Signal(str) # Signal to send data back to the main thread
finished_signal = Signal() # Signal to indicate that processing is finished
def __init__(self, parent=None):
super().__init__(parent)
self.model = None
self.feature_extractor = None
self.image_path = ""
def load_model(self) :
from transformers import ViTFeatureExtractor, ViTForImageClassification
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224',
# num_labels=100,
# ignore_mismatched_sizes=True,
# map_location=torch.device('cpu')
)
self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
def set_image_path(self, image_path):
self.image_path = image_path
def run(self):
# Load and process the image
if not self.model:
self.load_model()
if self.image_path:
image = Image.open(self.image_path).convert('RGB').resize((244,244))
processed_image = self.feature_extractor(images=image, return_tensors="pt")
processed_image = {k: v.to(self.device) for k, v in processed_image.items()}
with torch.no_grad():
outputs = self.model(**processed_image)
predicted_class_idx = torch.argmax(outputs.logits, dim=-1).item()
predicted_class = self.model.config.id2label[predicted_class_idx]
self.result_signal.emit(predicted_class)
self.finished_signal.emit()