|
3 | 3 | from uuid import UUID
|
4 | 4 |
|
5 | 5 | from soundevent import data
|
6 |
| -from sqlalchemy import Select |
| 6 | +from sqlalchemy import Select, or_ |
7 | 7 |
|
8 | 8 | from whombat import models
|
9 | 9 | from whombat.filters import base
|
|
12 | 12 | "AnnotationProjectFilter",
|
13 | 13 | "DatasetFilter",
|
14 | 14 | "AnnotationTaskFilter",
|
| 15 | + "SearchRecordingsFilter", |
15 | 16 | ]
|
16 | 17 |
|
17 | 18 |
|
@@ -118,6 +119,25 @@ def filter(self, query: Select) -> Select:
|
118 | 119 | )
|
119 | 120 |
|
120 | 121 |
|
| 122 | +class IsCompletedFilter(base.Filter): |
| 123 | + """Filter for tasks if rejected.""" |
| 124 | + |
| 125 | + eq: bool | None = None |
| 126 | + |
| 127 | + def filter(self, query: Select) -> Select: |
| 128 | + """Filter the query.""" |
| 129 | + if self.eq is None: |
| 130 | + return query |
| 131 | + |
| 132 | + return query.where( |
| 133 | + models.AnnotationTask.status_badges.any( |
| 134 | + models.AnnotationStatusBadge.state |
| 135 | + == data.AnnotationState.completed, |
| 136 | + ) |
| 137 | + == self.eq, |
| 138 | + ) |
| 139 | + |
| 140 | + |
121 | 141 | class IsAssignedFilter(base.Filter):
|
122 | 142 | """Filter for tasks if assigned."""
|
123 | 143 |
|
@@ -206,6 +226,34 @@ def filter(self, query: Select) -> Select:
|
206 | 226 | )
|
207 | 227 |
|
208 | 228 |
|
| 229 | +class SearchRecordingsFilter(base.Filter): |
| 230 | + """Filter recordings by the dataset they are in.""" |
| 231 | + |
| 232 | + search_recordings: str | None = None |
| 233 | + |
| 234 | + def filter(self, query: Select) -> Select: |
| 235 | + """Filter the query.""" |
| 236 | + query = ( |
| 237 | + query.join( |
| 238 | + models.ClipAnnotation, |
| 239 | + models.AnnotationTask.clip_annotation_id |
| 240 | + == models.ClipAnnotation.id, |
| 241 | + ) |
| 242 | + .join( |
| 243 | + models.Clip, |
| 244 | + models.ClipAnnotation.clip_id == models.Clip.id, |
| 245 | + ) |
| 246 | + .join( |
| 247 | + models.Recording, |
| 248 | + models.Recording.id == models.Clip.recording_id, |
| 249 | + ) |
| 250 | + ) |
| 251 | + fields = [models.Recording.path] |
| 252 | + |
| 253 | + term = f"%{self.search_recordings}%" |
| 254 | + return query.where(or_(*[field.ilike(term) for field in fields])) |
| 255 | + |
| 256 | + |
209 | 257 | class SoundEventAnnotationTagFilter(base.Filter):
|
210 | 258 | """Filter for tasks by sound event annotation tag."""
|
211 | 259 |
|
@@ -258,10 +306,12 @@ def filter(self, query: Select) -> Select:
|
258 | 306 |
|
259 | 307 |
|
260 | 308 | AnnotationTaskFilter = base.combine(
|
| 309 | + SearchRecordingsFilter, |
261 | 310 | assigned_to=AssignedToFilter,
|
262 | 311 | pending=PendingFilter,
|
263 | 312 | verified=IsVerifiedFilter,
|
264 | 313 | rejected=IsRejectedFilter,
|
| 314 | + completed=IsCompletedFilter, |
265 | 315 | assigned=IsAssignedFilter,
|
266 | 316 | annotation_project=AnnotationProjectFilter,
|
267 | 317 | dataset=DatasetFilter,
|
|
0 commit comments