Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug in backward() or help needed #2077

Open
Apogeum12 opened this issue Jul 29, 2024 · 5 comments
Open

bug in backward() or help needed #2077

Apogeum12 opened this issue Jul 29, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@Apogeum12
Copy link

I have a strange issue with backward() I have two generators, gen1 and gen2, I calculate loss on three ways, loss_1, loss_2, loss_3

All compute for gen1 are ok
Part 1.
let out = gen1.forward(input);
let out2 = gen2.forward(out.clone())
... calculate loss
then total_loss = loss_1 + loss_2 + loss_3
and:
let grads1 = total_loss.backward();
let grads_gen1 = GradientsParams::from_grads(grads, &gen1);
This works fine

And then:
Part 2.
let out3 = gen2.forward(out)
let out = gen1.forward(out3)
... calculate loss
then in similar way but for others arguments total_loss = loss_1 + loss_2 + loss_3
and:
let grads2 = total_loss.backward();
let grads_gen2= GradientsParams::from_grads(grads, &gen2);

and update models
.. .step()...

all compute are done in one loop. And here for grads2 I have a issue, when I remove a loss_3 from total_loss everything works fine but with loss_3 I have this error msg:

thread 'main' panicked at /home/euuki/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-tensor-0.13.2/src/tensor/api/numeric.rs:22:9:
=== Tensor Operation Error ===
  Operation: 'Add'
  Reason:
    1. The provided tensors have incompatible shapes. Incompatible size at dimension '2' => '1278 != 1280', which can't be broadcasted. Lhs tensor shape [2, 1, 1278, 1278], Rhs tensor shape [2, 1, 1280, 1280]. 
    2. The provided tensors have incompatible shapes. Incompatible size at dimension '3' => '1278 != 1280', which can't be broadcasted. Lhs tensor shape [2, 1, 1278, 1278], Rhs tensor shape [2, 1, 1280, 1280]. 

stack backtrace:
   0:     0x577d25479995 - <std::sys_common::backtrace::_print::DisplayBacktrace as core::fmt::Display>::fmt::h1e1a1972118942ad
   1:     0x577d2549e41b - core::fmt::write::hc090a2ffd6b28c4a
   2:     0x577d2547773f - std::io::Write::write_fmt::h8898bac6ff039a23
   3:     0x577d2547976e - std::sys_common::backtrace::print::ha96650907276675e
   4:     0x577d2547aa29 - std::panicking::default_hook::{{closure}}::h215c2a0a8346e0e0
   5:     0x577d2547a76d - std::panicking::default_hook::h207342be97478370
   6:     0x577d2547aec3 - std::panicking::rust_panic_with_hook::hac8bdceee1e4fe2c
   7:     0x577d2547ada4 - std::panicking::begin_panic_handler::{{closure}}::h00d785e82757ce3c
   8:     0x577d25479e59 - std::sys_common::backtrace::__rust_end_short_backtrace::h1628d957bcd06996
   9:     0x577d2547aad7 - rust_begin_unwind
  10:     0x577d24f95ff3 - core::panicking::panic_fmt::hdc63834ffaaefae5
  11:     0x577d25085a41 - core::panicking::panic_display::hd504bfa7a23e079b
  12:     0x577d24f6910d - burn_tensor::tensor::api::numeric::<impl burn_tensor::tensor::api::base::Tensor<B,_,K>>::add::panic_cold_display::haa334297998f63f1
  13:     0x577d2507fd6f - burn_tensor::tensor::api::numeric::<impl burn_tensor::tensor::api::base::Tensor<B,_,K>>::add::h197481e80e899342
  14:     0x577d2503c0b7 - burn_autodiff::grads::Gradients::register::h7f50eb7d3a39e84f
  15:     0x577d25018acf - <burn_autodiff::ops::module::<impl burn_tensor::tensor::ops::modules::base::ModuleOps<burn_autodiff::backend::Autodiff<B,C>> for burn_autodiff::backend::Autodiff<B,C>>::conv2d::Conv2DWithBias as burn_autodiff::ops::backward::Backward<B,4_usize,3_usize>>::backward::h0eb6b130b28fdf5d
  16:     0x577d25060b8f - <burn_autodiff::ops::base::OpsStep<B,T,SB,_,_> as burn_autodiff::graph::base::Step>::step::hd9a3762bb2722aab
  17:     0x577d25446e8d - burn_autodiff::runtime::server::AutodiffServer::backward::h034fbb21f4457df1
  18:     0x577d24fcf75a - <burn_autodiff::runtime::mutex::MutexClient as burn_autodiff::runtime::client::AutodiffClient>::backward::h235c73a9c48299ab
  19:     0x577d2504a5c2 - burn_test::nn::flowscaller::deg_training::flow_net::h0a08effeaee47cf1
  20:     0x577d250a7f28 - burn_test::main::h3db9588c4d2b6cd5
  21:     0x577d24fc2a33 - std::sys_common::backtrace::__rust_begin_short_backtrace::h1ab658f13ba31837
  22:     0x577d2503da39 - std::rt::lang_start::{{closure}}::ha5d1a6d22c45cdd7
  23:     0x577d25472ff0 - std::rt::lang_start_internal::h3ed4fe7b2f419135
  24:     0x577d250a8035 - main
  25:     0x731bf462a1ca - __libc_start_call_main
                               at ./csu/../sysdeps/nptl/libc_start_call_main.h:58:16
  26:     0x731bf462a28b - __libc_start_main_impl
                               at ./csu/../csu/libc-start.c:360:3
  27:     0x577d24f96745 - _start
  28:                0x0 - <unknown>

