What type is a sklearn model?
I think the most generic class that all models inherit from would be sklearn.base.BaseEstimator
.
If you want to be more specific, maybe use sklearn.base.ClassifierMixin
or sklearn.base.RegressorMixin
.
So I would do:
from sklearn.base import RegressorMixin
def model_tester(model: RegressorMixin, parameter: int) -> np.ndarray:
"""An example function with type hints."""
# do stuff to model
return values
I am no expert in type checking, so correct me if this is not right.
From Python 3.8 on (or earlier using typing-extensions), you can use typing.Protocol
. Using protocols, you can use a concept called structural subtyping to define exactly the type's expected structure:
from typing import Protocol
# from typing_extensions import Protocol # for Python <3.8
class ScikitModel(Protocol):
def fit(self, X, y, sample_weight=None): ...
def predict(self, X): ...
def score(self, X, y, sample_weight=None): ...
def set_params(self, **params): ...
which you can then use as a type hint:
def do_stuff(model: ScikitModel) -> Any:
model.fit(train_data, train_labels) # this type checks
score = model.score(test_data, test_labels) # this type checks
...