How forwardRef() works internally in React?

forwardRef() in React allows us to create a component in which we can extract ref from props. We are not able to achieve this by ourself because ref is handled specially. Let’s dive into its internals today.

1. Why forwardRef() is needed?

The reason is simple, we cannot retrieve ref from props.

function Component(props) {
  return <div>
    Notice that for Component, <br/>we get props as <code>{JSON.stringify(props)}</code> <br/>which doesn't contain `ref` thought it is set
  </div>
}
export default function App() {
  return <Component a={1} ref={() => {}}/>
}

This is because ref is a special prop that gets specially handled.

For <Component a={1} ref={() => {}}/>, it is compiled to below React Element.

{
$$typeof: Symbol("react.element"),
key: null,
props: {a: 1},
-------------
ref: () => {},
-------------

See that ref is besides props

type: function Component(props) {...},
_owner: ...,
_store: ...,
}
{
$$typeof: Symbol("react.element"),
key: null,
props: {a: 1},
-------------
ref: () => {},
-------------

See that ref is besides props

type: function Component(props) {...},
_owner: ...,
_store: ...,
}

We can see that ref is not inside props, so we are not able to retrieve it, this becomes an issue if we want to relay ref in all kinds of wrapper components.

2. How forwardRef() works internally?

2.1 forwardRef() returns a special React Element of REACT_FORWARD_REF_TYPE.

export function forwardRef<Props, ElementType: React$ElementType>(
render: (props: Props, ref: React$Ref<ElementType>) => React$Node,
) {
const elementType = {
$$typeof: REACT_FORWARD_REF_TYPE,
--------------------------------

forwardRef() does not return a function component,

but a special element which later maps a special Fiber Node

render,
------

This is our custom code as the parameter

};
return elementType;
}
export function forwardRef<Props, ElementType: React$ElementType>(
render: (props: Props, ref: React$Ref<ElementType>) => React$Node,
) {
const elementType = {
$$typeof: REACT_FORWARD_REF_TYPE,
--------------------------------

forwardRef() does not return a function component,

but a special element which later maps a special Fiber Node

render,
------

This is our custom code as the parameter

};
return elementType;
}

Actually there are quite a lot of React Element Types.

// The Symbol used to tag the ReactElement-like types.
export const REACT_ELEMENT_TYPE = Symbol.for('react.element');
------------------------------------------------

This covers custom Function/Class components and intrinsic HTML tags

The other types are special ones that have special logic

export const REACT_PORTAL_TYPE = Symbol.for('react.portal');
export const REACT_FRAGMENT_TYPE = Symbol.for('react.fragment');
export const REACT_STRICT_MODE_TYPE = Symbol.for('react.strict_mode');
export const REACT_PROFILER_TYPE = Symbol.for('react.profiler');
export const REACT_PROVIDER_TYPE = Symbol.for('react.provider');
export const REACT_CONTEXT_TYPE = Symbol.for('react.context');
export const REACT_SERVER_CONTEXT_TYPE = Symbol.for('react.server_context');
export const REACT_FORWARD_REF_TYPE = Symbol.for('react.forward_ref');
--------------------------------------------------------------
export const REACT_SUSPENSE_TYPE = Symbol.for('react.suspense');
export const REACT_SUSPENSE_LIST_TYPE = Symbol.for('react.suspense_list');
export const REACT_MEMO_TYPE = Symbol.for('react.memo');
export const REACT_LAZY_TYPE = Symbol.for('react.lazy');
export const REACT_SCOPE_TYPE = Symbol.for('react.scope');
export const REACT_DEBUG_TRACING_MODE_TYPE = Symbol.for(
'react.debug_trace_mode',
);
export const REACT_OFFSCREEN_TYPE = Symbol.for('react.offscreen');
export const REACT_LEGACY_HIDDEN_TYPE = Symbol.for('react.legacy_hidden');
export const REACT_CACHE_TYPE = Symbol.for('react.cache');
export const REACT_TRACING_MARKER_TYPE = Symbol.for('react.tracing_marker');
export const REACT_SERVER_CONTEXT_DEFAULT_VALUE_NOT_LOADED = Symbol.for(
'react.default_value',
);
// The Symbol used to tag the ReactElement-like types.
export const REACT_ELEMENT_TYPE = Symbol.for('react.element');
------------------------------------------------

This covers custom Function/Class components and intrinsic HTML tags

The other types are special ones that have special logic

export const REACT_PORTAL_TYPE = Symbol.for('react.portal');
export const REACT_FRAGMENT_TYPE = Symbol.for('react.fragment');
export const REACT_STRICT_MODE_TYPE = Symbol.for('react.strict_mode');
export const REACT_PROFILER_TYPE = Symbol.for('react.profiler');
export const REACT_PROVIDER_TYPE = Symbol.for('react.provider');
export const REACT_CONTEXT_TYPE = Symbol.for('react.context');
export const REACT_SERVER_CONTEXT_TYPE = Symbol.for('react.server_context');
export const REACT_FORWARD_REF_TYPE = Symbol.for('react.forward_ref');
--------------------------------------------------------------
export const REACT_SUSPENSE_TYPE = Symbol.for('react.suspense');
export const REACT_SUSPENSE_LIST_TYPE = Symbol.for('react.suspense_list');
export const REACT_MEMO_TYPE = Symbol.for('react.memo');
export const REACT_LAZY_TYPE = Symbol.for('react.lazy');
export const REACT_SCOPE_TYPE = Symbol.for('react.scope');
export const REACT_DEBUG_TRACING_MODE_TYPE = Symbol.for(
'react.debug_trace_mode',
);
export const REACT_OFFSCREEN_TYPE = Symbol.for('react.offscreen');
export const REACT_LEGACY_HIDDEN_TYPE = Symbol.for('react.legacy_hidden');
export const REACT_CACHE_TYPE = Symbol.for('react.cache');
export const REACT_TRACING_MARKER_TYPE = Symbol.for('react.tracing_marker');
export const REACT_SERVER_CONTEXT_DEFAULT_VALUE_NOT_LOADED = Symbol.for(
'react.default_value',
);

Notice that above types are types of React Elements, basically means the types of JSX elements. React Elements are blueprints for React runtime to construct and update Fiber Tree.

And in the realm of Fiber, there are also all kinds of Fiber types(tag). They are not mapped in 1:1 though, some are Fiber-only.

export const FunctionComponent = 0;
-----------------
export const ClassComponent = 1;
--------------
export const IndeterminateComponent = 2; // Before we know whether it is function or class
export const HostRoot = 3; // Root of a host tree. Could be nested inside another node.
export const HostPortal = 4; // A subtree. Could be an entry point to a different renderer.
export const HostComponent = 5;
-------------

HostComponent means the HTML tags in Host DOM

export const HostText = 6;
export const Fragment = 7;
export const Mode = 8;
export const ContextConsumer = 9;
export const ContextProvider = 10;
export const ForwardRef = 11;
----------
export const Profiler = 12;
export const SuspenseComponent = 13;
export const MemoComponent = 14;
export const SimpleMemoComponent = 15;
export const LazyComponent = 16;
-------------
export const IncompleteClassComponent = 17;
export const DehydratedFragment = 18;
export const SuspenseListComponent = 19;
export const ScopeComponent = 21;
export const OffscreenComponent = 22;
export const LegacyHiddenComponent = 23;
export const CacheComponent = 24;
export const TracingMarkerComponent = 25;
export const FunctionComponent = 0;
-----------------
export const ClassComponent = 1;
--------------
export const IndeterminateComponent = 2; // Before we know whether it is function or class
export const HostRoot = 3; // Root of a host tree. Could be nested inside another node.
export const HostPortal = 4; // A subtree. Could be an entry point to a different renderer.
export const HostComponent = 5;
-------------

HostComponent means the HTML tags in Host DOM

export const HostText = 6;
export const Fragment = 7;
export const Mode = 8;
export const ContextConsumer = 9;
export const ContextProvider = 10;
export const ForwardRef = 11;
----------
export const Profiler = 12;
export const SuspenseComponent = 13;
export const MemoComponent = 14;
export const SimpleMemoComponent = 15;
export const LazyComponent = 16;
-------------
export const IncompleteClassComponent = 17;
export const DehydratedFragment = 18;
export const SuspenseListComponent = 19;
export const ScopeComponent = 21;
export const OffscreenComponent = 22;
export const LegacyHiddenComponent = 23;
export const CacheComponent = 24;
export const TracingMarkerComponent = 25;

