[Scheduler] Move predict epsilon to init#1155
Conversation
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--predict_mode", |
There was a problem hiding this comment.
Let's try to align naming all over the codebase
| model_output: torch.FloatTensor, | ||
| timestep: int, | ||
| sample: torch.FloatTensor, | ||
| predict_epsilon=True, |
There was a problem hiding this comment.
predict_epsilon is an inherent config parameter just like beta_start that should not change during inference
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
i was wondering whether we should store this flag on the model because it's a characteristic of the model |
pcuenca
left a comment
There was a problem hiding this comment.
Very nice! I left a very minor docstring suggestion and have two general comments:
- How are we going to deal with additional objectives like v-prediction?
- For the property name, I wonder if we should call it
model_predicts_epsiloninstead. It's probably not worthwhile changing, but it might be a more clear indication that this is something that has to be configured depending on the model being used, and does not really affect the scheduling per se.
The model actually doesn't know what it's predicting. It depends on what are computing the loss against (noise or x0) and most of the changes to handle x0 prediction need to be in the scheduler, so I think this should go in scheduler config as in this PR. |
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
|
Regarding naming, I think Maybe Regarding general |
|
|
||
| if predict_epsilon is not None: | ||
| new_config = dict(self.scheduler.config) | ||
| new_config["predict_epsilon"] = predict_epsilon |
There was a problem hiding this comment.
@pcuenca note that now you should change this into:
new_config["prediction_type"] = "predict_epsilon"
Here and everywhere else
* [Scheduler] Move predict epsilon to init * up * uP * uP * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This PR cleans up the different implementations of
predict_epsilonthat we currently have.predict_epsilon=True/Falseinstead of predict_mode and predict_epsilon