Skip to content

Commit

Permalink
Functionality to prioritize user specified fields
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya-balachander committed Nov 7, 2024
1 parent f86f94d commit 7fae06e
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 60 deletions.
9 changes: 5 additions & 4 deletions cumulusci/tasks/bulkdata/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def configure_step(self, mapping):
num_records_in_target = sobject_map.get(mapping.sf_object, None)

# Check for similarity selection strategy and modify fields accordingly
if mapping.selection_strategy == "similarity":
if mapping.select_options.strategy == "similarity":
# Describe the object to determine polymorphic lookups
describe_result = self.sf.restful(
f"sobjects/{mapping.sf_object}/describe"
Expand Down Expand Up @@ -469,8 +469,9 @@ def configure_step(self, mapping):
fields=fields,
api=mapping.api,
volume=volume,
selection_strategy=mapping.selection_strategy,
selection_filter=mapping.selection_filter,
selection_strategy=mapping.select_options.strategy,
selection_filter=mapping.select_options.filter,
selection_priority_fields=mapping.select_options.priority_fields,
content_type=content_type,
)
return step, query
Expand Down Expand Up @@ -577,7 +578,7 @@ def _query_db(self, mapping):
transformers = []
if (
mapping.action == DataOperationType.SELECT
and mapping.selection_strategy == "similarity"
and mapping.select_options.strategy == "similarity"
):
transformers.append(
DynamicLookupQueryExtender(
Expand Down
44 changes: 24 additions & 20 deletions cumulusci/tasks/bulkdata/mapping_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,21 @@
from typing import IO, Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

from pydantic import Field, ValidationError, root_validator, validator
from requests.structures import CaseInsensitiveDict as RequestsCaseInsensitiveDict
from simple_salesforce import Salesforce
from typing_extensions import Literal

from cumulusci.core.enums import StrEnum
from cumulusci.core.exceptions import BulkDataException
from cumulusci.tasks.bulkdata.dates import iso_to_date
from cumulusci.tasks.bulkdata.select_utils import SelectStrategy
from cumulusci.tasks.bulkdata.select_utils import SelectOptions, SelectStrategy
from cumulusci.tasks.bulkdata.step import DataApi, DataOperationType
from cumulusci.tasks.bulkdata.utils import CaseInsensitiveDict
from cumulusci.utils import convert_to_snake_case
from cumulusci.utils.yaml.model_parser import CCIDictModel

logger = getLogger(__name__)


class CaseInsensitiveDict(RequestsCaseInsensitiveDict):
def __init__(self, *args, **kwargs):
self._canonical_keys = {}
super().__init__(*args, **kwargs)

def canonical_key(self, name):
return self._canonical_keys[name.lower()]

def __setitem__(self, key, value):
super().__setitem__(key, value)
self._canonical_keys[key.lower()] = key


class MappingLookup(CCIDictModel):
"Lookup relationship between two tables."
table: Union[str, List[str]] # Support for polymorphic lookups
Expand Down Expand Up @@ -85,7 +72,7 @@ class BulkMode(StrEnum):

ENUM_VALUES = {
v.value.lower(): v.value
for enum in [BulkMode, DataApi, DataOperationType, SelectStrategy]
for enum in [BulkMode, DataApi, DataOperationType]
for v in enum.__members__.values()
}

Expand All @@ -108,13 +95,12 @@ class MappingStep(CCIDictModel):
)
anchor_date: Optional[Union[str, date]] = None
soql_filter: Optional[str] = None # soql_filter property
selection_strategy: SelectStrategy = SelectStrategy.STANDARD # selection strategy
selection_filter: Optional[str] = (
None # filter to be added at the end of select query
select_options: Optional[SelectOptions] = Field(
default_factory=lambda: SelectOptions(strategy=SelectStrategy.STANDARD)
)
update_key: T.Union[str, T.Tuple[str, ...]] = () # only for upserts

@validator("bulk_mode", "api", "action", "selection_strategy", pre=True)
@validator("bulk_mode", "api", "action", pre=True)
def case_normalize(cls, val):
if isinstance(val, Enum):
return val
Expand All @@ -134,6 +120,24 @@ def split_update_key(cls, val):
), "`update_key` should be a field name or list of field names."
assert False, "Should be unreachable" # pragma: no cover

@root_validator
def validate_priority_fields(cls, values):
select_options = values.get("select_options")
fields_ = values.get("fields_", {})

if select_options and select_options.priority_fields:
priority_field_names = set(select_options.priority_fields.keys())
field_names = set(fields_.keys())

# Check if all priority fields are present in the fields
missing_fields = priority_field_names - field_names
if missing_fields:
raise ValueError(
f"Priority fields {missing_fields} are not present in 'fields'"
)

return values

def get_oid_as_pk(self):
"""Returns True if using Salesforce Ids as primary keys."""
return "Id" in self.fields
Expand Down
6 changes: 5 additions & 1 deletion cumulusci/tasks/bulkdata/query_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ def columns_to_add(self):
load_fields = lookup_mapping_step.get_load_field_list()
for field in load_fields:
matching_column = next(
(col for col in aliased_table.columns if col.name == field)
(
col
for col in aliased_table.columns
if col.name == lookup_mapping_step.fields[field]
)
)
columns.append(
matching_column.label(f"{aliased_table.name}_{field}")
Expand Down
Loading

0 comments on commit 7fae06e

Please sign in to comment.