As I said React Elements are blueprints for Fiber tree, we can find the logic that do the mapping from REACT_FORWARD_REF_TYPE to ForwardRef in the Fiber Node creation.

export function createFiberFromElement(
----------------------
element: ReactElement,
mode: TypeOfMode,
lanes: Lanes,
): Fiber {
let owner = null;
const type = element.type;
const key = element.key;
const pendingProps = element.props;
const fiber = createFiberFromTypeAndProps(
---------------------------
type,
key,
pendingProps,
owner,
mode,
lanes,
);
return fiber;
}
export function createFiberFromTypeAndProps(
type: any, // React$ElementType
key: null | string,
pendingProps: any,
owner: null | Fiber,
mode: TypeOfMode,
lanes: Lanes,
): Fiber {
let fiberTag = IndeterminateComponent;
// The resolved type is set if we know what the final type will be. I.e. it's not lazy.
let resolvedType = type;
if (typeof type === 'function') {
if (shouldConstruct(type)) {
fiberTag = ClassComponent;
}

Custom components

} else if (typeof type === 'string') {
fiberTag = HostComponent;

HTML tags

} else {
getTag: switch (type) {
case REACT_FRAGMENT_TYPE:
return createFiberFromFragment(pendingProps.children, mode, lanes, key);
case REACT_STRICT_MODE_TYPE:
fiberTag = Mode;
mode |= StrictLegacyMode;
if (enableStrictEffects && (mode & ConcurrentMode) !== NoMode) {
// Strict effects should never run on legacy roots
mode |= StrictEffectsMode;
}
break;
case REACT_PROFILER_TYPE:
return createFiberFromProfiler(pendingProps, mode, lanes, key);
case REACT_SUSPENSE_TYPE:
return createFiberFromSuspense(pendingProps, mode, lanes, key);
....
// eslint-disable-next-line no-fallthrough
default: {
if (typeof type === 'object' && type !== null) {
switch (type.$$typeof) {
case REACT_PROVIDER_TYPE:
fiberTag = ContextProvider;
break getTag;
case REACT_CONTEXT_TYPE:
// This is a consumer
fiberTag = ContextConsumer;
break getTag;
case REACT_FORWARD_REF_TYPE:
fiberTag = ForwardRef;
break getTag;

Here it is

case REACT_MEMO_TYPE:
fiberTag = MemoComponent;
break getTag;
case REACT_LAZY_TYPE:
fiberTag = LazyComponent;
resolvedType = null;
break getTag;
}
}
}
}
}
const fiber = createFiber(fiberTag, pendingProps, key, mode);
fiber.elementType = type;
fiber.type = resolvedType;
fiber.lanes = lanes;
return fiber;
}
export function createFiberFromElement(
----------------------
element: ReactElement,
mode: TypeOfMode,
lanes: Lanes,
): Fiber {
let owner = null;
const type = element.type;
const key = element.key;
const pendingProps = element.props;
const fiber = createFiberFromTypeAndProps(
---------------------------
type,
key,
pendingProps,
owner,
mode,
lanes,
);
return fiber;
}
export function createFiberFromTypeAndProps(
type: any, // React$ElementType
key: null | string,
pendingProps: any,
owner: null | Fiber,
mode: TypeOfMode,
lanes: Lanes,
): Fiber {
let fiberTag = IndeterminateComponent;
// The resolved type is set if we know what the final type will be. I.e. it's not lazy.
let resolvedType = type;
if (typeof type === 'function') {
if (shouldConstruct(type)) {
fiberTag = ClassComponent;
}

Custom components

} else if (typeof type === 'string') {
fiberTag = HostComponent;

HTML tags

} else {
getTag: switch (type) {
case REACT_FRAGMENT_TYPE:
return createFiberFromFragment(pendingProps.children, mode, lanes, key);
case REACT_STRICT_MODE_TYPE:
fiberTag = Mode;
mode |= StrictLegacyMode;
if (enableStrictEffects && (mode & ConcurrentMode) !== NoMode) {
// Strict effects should never run on legacy roots
mode |= StrictEffectsMode;
}
break;
case REACT_PROFILER_TYPE:
return createFiberFromProfiler(pendingProps, mode, lanes, key);
case REACT_SUSPENSE_TYPE:
return createFiberFromSuspense(pendingProps, mode, lanes, key);
....
// eslint-disable-next-line no-fallthrough
default: {
if (typeof type === 'object' && type !== null) {
switch (type.$$typeof) {
case REACT_PROVIDER_TYPE:
fiberTag = ContextProvider;
break getTag;
case REACT_CONTEXT_TYPE:
// This is a consumer
fiberTag = ContextConsumer;
break getTag;
case REACT_FORWARD_REF_TYPE:
fiberTag = ForwardRef;
break getTag;

Here it is

case REACT_MEMO_TYPE:
fiberTag = MemoComponent;
break getTag;
case REACT_LAZY_TYPE:
fiberTag = LazyComponent;
resolvedType = null;
break getTag;
}
}
}
}
}
const fiber = createFiber(fiberTag, pendingProps, key, mode);
fiber.elementType = type;
fiber.type = resolvedType;
fiber.lanes = lanes;
return fiber;
}

