Add new method sqlmodel_update() to update models in place, including an update parameter for extra data (#804)

This commit is contained in:
Sebastián Ramírez
2024-02-17 14:49:39 +01:00
committed by GitHub
parent 7fec884864
commit fa12c5d87b
15 changed files with 1871 additions and 26 deletions

View File

@@ -6,6 +6,7 @@ from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Dict,
ForwardRef,
Generator,
@@ -18,6 +19,7 @@ from typing import (
)
from pydantic import VERSION as PYDANTIC_VERSION
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from typing_extensions import get_args, get_origin
@@ -46,9 +48,11 @@ class ObjectWithUpdateWrapper:
update: Dict[str, Any]
def __getattribute__(self, __name: str) -> Any:
if __name in self.update:
return self.update[__name]
return getattr(self.obj, __name)
update = super().__getattribute__("update")
obj = super().__getattribute__("obj")
if __name in update:
return update[__name]
return getattr(obj, __name)
def _is_union_type(t: Any) -> bool:
@@ -94,9 +98,14 @@ if IS_PYDANTIC_V2:
) -> None:
model.model_config[parameter] = value # type: ignore[literal-required]
def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]:
def get_model_fields(model: InstanceOrType[BaseModel]) -> Dict[str, "FieldInfo"]:
return model.model_fields
def get_fields_set(
object: InstanceOrType["SQLModel"],
) -> Union[Set[str], Callable[[BaseModel], Set[str]]]:
return object.model_fields_set
def init_pydantic_private_attrs(new_object: InstanceOrType["SQLModel"]) -> None:
object.__setattr__(new_object, "__pydantic_fields_set__", set())
object.__setattr__(new_object, "__pydantic_extra__", None)
@@ -384,9 +393,14 @@ else:
) -> None:
setattr(model.__config__, parameter, value) # type: ignore
def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]:
def get_model_fields(model: InstanceOrType[BaseModel]) -> Dict[str, "FieldInfo"]:
return model.__fields__ # type: ignore
def get_fields_set(
object: InstanceOrType["SQLModel"],
) -> Union[Set[str], Callable[[BaseModel], Set[str]]]:
return object.__fields_set__
def init_pydantic_private_attrs(new_object: InstanceOrType["SQLModel"]) -> None:
object.__setattr__(new_object, "__fields_set__", set())

View File

@@ -758,7 +758,6 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
update=update,
)
# TODO: remove when deprecating Pydantic v1, only for compatibility
def model_dump(
self,
*,
@@ -869,3 +868,32 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
exclude_unset=exclude_unset,
update=update,
)
def sqlmodel_update(
self: _TSQLModel,
obj: Union[Dict[str, Any], BaseModel],
*,
update: Union[Dict[str, Any], None] = None,
) -> _TSQLModel:
use_update = (update or {}).copy()
if isinstance(obj, dict):
for key, value in {**obj, **use_update}.items():
if key in get_model_fields(self):
setattr(self, key, value)
elif isinstance(obj, BaseModel):
for key in get_model_fields(obj):
if key in use_update:
value = use_update.pop(key)
else:
value = getattr(obj, key)
setattr(self, key, value)
for remaining_key in use_update:
if remaining_key in get_model_fields(self):
value = use_update.pop(remaining_key)
setattr(self, remaining_key, value)
else:
raise ValueError(
"Can't use sqlmodel_update() with something that "
f"is not a dict or SQLModel or Pydantic model: {obj}"
)
return self