How to enable_shared_from_this of both parent and derived

The OP solution can be made more convenient by defining the following on the base class.

protected:
    template <typename Derived>
    std::shared_ptr<Derived> shared_from_base()
    {
        return std::static_pointer_cast<Derived>(shared_from_this());
    }

This can be made more convenient by placing it in a base class (for reuse).

#include <memory>

template <class Base>
class enable_shared_from_base
  : public std::enable_shared_from_this<Base>
{
protected:
    template <class Derived>
    std::shared_ptr<Derived> shared_from_base()
    {
        return std::static_pointer_cast<Derived>(shared_from_this());
    }
};

and then deriving from it as follows.

#include <functional>
#include <iostream>

class foo : public enable_shared_from_base<foo> {
    void foo_do_it()
    {
        std::cout << "foo::do_it\n";
    }
public:
    virtual std::function<void()> get_callback()
    {
        return std::bind(&foo::foo_do_it, shared_from_base<foo>());
    }
};

class bar1 : public foo {
    void bar1_do_it()
    {
        std::cout << "bar1::do_it\n";
    }
public:
    virtual std::function<void()> get_callback() override
    {
        return std::bind(&bar1::bar1_do_it, shared_from_base<bar1>());
    }
};

Sorry, but there isn't.

The problem is that shared_ptr<foo> and shared_ptr<bar1> are different types. I don't understand everything that's going on under the hood, but I think that when the constructor returns and is assigned to a shared_ptr<foo>, the internal weak_ptr<bar1> sees that nothing is pointing to it (because only a shared_ptr<bar1> would increment the counter) and resets itself. When you call bar1::shared_from_this in get_callback, you get the exception because the internal weak_ptr isn't pointing to anything.

Essentially, enable_shared_from_this only seems to work transparently from a single class in a hierarchy. If you try implementing it manually, the problem should become obvious.


A similar solution to @evoskuil that reduces boilerplate in derived classes should you want to implement a shared_from_this() function, resulting in the following code at the point of use in the class:

auto shared_from_this() {
    return shared_from(this);
}  

This uses 'shim' functions outside of the class. By doing it that way it also provides a clean way to do this for classes who's interface can't be modified but derive from enable_shared_from_this - e.g.

auto shared_that = shared_from(that);

Note: Use of auto for return types here will depend upon the age of your compiler.

Shim functions that could be placed in a library header:

template <typename Base>
inline std::shared_ptr<Base>
shared_from_base(std::enable_shared_from_this<Base>* base) 
{
    return base->shared_from_this();
}
template <typename Base>
inline std::shared_ptr<const Base>
shared_from_base(std::enable_shared_from_this<Base> const* base) 
{
    return base->shared_from_this();
}
template <typename That>
inline std::shared_ptr<That>
shared_from(That* that) 
{
    return std::static_pointer_cast<That>(shared_from_base(that));
}

The above code relies on the fact that the type passed to shared_from(...) inherits from std::enable_shared_from_this<Base> at some point in its ancestry.

Calling shared_from_base will figure out what type that ultimately was. Since we know that That inherits from Base, a static downcast can be made.

Probably there are some pathological corner cases with classes having type conversion operators.. but that's unlikely to occur in code not designed to break this.

Example:

struct base : public std::enable_shared_from_this<base> {};
struct derived : public base
{
    auto shared_from_this() {
        return shared_from(this);
    }
    // Can also provide a version for const:
    auto shared_from_this() const {
        return shared_from(this);
    }
    // Note that it is also possible to use shared_from(...) from
    // outside the class, e.g. 
    // auto sp = shared_from(that);
};
template <typename X>
struct derived_x : public derived
{
    auto shared_from_this() {
        return shared_from(this);
    }
};

Compilation test:

int main()
{
    auto pbase = std::make_shared<base>();
    auto pderived = std::make_shared<derived>();
    auto pderived_x = std::make_shared<derived_x<int> >();

    auto const& const_pderived = *pderived;
    const_pderived.shared_from_this();

    std::shared_ptr<base> test1 = pbase->shared_from_this();
    std::shared_ptr<derived> test2 = pderived->shared_from_this();
    std::shared_ptr<derived_x<int> > test3 = pderived_x->shared_from_this();

    return 0;
}

https://onlinegdb.com/SJWM5CYIG

Prior solution that I posted, kept to make the comments still make sense - this placed the functions in the base class which had some problems - particularly non-uniformity between the required implementation for 'normal' classes and template classes.
Additionally the implementation in the base class would need to be repeated for new class hierarchies which is not all that DRY. Furthermore the base class function suffered from the possibility of misuse by supplying a base class pointer from a different object. The newer scheme above avoids this entirely and the runtime assert(...) check goes.

Old implementation:

#include <cassert>
#include <memory>

class base : public std::enable_shared_from_this<base>
{
protected:   
    template <typename T>
    std::shared_ptr<T> shared_from(T* derived) {
        assert(this == derived);
        return std::static_pointer_cast<T>(shared_from_this());
    }
};

class derived : public base
{
public:
    auto shared_from_this() {
        return shared_from(this);
    }
};

template <typename X>
class derived_x : public derived
{
public:
    auto shared_from_this() {
        return this->template shared_from(this);
    }
};

int main()
{
    auto pbase = std::make_shared<base>();
    auto pderived = std::make_shared<derived>();
    auto pderived_x = std::make_shared<derived_x<int> >();

    std::shared_ptr<base> test1 = pbase->shared_from_this();
    std::shared_ptr<derived> test2 = pderived->shared_from_this();
    std::shared_ptr<derived_x<int> > test3 = pderived_x->shared_from_this();

    return 0;
}