By now we see how special Fiber is created for ForwardRef.

2.2 updateForwardRef() processes the special Fiber Node for forwardRef().

As we explained how React does the initial mount and re-render, we can easily find the code that processes the Fiber node for forwardRef() in beginWork().

function beginWork(
current: Fiber | null,
workInProgress: Fiber,
renderLanes: Lanes,
): Fiber | null {
...
switch (workInProgress.tag) {
case FunctionComponent: {
const Component = workInProgress.type;
const unresolvedProps = workInProgress.pendingProps;
const resolvedProps =
workInProgress.elementType === Component
? unresolvedProps
: resolveDefaultProps(Component, unresolvedProps);
return updateFunctionComponent(
current,
workInProgress,
Component,
resolvedProps,
renderLanes,
);
}
case ClassComponent: {
const Component = workInProgress.type;
const unresolvedProps = workInProgress.pendingProps;
const resolvedProps =
workInProgress.elementType === Component
? unresolvedProps
: resolveDefaultProps(Component, unresolvedProps);
return updateClassComponent(
current,
workInProgress,
Component,
resolvedProps,
renderLanes,
);
}
case HostRoot:
return updateHostRoot(current, workInProgress, renderLanes);
case HostComponent:
return updateHostComponent(current, workInProgress, renderLanes);
case HostText:
return updateHostText(current, workInProgress);
...
case ForwardRef: {
const type = workInProgress.type;
const unresolvedProps = workInProgress.pendingProps;
const resolvedProps =
workInProgress.elementType === type
? unresolvedProps
: resolveDefaultProps(type, unresolvedProps);
return updateForwardRef(
current,
workInProgress,
type,
resolvedProps,
renderLanes,
);
}
...
}
}
function beginWork(
current: Fiber | null,
workInProgress: Fiber,
renderLanes: Lanes,
): Fiber | null {
...
switch (workInProgress.tag) {
case FunctionComponent: {
const Component = workInProgress.type;
const unresolvedProps = workInProgress.pendingProps;
const resolvedProps =
workInProgress.elementType === Component
? unresolvedProps
: resolveDefaultProps(Component, unresolvedProps);
return updateFunctionComponent(
current,
workInProgress,
Component,
resolvedProps,
renderLanes,
);
}
case ClassComponent: {
const Component = workInProgress.type;
const unresolvedProps = workInProgress.pendingProps;
const resolvedProps =
workInProgress.elementType === Component
? unresolvedProps
: resolveDefaultProps(Component, unresolvedProps);
return updateClassComponent(
current,
workInProgress,
Component,
resolvedProps,
renderLanes,
);
}
case HostRoot:
return updateHostRoot(current, workInProgress, renderLanes);
case HostComponent:
return updateHostComponent(current, workInProgress, renderLanes);
case HostText:
return updateHostText(current, workInProgress);
...
case ForwardRef: {
const type = workInProgress.type;
const unresolvedProps = workInProgress.pendingProps;
const resolvedProps =
workInProgress.elementType === type
? unresolvedProps
: resolveDefaultProps(type, unresolvedProps);
return updateForwardRef(
current,
workInProgress,
type,
resolvedProps,
renderLanes,
);
}
...
}
}
function updateForwardRef(
current: Fiber | null,
workInProgress: Fiber,
Component: any,
nextProps: any,
renderLanes: Lanes,
) {
const render = Component.render;
const ref = workInProgress.ref;
// The rest is a fork of updateFunctionComponent
------------------------------------------------

As this comment says, it is treated the same as Function Component

let nextChildren;
let hasId;
prepareToReadContext(workInProgress, renderLanes);
nextChildren = renderWithHooks(
---------------

This is actually where our component gets run.

current,
workInProgress,
render,
-----
nextProps,
ref,
---
renderLanes,
);
hasId = checkDidRenderIdHook();
if (current !== null && !didReceiveUpdate) {
bailoutHooks(current, workInProgress, renderLanes);
return bailoutOnAlreadyFinishedWork(current, workInProgress, renderLanes);
}
if (getIsHydrating() && hasId) {
pushMaterializedTreeId(workInProgress);
}
// React DevTools reads this flag.
workInProgress.flags |= PerformedWork;
reconcileChildren(current, workInProgress, nextChildren, renderLanes);
return workInProgress.child;
}
function updateForwardRef(
current: Fiber | null,
workInProgress: Fiber,
Component: any,
nextProps: any,
renderLanes: Lanes,
) {
const render = Component.render;
const ref = workInProgress.ref;
// The rest is a fork of updateFunctionComponent
------------------------------------------------

As this comment says, it is treated the same as Function Component

let nextChildren;
let hasId;
prepareToReadContext(workInProgress, renderLanes);
nextChildren = renderWithHooks(
---------------

This is actually where our component gets run.

current,
workInProgress,
render,
-----
nextProps,
ref,
---
renderLanes,
);
hasId = checkDidRenderIdHook();
if (current !== null && !didReceiveUpdate) {
bailoutHooks(current, workInProgress, renderLanes);
return bailoutOnAlreadyFinishedWork(current, workInProgress, renderLanes);
}
if (getIsHydrating() && hasId) {
pushMaterializedTreeId(workInProgress);
}
// React DevTools reads this flag.
workInProgress.flags |= PerformedWork;
reconcileChildren(current, workInProgress, nextChildren, renderLanes);
return workInProgress.child;
}

So actually renderWithHooks() accepts two arguments, but for normal Function Component, it is not used (it is used actually for legacy stuff).

function updateFunctionComponent(
current,
workInProgress,
Component,
nextProps: any,
renderLanes,
) {
let context;
if (!disableLegacyContext) {
----------

This is deprecated legacy context

const unmaskedContext = getUnmaskedContext(workInProgress, Component, true);
context = getMaskedContext(workInProgress, unmaskedContext);
}
let nextChildren;
let hasId;
prepareToReadContext(workInProgress, renderLanes);
if (enableSchedulingProfiler) {
markComponentRenderStarted(workInProgress);
}
nextChildren = renderWithHooks(
current,
workInProgress,
Component,
nextProps,
context,
-------

For normal Function Component, rather than ref, context is used

But this context is Legacy Context which is deprecated

So here we can just ignore it

renderLanes,
);
hasId = checkDidRenderIdHook();
if (enableSchedulingProfiler) {
markComponentRenderStopped();
}
if (current !== null && !didReceiveUpdate) {
bailoutHooks(current, workInProgress, renderLanes);
return bailoutOnAlreadyFinishedWork(current, workInProgress, renderLanes);
}
if (getIsHydrating() && hasId) {
pushMaterializedTreeId(workInProgress);
}
// React DevTools reads this flag.
workInProgress.flags |= PerformedWork;
reconcileChildren(current, workInProgress, nextChildren, renderLanes);
return workInProgress.child;
}
function updateFunctionComponent(
current,
workInProgress,
Component,
nextProps: any,
renderLanes,
) {
let context;
if (!disableLegacyContext) {
----------

This is deprecated legacy context

const unmaskedContext = getUnmaskedContext(workInProgress, Component, true);
context = getMaskedContext(workInProgress, unmaskedContext);
}
let nextChildren;
let hasId;
prepareToReadContext(workInProgress, renderLanes);
if (enableSchedulingProfiler) {
markComponentRenderStarted(workInProgress);
}
nextChildren = renderWithHooks(
current,
workInProgress,
Component,
nextProps,
context,
-------

For normal Function Component, rather than ref, context is used

But this context is Legacy Context which is deprecated

So here we can just ignore it

renderLanes,
);
hasId = checkDidRenderIdHook();
if (enableSchedulingProfiler) {
markComponentRenderStopped();
}
if (current !== null && !didReceiveUpdate) {
bailoutHooks(current, workInProgress, renderLanes);
return bailoutOnAlreadyFinishedWork(current, workInProgress, renderLanes);
}
if (getIsHydrating() && hasId) {
pushMaterializedTreeId(workInProgress);
}
// React DevTools reads this flag.
workInProgress.flags |= PerformedWork;
reconcileChildren(current, workInProgress, nextChildren, renderLanes);
return workInProgress.child;
}

Let’s take a look at renderWithHooks().

export function renderWithHooks<Props, SecondArg>(
current: Fiber | null,
workInProgress: Fiber,
Component: (p: Props, arg: SecondArg) => any,
props: Props,
secondArg: SecondArg,
---------------------
nextRenderLanes: Lanes,
): any {
renderLanes = nextRenderLanes;
currentlyRenderingFiber = workInProgress;
workInProgress.memoizedState = null;
workInProgress.updateQueue = null;
workInProgress.lanes = NoLanes;
...
let children = Component(props, secondArg);
---------------------------

props and ref are passed to our Component and run here

...
// We can assume the previous dispatcher is always this one, since we set it
// at the beginning of the render phase and there's no re-entrance.
ReactCurrentDispatcher.current = ContextOnlyDispatcher;
// This check uses currentHook so that it works the same in DEV and prod bundles.
// hookTypesDev could catch more cases (e.g. context) but only in DEV bundles.
const didRenderTooFewHooks =
currentHook !== null && currentHook.next !== null;
renderLanes = NoLanes;
currentlyRenderingFiber = (null: any);
currentHook = null;
workInProgressHook = null;
...
return children;
}
export function renderWithHooks<Props, SecondArg>(
current: Fiber | null,
workInProgress: Fiber,
Component: (p: Props, arg: SecondArg) => any,
props: Props,
secondArg: SecondArg,
---------------------
nextRenderLanes: Lanes,
): any {
renderLanes = nextRenderLanes;
currentlyRenderingFiber = workInProgress;
workInProgress.memoizedState = null;
workInProgress.updateQueue = null;
workInProgress.lanes = NoLanes;
...
let children = Component(props, secondArg);
---------------------------

props and ref are passed to our Component and run here

...
// We can assume the previous dispatcher is always this one, since we set it
// at the beginning of the render phase and there's no re-entrance.
ReactCurrentDispatcher.current = ContextOnlyDispatcher;
// This check uses currentHook so that it works the same in DEV and prod bundles.
// hookTypesDev could catch more cases (e.g. context) but only in DEV bundles.
const didRenderTooFewHooks =
currentHook !== null && currentHook.next !== null;
renderLanes = NoLanes;
currentlyRenderingFiber = (null: any);
currentHook = null;
workInProgressHook = null;
...
return children;
}

3. Summary

  1. ref is a property besides props in React element.
  2. Internally Function Components are run with 2 arguments, but by default the 2nd is not used thus we are not able to access anything other than the 1st one - props.
  3. With forwardRef(), React uses a new type of Fiber Node that when processed it works similar to Function Components, except it extracts the ref and passes it to our components as 2nd argument.

Want to know more about how React works internally?
Check out my series - React Internals Deep Dive!

😳 Would you like to share my post to more people ?    

❮ Prev: How lazy() works internally in React?

Next: How useSyncExternalStore() works internally in React?