Skip to content
Snippets Groups Projects
Commit 7c1822ee authored by Alexander Heidelbach's avatar Alexander Heidelbach
Browse files

Add method to update fitparameters and some linting

parent b5f9d77f
Branches
No related tags found
No related merge requests found
......@@ -3,7 +3,12 @@ from typing import Dict, Optional, Union
import zfit
import zfit.core
from .FitParameter import Fitparameter, FitCovarianceMatrix, CovarianceInputType, CovarianceType
from .FitParameter import (
Fitparameter,
FitCovarianceMatrix,
CovarianceInputType,
CovarianceType,
)
class BaseModel(ABC):
......@@ -18,9 +23,7 @@ class BaseModel(ABC):
self._modelparameters = {} # type: Dict[str, zfit.Parameter]
self._covariance = {} # type: CovarianceType
self._model = None # type: Optional[zfit.core.basepdf.BasePDF]
self._ext_model = (
None
) # type: Optional[Union[Union[zfit.core.basepdf.BasePDF,zfit.core.interfaces.ZfitPDF] , Fitparameter]]
self._ext_model = None # type: Optional[Union[Union[zfit.core.basepdf.BasePDF,zfit.core.interfaces.ZfitPDF] , Fitparameter]]
def __add__(self, model: "BaseModel") -> "BaseModel":
assert self.obs == model.obs, Exception(
......@@ -55,7 +58,9 @@ class BaseModel(ABC):
def modelparameters(self) -> Dict[str, zfit.Parameter]:
if self.parameters and not self._modelparameters:
for key, parameter in self.parameters.items():
self._modelparameters[key] = zfit.Parameter(**parameter.parameter_info_dict)
self._modelparameters[key] = zfit.Parameter(
**parameter.parameter_info_dict
)
return self._modelparameters
......@@ -76,7 +81,11 @@ class BaseModel(ABC):
@property
def ext_model(
self,
) -> Optional[Union[Union[zfit.core.basepdf.BasePDF, zfit.core.interfaces.ZfitPDF], Fitparameter]]:
) -> Optional[
Union[
Union[zfit.core.basepdf.BasePDF, zfit.core.interfaces.ZfitPDF], Fitparameter
]
]:
return self._ext_model
@ext_model.setter
......@@ -98,7 +107,9 @@ class BaseModel(ABC):
return OutputParameter
def update_fitparameter(self, name: str, value: float, lower: float, upper: float) -> None:
def update_fitparameter(
self, name: str, value: float, lower: float, upper: float
) -> None:
if self.parameters:
for parameter in self.parameters.values():
if name == parameter.name:
......@@ -106,6 +117,12 @@ class BaseModel(ABC):
parameter.lower = lower
parameter.upper = upper
def update_modelparameter(self, name: str, value: float) -> None:
if self.modelparameters:
for parameter in self.modelparameters.values():
if name == str(parameter.name):
parameter.set_value(value)
def get_modelparameter(self, name: str) -> zfit.Parameter:
OutputParameter = None
for parameter in self.modelparameters.values():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment