Move loss to task heads (#2825)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
@ -48,25 +48,22 @@ trainer.train()
|
||||
|
||||
You now realize that you need to customize the trainer further to:
|
||||
|
||||
* Customize the `loss function`.
|
||||
* * Customize the `loss function`.
|
||||
* Add `callback` that uploads model to your Google Drive after every 10 `epochs`
|
||||
Here's how you can do it:
|
||||
|
||||
```python
|
||||
from ultralytics.yolo.v8.detect import DetectionTrainer
|
||||
from ultralytcs.nn.tasks import DetectionModel
|
||||
|
||||
class MyCustomModel(DetectionModel):
|
||||
def init_criterion():
|
||||
...
|
||||
|
||||
|
||||
class CustomTrainer(DetectionTrainer):
|
||||
def get_model(self, cfg, weights):
|
||||
...
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
# get ground truth
|
||||
imgs = batch["imgs"]
|
||||
bboxes = batch["bboxes"]
|
||||
...
|
||||
return loss, loss_items # see Reference-> Trainer for details on the expected format
|
||||
|
||||
return MyCustomModel(...)
|
||||
|
||||
# callback to upload model weights
|
||||
def log_model(trainer):
|
||||
@ -84,4 +81,4 @@ To know more about Callback triggering events and entry point, checkout our [Cal
|
||||
## Other engine components
|
||||
|
||||
There are other components that can be customized similarly like `Validators` and `Predictors`
|
||||
See Reference section for more information on these.
|
||||
See Reference section for more information on these.
|
||||
|
Reference in New Issue
Block a user