where 1280 it's output size for tensor for width and height. All padding etc. should be ok for gen1 and gen2 because in otherways Part 1 wouldn't work. Do you have any idea or some sugestion what could be cause? I wanted write on discord on the channel #help, but with logs msg exceeded 2000 chars

@antimora antimora added the bug Something isn't working label Aug 5, 2024
@laggui
Copy link
Member

laggui commented Aug 8, 2024

Btw you can still write on discord, just embed the logs as a file.

Would love to help, but this will be difficult to reproduce without a minimal example.

@Apogeum12
Copy link
Author

Btw you can still write on discord, just embed the logs as a file.

Would love to help, but this will be difficult to reproduce without a minimal example.

This week I'll prepare some example an Upload code with more details, but based on working with pyTorch I thing it is not a bug, but burn not support this feature yet, I thing in backwards() are missing support for retain_graph=True and this is cause.

@laggui
Copy link
Member

laggui commented Aug 9, 2024

Ahhh ok that might be the case. We have an open issue for that (#1802).

@Apogeum12
Copy link
Author

Apogeum12 commented Aug 30, 2024

@laggui
Sorry for so late response.
This is part of code what I used, currently I'm not able to print current logs because I switch to pytorch for testing this way for model. General this way work in pytorch so I think this is not a bug but not support yet a features with Retain_graph. If you are able to confirm it's missing features not a bug I close the issue or you can close ;)

    for epoch in 1..config_generator.num_epochs + 1 {
        let mut total_training_luma_ssim = 0.0;
        let mut iter_training = 0;
        for (iteration, batch) in dataloader_train.iter().enumerate() {
            let image_lr = batch.clone().input_luma;
            let image_hr = batch.clone().output_luma;

            // ------ First Step -----
            let degenerate_lr = model_degenerator.forward(image_hr.clone());
            let generate_hr = model_generator.forward(degenerate_lr.clone());

            let loss_charbonier_degenerate =
                charbonier_loss(degenerate_lr.clone(), image_lr.clone(), 1e-3);
            let loss_ssim_degenerate =
                ssim_simple(degenerate_lr.clone(), image_lr.clone(), true, 1.0);
            let loss_degenerate_recon = loss_charbonier_degenerate.mul_scalar(0.1)
                + loss_ssim_degenerate.clone().mul_scalar(0.9);

            let cycle_hr = cycle_loss(image_hr.clone(), generate_hr, lambda);
            let identify_lr = identify_loss(image_lr.clone(), degenerate_lr.clone(), lambda);

            let total_loss_deg = loss_degenerate_recon.clone() + cycle_hr + identify_lr;
            let grads = total_loss_deg.backward();
            let grads_degenerator = GradientsParams::from_grads(grads, &model_degenerator);

            // ------ Second Step -----
            let generate_hr = model_generator.forward(degenerate_lr);
            let degenerate_lr = model_degenerator.forward(generate_hr.clone());

            let cycle_lr: Tensor<B, 1> =
                cycle_loss(image_lr.clone(), degenerate_lr.clone(), lambda);
            let identify_hr = identify_loss(image_hr.clone(), generate_hr.clone(), lambda);

            let loss_charbonier_generate =
                charbonier_loss(generate_hr.clone(), image_hr.clone(), 1e-3);
            let loss_ssim_generate = ssim_simple(generate_hr, image_hr, true, 1.0);
            let loss_generator = loss_charbonier_generate.mul_scalar(0.1)
                + loss_ssim_generate.clone().mul_scalar(0.9);

            let total_loss_gen = loss_generator.clone() + cycle_lr + identify_hr;
            let grads = total_loss_gen.backward();
            let grads_generator = GradientsParams::from_grads(grads, &model_generator);

            model_degenerator = optim_degenerator.step(
                config_degenerator.l_rate,
                model_degenerator,
                grads_degenerator,
            );
            model_generator =
                optim_generator.step(config_generator.l_rate, model_generator, grads_generator);

            // Stats
            total_training_luma_ssim += loss_ssim_generate.clone().into_scalar().elem::<f32>();
            iter_training += 1;
            println!(
                    "[Train - Epoch {} Iteration {}] Loss D_Gen {:.4} Loss Gen: {:.4} SSIM D_Gen: {:.4} SSIM Gen: {:.4}",
                    epoch,
                    iteration,
                    total_loss_deg.into_scalar(),
                    total_loss_gen.into_scalar(),
                    loss_ssim_degenerate.into_scalar(),
                    loss_ssim_generate.into_scalar(),
                );
        }
        

@laggui
Copy link
Member

laggui commented Aug 30, 2024

While the lack of an equivalent to retain_graph is really a missing feature, the issue you're encountering could be a bug.

Without retain_graph you can't calculate the gradients multiple times on the same graph, and if this is something that is required for your use case it might simply not give the expected results. But looks like there might be a bug in one of the backward implementations causing a shape mismatch (though hard to confirm with just this info).

I'll keep this open. Hopefully we can reproduce this somehow with a small example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants