Skip to content
Snippets Groups Projects
Commit 5a18044a authored by Alexander Heidelbach's avatar Alexander Heidelbach
Browse files

Rework extended parameter treatment

parent ff66c962
Branches
No related tags found
No related merge requests found
......@@ -13,7 +13,7 @@ __all__ = ["ModelsReadType", "LimitType", "Model"]
ModelsReadType = Dict[str, Dict[str, Dict[str, Dict[str, Any]]]]
LimitType = Union[List[float], List[List[float]], list]
ModelsOutputType = Union[ModelsReadType, str, LimitType, CovarianceType]
ModelsOutputType = Union[ModelsReadType, str, LimitType, CovarianceType, float]
class Model(BaseModel):
......@@ -23,6 +23,7 @@ class Model(BaseModel):
modelstring: str,
models: ModelsReadType,
modelname: Optional[str] = None,
gof: float = 0.0,
) -> None:
super().__init__(obs)
......@@ -33,6 +34,8 @@ class Model(BaseModel):
self._node = parser.parse()
self._fraction = 0
self.gof = gof
def _prepare_models(self, models: ModelsReadType) -> Dict[str, Dict[str, Any]]:
tmpmodels = {} # type: Dict[str, Dict[str, Dict[str, Any]]]
......@@ -67,11 +70,16 @@ class Model(BaseModel):
pdf = model["pdf"]
parameter = {} # type: Dict[str, zfit.Parameter]
ext_param: Optional[zfit.Parameter] = None
for parametername, values in model["parameter"].items():
if not isinstance(values, list):
parameter.update(
{parametername: self.modelparameters[values["name"]]}
)
if not self.parameters[values["name"]].extended:
parameter.update(
{parametername: self.modelparameters[values["name"]]}
)
else:
ext_param = self.modelparameters[values["name"]]
else:
if len(values) > 0:
parameter.update(
......@@ -95,13 +103,8 @@ class Model(BaseModel):
(tuple(limits) if isinstance(limits[0], float) else limits),
)
ext_param: Optional[zfit.Parameter] = None
if node.token_type == TokenType.T_EXT:
assert node.ext_param is not None
ext_param = parameter.pop(node.ext_param)
zfit_model = getattr(zfit.pdf, pdf)(obs=obs, **parameter)
if node.token_type == TokenType.T_EXT:
if ext_param is not None:
zfit_model = zfit_model.create_extended(ext_param)
return zfit_model
......@@ -236,5 +239,6 @@ class Model(BaseModel):
limits.append(tmplimits) # type: ignore
output.update({"limits": limits})
output.update({"gof": self.gof})
return output
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment