Complicated summary function -- is it possible to solve with R data.table package?
Great question!! The example data is especially well constructed and well explained.
First I'll show this answer, then I'll explain it step by step.
> ids = 1:3 # or from the data: unique(ds$ID)
> pos = 1:6 # or from the data: unique(ds$Pos)
> setkey(ds,ID,Pos)
> ds[CJ(ids,pos), roll=-Inf, nomatch=0][, .N, by=Pos]
Pos N
1: 1 3
2: 2 3
3: 3 3
4: 4 3
5: 5 2
6: 6 1
>
That should also be very efficient on your large data.
Step by step
First I tried a Cross Join (CJ); i.e., for each train for each position.
> ds[CJ(ids,pos)]
ID Pos Obs
1: 1 1 1.50
2: 1 2 NA
3: 1 3 2.50
4: 1 4 NA
5: 1 5 0.00
6: 1 6 1.25
7: 2 1 NA
8: 2 2 1.45
9: 2 3 1.50
10: 2 4 NA
11: 2 5 2.50
12: 2 6 NA
13: 3 1 NA
14: 3 2 0.00
15: 3 3 1.25
16: 3 4 1.45
17: 3 5 NA
18: 3 6 NA
I see 6 rows per train. I see 3 trains. I've got 18 rows as I expected. I see NA
where that train wasn't observed. Good. Check. The cross join seems to be working. Let's now build the query up.
You wrote if a train is observed at position n it must have passed previous positions. Immediately I'm thinking roll
. Let's try it.
ds[CJ(ids,pos), roll=TRUE]
ID Pos Obs
1: 1 1 1.50
2: 1 2 1.50
3: 1 3 2.50
4: 1 4 2.50
5: 1 5 0.00
6: 1 6 1.25
7: 2 1 NA
8: 2 2 1.45
9: 2 3 1.50
10: 2 4 1.50
11: 2 5 2.50
12: 2 6 2.50
13: 3 1 NA
14: 3 2 0.00
15: 3 3 1.25
16: 3 4 1.45
17: 3 5 1.45
18: 3 6 1.45
Hm. That rolled the observations forwards for each train. It left some NA
at position 1 for trains 2 and 3, but you said if a train is observed at position 2 it must have passed position 1. It also rolled the last observation for trains 2 and 3 forward to position 6, but you said trains might explode. So, we want to roll backwards! That's roll=-Inf
. It's a complicated -Inf
because you can also control how far to roll backwards, but we don't need that for this question; we just want to roll backwards indefinitely. Let's try roll=-Inf
and see what happens.
> ds[CJ(ids,pos), roll=-Inf]
ID Pos Obs
1: 1 1 1.50
2: 1 2 2.50
3: 1 3 2.50
4: 1 4 0.00
5: 1 5 0.00
6: 1 6 1.25
7: 2 1 1.45
8: 2 2 1.45
9: 2 3 1.50
10: 2 4 2.50
11: 2 5 2.50
12: 2 6 NA
13: 3 1 0.00
14: 3 2 0.00
15: 3 3 1.25
16: 3 4 1.45
17: 3 5 NA
18: 3 6 NA
That's better. Almost there. All we need to do now is count. But, those pesky NA
are there after trains 2 and 3 exploded. Let's remove them.
> ds[CJ(ids,pos), roll=-Inf, nomatch=0]
ID Pos Obs
1: 1 1 1.50
2: 1 2 2.50
3: 1 3 2.50
4: 1 4 0.00
5: 1 5 0.00
6: 1 6 1.25
7: 2 1 1.45
8: 2 2 1.45
9: 2 3 1.50
10: 2 4 2.50
11: 2 5 2.50
12: 3 1 0.00
13: 3 2 0.00
14: 3 3 1.25
15: 3 4 1.45
Btw, data.table
likes as much as possible to be inside one single DT[...]
as that's how it optimizes the query. Internally, it doesn't create the NA
and then remove them; it never creates the NA
in the first place. This concept is important for efficiency.
Finally, all we have to do is count. We can just tack this on the end as a compound query.
> ds[CJ(ids,pos), roll=-Inf, nomatch=0][, .N, by=Pos]
Pos N
1: 1 3
2: 2 3
3: 3 3
4: 4 3
5: 5 2
6: 6 1
data.table
sounds like an excellent solution. From the way the data are ordered one could find the maximum of each train with
maxPos = ds$Pos[!duplicated(ds$ID, fromLast=TRUE)]
Then tabulate the trains that reach that position
nAtMax = tabulate(maxPos)
and calculate the cumulative sum of trains at each position, counting from the end
rev(cumsum(rev(nAtMax)))
## [1] 3 3 3 3 2 1
I think this will be quite fast for large data, though not entirely memory efficient.
You can try as below. I have purposefully split it into many step solution for better understanding. You can probably combine all of them into one step as well by just chaining []
.
The logic here is that first we find final position for each ID. Then we aggregate data to find count of IDs for each Final Position. Since all IDs for Final Position 6 should also be counted for Final position 5, we use cumsum
to add all higher ID counts to their respective lower IDs.
ds2 <- ds[, list(FinalPos=max(Pos)), by=ID]
ds2
## ID FinalPos
## 1: 1 6
## 2: 2 5
## 3: 3 4
ds3 <- ds2[ , list(Count = length(ID)), by = FinalPos][order(FinalPos, decreasing=TRUE), list(FinalPos, Count = cumsum(Count))]
ds3
## FinalPos Count
## 1: 4 3
## 2: 5 2
## 3: 6 1
setkey(ds3, FinalPos)
ds3[J(c(1:6)), roll = 'nearest']
## FinalPos Count
## 1: 1 3
## 2: 2 3
## 3: 3 3
## 4: 4 3
## 5: 5 2
## 6: 6 1