How do Relational Networks work ?
How do Relational Networks work ?

How do Relational Networks work ?

How do Relational Networks work ?

I have just re-read Google's paper on Relational Networks. This is a neural architecture designed to learn relationships between objects in a scene. It evaluates a function for each pair of 'objects' in a scene, and then sums the result before applying another function. The form is shown below. f and g are implemented as trainable multi-layer perceptrons. O_i and O_j are objects from the scene (which include their location) and q is the query being asked about the scene. q is an embedding vector created by passing the textual query through a recurrent network.

Relational Network

I used to think I understood this paper, but now I have doubts - consider the question 'What colour is the object farthest from the blue sphere ?'. Function g is applied to each pair of objects. It is conditioned on the query, q and is capable of computing the distance between the objects when the first object (say) is the blue sphere. The distance between the objects and the colour of the second object can be output as a vector from g. These vectors are summed and then f computes the final answer. Unfortunately, I can't see a general way in which f can identify the furthest object from the sum. It is possible that each object uses a different segment, of the vector returned by g, to encode its colour and distance from the blue sphere. An example layout being:

obj1 obj2 obj3 obj4 obj5 obj6

[red, 15, blue, 23, blue, 10, blue, 0, green, 18, red 5]

f could then learn to output the colour to the left of the largest distance (in this case 'blue' at distance 23), but this is a hard function for an MLP to learn (it essentially needs to understand relations between elements of the vector). Another problem is that this solution is limited by the length of the vector - it can only handle a fixed number of objects.

Are there any other sensible ways in which information from g can be passed to f, and that allow the farthest object to be identified ? I feel I must be missing something here.

submitted by /u/TurnipYadaYada6941
[link] [comments]