Skip to content
Snippets Groups Projects
Commit d6a3fea6 authored by Alexander Heidelbach's avatar Alexander Heidelbach
Browse files

Update zfit pdf getter and adjust for new extended features in ModelParser

parent 7c1822ee
Branches
No related tags found
No related merge requests found
......@@ -14,12 +14,6 @@ __all__ = ["ModelsReadType", "LimitType", "Model"]
ModelsReadType = Dict[str, Dict[str, Dict[str, Dict[str, Any]]]]
LimitType = Union[List[float], List[List[float]], list]
ModelsOutputType = Union[ModelsReadType, str, LimitType, CovarianceType]
ModelsDict = {
"DoubleCB": zfit.pdf.DoubleCB,
"Cauchy": zfit.pdf.Cauchy,
"Gauss": zfit.pdf.Gauss,
"CrystalBall": zfit.pdf.CrystalBall,
}
class Model(BaseModel):
......@@ -55,10 +49,11 @@ class Model(BaseModel):
), f"[Models]: Please provide modelparameters for model '{modelname}'"
for parametername, values in modelparameters["parameter"].items():
assert (
"name" in values.keys()
), f"[Models]: Please provide an unique pararametername for '{parametername}' in '{modelname}'"
try:
parameter = {values["name"]: values}
except TypeError:
parameter = {parametername: values}
self._parameters.update(dictToFitparameter(**parameter))
tmpmodels.update({modelname: modelparameters})
......@@ -66,14 +61,29 @@ class Model(BaseModel):
return tmpmodels
def _compute_model(self, node: Node) -> zfit.core.basepdf.BasePDF:
if node.token_type == TokenType.T_MODEL:
if node.token_type == TokenType.T_MODEL or node.token_type == TokenType.T_EXT:
assert node.value is not None
model = self._models[node.value]
pdf = model["pdf"]
parameter = {} # type: Dict[str, zfit.Parameter]
for parametername, values in model["parameter"].items():
parameter.update({parametername: self.modelparameters[values["name"]]})
if not isinstance(values, list):
parameter.update(
{parametername: self.modelparameters[values["name"]]}
)
else:
if len(values) > 0:
parameter.update(
{
parametername: [
self.modelparameters[value["name"]]
for value in values
]
}
)
else:
parameter.update({parametername: []})
obs = self.obs
if "limits" in model.keys():
......@@ -85,7 +95,14 @@ class Model(BaseModel):
(tuple(limits) if isinstance(limits[0], float) else limits),
)
ext_param: Optional[zfit.Parameter] = None
if node.token_type == TokenType.T_EXT:
assert node.ext_param is not None
ext_param = parameter.pop(node.ext_param)
zfit_model = getattr(zfit.pdf, pdf)(obs=obs, **parameter)
if node.token_type == TokenType.T_EXT:
zfit_model = zfit_model.create_extended(ext_param)
return zfit_model
......@@ -102,10 +119,16 @@ class Model(BaseModel):
kwargs = {} # type: Dict[str, Any]
if node.token_type == TokenType.T_PLUS:
if left_result.is_extended != right_result.is_extended:
raise Exception(
"[Models]: SumPDF does not work with one pdf extended and not the other."
)
if not left_result.is_extended:
fraction_infodict = {
f"fraction{self._fraction}": {
"name": f"fraction{self._fraction}",
"value": 0.5,
"value": 0.1,
"lower": 0.0,
"upper": 1.0,
}
......@@ -126,12 +149,14 @@ class Model(BaseModel):
}
)
kwargs.update({"pdfs": [left_result, right_result]})
kwargs.update(
{"fracs": self._modelparameters[f"fraction{self._fraction}"]}
)
self._fraction += 1
kwargs.update({"pdfs": [left_result, right_result]})
elif node.token_type == TokenType.T_CONV:
kwargs.update({"func": left_result})
kwargs.update({"kernel": right_result})
......@@ -150,6 +175,28 @@ class Model(BaseModel):
return self._model
@property
def models(self) -> List["Model"]:
models: List["Model"] = []
for modelname, model in self._models.items():
models.append(
Model(
obs=self.obs,
modelstring=modelname,
models={modelname: model},
modelname=modelname,
)
)
for model in models:
for modelparametername, modelparameter in self.modelparameters.items():
if modelparametername in model.modelparameters.keys():
model.update_modelparameter(
name=modelparametername, value=modelparameter.value().numpy()
)
return models
@property
def dict(self) -> Dict[str, ModelsOutputType]:
output = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment