Skip to content

Commit 41c56d4

Browse files
committed
Merge pull request #10996 from cweld510:cweld/optionally-close-unix-sockets-on-save
PiperOrigin-RevId: 684217787
2 parents 62eaadc + befd16e commit 41c56d4

File tree

14 files changed

+115
-37
lines changed

14 files changed

+115
-37
lines changed

pkg/sentry/fsimpl/gofer/socket.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ type endpoint struct {
5454
}
5555

5656
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
57-
func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error {
57+
func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint), opts transport.UnixSocketOpts) *syserr.Error {
5858
// No lock ordering required as only the ConnectingEndpoint has a mutex.
5959
ce.Lock()
6060

@@ -68,7 +68,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec
6868
return syserr.ErrInvalidEndpointState
6969
}
7070

71-
c, err := e.newConnectedEndpoint(ctx, ce.Type(), ce.WaiterQueue())
71+
c, err := e.newConnectedEndpoint(ctx, ce.Type(), ce.WaiterQueue(), opts)
7272
if err != nil {
7373
ce.Unlock()
7474
return err
@@ -85,8 +85,8 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec
8585

8686
// UnidirectionalConnect implements
8787
// transport.BoundEndpoint.UnidirectionalConnect.
88-
func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) {
89-
c, err := e.newConnectedEndpoint(ctx, linux.SOCK_DGRAM, &waiter.Queue{})
88+
func (e *endpoint) UnidirectionalConnect(ctx context.Context, opts transport.UnixSocketOpts) (transport.ConnectedEndpoint, *syserr.Error) {
89+
c, err := e.newConnectedEndpoint(ctx, linux.SOCK_DGRAM, &waiter.Queue{}, opts)
9090
if err != nil {
9191
return nil, err
9292
}
@@ -102,15 +102,15 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect
102102
return c, nil
103103
}
104104

105-
func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.SockType, queue *waiter.Queue) (*transport.SCMConnectedEndpoint, *syserr.Error) {
105+
func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.SockType, queue *waiter.Queue, opts transport.UnixSocketOpts) (*transport.SCMConnectedEndpoint, *syserr.Error) {
106106
e.dentry.fs.renameMu.RLock()
107107
hostSockFD, err := e.dentry.connect(ctx, sockType)
108108
e.dentry.fs.renameMu.RUnlock()
109109
if err != nil {
110110
return nil, syserr.ErrConnectionRefused
111111
}
112112

113-
c, serr := transport.NewSCMEndpoint(hostSockFD, queue, e.path)
113+
c, serr := transport.NewSCMEndpoint(hostSockFD, queue, e.path, opts)
114114
if serr != nil {
115115
unix.Close(hostSockFD)
116116
log.Warningf("NewSCMEndpoint failed: path=%q, err=%v", e.path, serr)

pkg/sentry/fsimpl/testutil/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ go_library(
3131
"//pkg/sentry/platform/kvm",
3232
"//pkg/sentry/platform/ptrace",
3333
"//pkg/sentry/seccheck",
34+
"//pkg/sentry/socket/unix/transport",
3435
"//pkg/sentry/time",
3536
"//pkg/sentry/usage",
3637
"//pkg/sentry/vfs",

pkg/sentry/fsimpl/testutil/kernel.go

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
3636
"gvisor.dev/gvisor/pkg/sentry/platform"
3737
"gvisor.dev/gvisor/pkg/sentry/seccheck"
38+
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
3839
"gvisor.dev/gvisor/pkg/sentry/time"
3940
"gvisor.dev/gvisor/pkg/sentry/usage"
4041
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -106,6 +107,7 @@ func Boot() (*kernel.Kernel, error) {
106107
RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace),
107108
RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace),
108109
PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace),
110+
UnixSocketOpts: transport.UnixSocketOpts{},
109111
}); err != nil {
110112
return nil, fmt.Errorf("initializing kernel: %v", err)
111113
}

pkg/sentry/kernel/kernel.go

+8
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ import (
7070
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
7171
"gvisor.dev/gvisor/pkg/sentry/platform"
7272
"gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
73+
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
7374
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
7475
"gvisor.dev/gvisor/pkg/sentry/unimpl"
7576
uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto"
@@ -387,6 +388,9 @@ type Kernel struct {
387388
// attempt succeeded, after which at least one more checkpoint attempt was
388389
// made and failed with this error. It's protected by checkpointMu.
389390
lastCheckpointStatus error `state:"nosave"`
391+
392+
// UnixSocketOpts stores configuration options for management of unix sockets.
393+
UnixSocketOpts transport.UnixSocketOpts
390394
}
391395

392396
// Saver is an interface for saving the kernel.
@@ -445,6 +449,9 @@ type InitKernelArgs struct {
445449
// used by processes. If it is zero, the limit will be set to
446450
// unlimited.
447451
MaxFDLimit int32
452+
453+
// UnixSocketOpts contains configuration options for unix sockets.
454+
UnixSocketOpts transport.UnixSocketOpts
448455
}
449456

450457
// Init initialize the Kernel with no tasks.
@@ -567,6 +574,7 @@ func (k *Kernel) Init(args InitKernelArgs) error {
567574
k.sockets = make(map[*vfs.FileDescription]*SocketRecord)
568575

569576
k.cgroupRegistry = newCgroupRegistry()
577+
k.UnixSocketOpts = args.UnixSocketOpts
570578
return nil
571579
}
572580

pkg/sentry/socket/netlink/socket.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func New(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *sy
131131
}
132132

133133
// Create a connection from which the kernel can write messages.
134-
connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
134+
connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t, t.Kernel().UnixSocketOpts)
135135
if err != nil {
136136
ep.Close(t)
137137
return nil, err

pkg/sentry/socket/unix/transport/connectioned.go

+9-9
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ func (e *connectionedEndpoint) Close(ctx context.Context) {
275275
}
276276

277277
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
278-
func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
278+
func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint), opts UnixSocketOpts) *syserr.Error {
279279
if ce.Type() != e.stype {
280280
return syserr.ErrWrongProtocolForSocket
281281
}
@@ -378,13 +378,13 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
378378
}
379379

380380
// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
381-
func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) {
381+
func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context, opts UnixSocketOpts) (ConnectedEndpoint, *syserr.Error) {
382382
return nil, syserr.ErrConnectionRefused
383383
}
384384

385385
// Connect attempts to directly connect to another Endpoint.
386386
// Implements Endpoint.Connect.
387-
func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error {
387+
func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint, opts UnixSocketOpts) *syserr.Error {
388388
returnConnect := func(r Receiver, ce ConnectedEndpoint) {
389389
e.receiver = r
390390
e.connected = ce
@@ -396,7 +396,7 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint
396396
}
397397
}
398398

399-
return server.BidirectionalConnect(ctx, e, returnConnect)
399+
return server.BidirectionalConnect(ctx, e, returnConnect, opts)
400400
}
401401

402402
// Listen starts listening on the connection.
@@ -405,7 +405,7 @@ func (e *connectionedEndpoint) Listen(ctx context.Context, backlog int) *syserr.
405405
defer e.Unlock()
406406
if e.ListeningLocked() {
407407
// Adjust the size of the channel iff we can fix existing
408-
// pending connections into the new one.
408+
// pending connections into the new one
409409
if len(e.acceptedChan) > backlog {
410410
return syserr.ErrInvalidEndpointState
411411
}
@@ -438,15 +438,15 @@ func (e *connectionedEndpoint) Listen(ctx context.Context, backlog int) *syserr.
438438
}
439439

440440
// Accept accepts a new connection.
441-
func (e *connectionedEndpoint) Accept(ctx context.Context, peerAddr *Address) (Endpoint, *syserr.Error) {
441+
func (e *connectionedEndpoint) Accept(ctx context.Context, peerAddr *Address, opts UnixSocketOpts) (Endpoint, *syserr.Error) {
442442
e.Lock()
443443

444444
if !e.ListeningLocked() {
445445
e.Unlock()
446446
return nil, syserr.ErrInvalidEndpointState
447447
}
448448

449-
ne, err := e.getAcceptedEndpointLocked(ctx)
449+
ne, err := e.getAcceptedEndpointLocked(ctx, opts)
450450
e.Unlock()
451451
if err != nil {
452452
return nil, err
@@ -470,7 +470,7 @@ func (e *connectionedEndpoint) Accept(ctx context.Context, peerAddr *Address) (E
470470
// Preconditions:
471471
// - e.Listening()
472472
// - e is locked.
473-
func (e *connectionedEndpoint) getAcceptedEndpointLocked(ctx context.Context) (*connectionedEndpoint, *syserr.Error) {
473+
func (e *connectionedEndpoint) getAcceptedEndpointLocked(ctx context.Context, opts UnixSocketOpts) (*connectionedEndpoint, *syserr.Error) {
474474
// Accept connections from within the sentry first, since this avoids
475475
// an RPC to the gofer on the common path.
476476
select {
@@ -493,7 +493,7 @@ func (e *connectionedEndpoint) getAcceptedEndpointLocked(ctx context.Context) (*
493493
return nil, syserr.FromError(err)
494494
}
495495
q := &waiter.Queue{}
496-
scme, serr := NewSCMEndpoint(nfd, q, e.path)
496+
scme, serr := NewSCMEndpoint(nfd, q, e.path, opts)
497497
if serr != nil {
498498
unix.Close(nfd)
499499
return nil, serr

pkg/sentry/socket/unix/transport/connectionless.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ func (e *connectionlessEndpoint) Close(ctx context.Context) {
7878
}
7979

8080
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
81-
func (e *connectionlessEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
81+
func (e *connectionlessEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint), opts UnixSocketOpts) *syserr.Error {
8282
return syserr.ErrConnectionRefused
8383
}
8484

8585
// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
86-
func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) {
86+
func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context, opts UnixSocketOpts) (ConnectedEndpoint, *syserr.Error) {
8787
e.Lock()
8888
r := e.receiver
8989
e.Unlock()
@@ -107,7 +107,8 @@ func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c C
107107
return e.baseEndpoint.SendMsg(ctx, data, c, nil)
108108
}
109109

110-
connected, err := to.UnidirectionalConnect(ctx)
110+
opts := UnixSocketOpts{}
111+
connected, err := to.UnidirectionalConnect(ctx, opts)
111112
if err != nil {
112113
return 0, nil, syserr.ErrInvalidEndpointState
113114
}
@@ -131,8 +132,8 @@ func (e *connectionlessEndpoint) Type() linux.SockType {
131132
}
132133

133134
// Connect attempts to connect directly to server.
134-
func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error {
135-
connected, err := server.UnidirectionalConnect(ctx)
135+
func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoint, opts UnixSocketOpts) *syserr.Error {
136+
connected, err := server.UnidirectionalConnect(ctx, opts)
136137
if err != nil {
137138
return err
138139
}
@@ -153,7 +154,7 @@ func (*connectionlessEndpoint) Listen(context.Context, int) *syserr.Error {
153154
}
154155

155156
// Accept accepts a new connection.
156-
func (*connectionlessEndpoint) Accept(context.Context, *Address) (Endpoint, *syserr.Error) {
157+
func (*connectionlessEndpoint) Accept(context.Context, *Address, UnixSocketOpts) (Endpoint, *syserr.Error) {
157158
return nil, syserr.ErrNotSupported
158159
}
159160

pkg/sentry/socket/unix/transport/host.go

+56-4
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ func (c *HostConnectedEndpoint) init() *syserr.Error {
9898
}
9999

100100
func (c *HostConnectedEndpoint) initFromOptions() *syserr.Error {
101+
if c.fd < 0 {
102+
// There is no underlying FD to restore; nothing to do
103+
return nil
104+
}
105+
101106
family, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_DOMAIN)
102107
if err != nil {
103108
return syserr.FromError(err)
@@ -163,6 +168,10 @@ func (c *HostConnectedEndpoint) Send(ctx context.Context, data [][]byte, control
163168
return 0, false, syserr.ErrInvalidEndpointState
164169
}
165170

171+
if c.IsSendClosed() {
172+
return 0, false, syserr.ErrClosedForSend
173+
}
174+
166175
// Since stream sockets don't preserve message boundaries, we can write
167176
// only as much of the message as fits in the send buffer.
168177
truncate := c.stype == linux.SOCK_STREAM
@@ -192,6 +201,14 @@ func (c *HostConnectedEndpoint) SendNotify() {}
192201
func (c *HostConnectedEndpoint) CloseSend() {
193202
c.mu.Lock()
194203
defer c.mu.Unlock()
204+
c.closeSendLocked()
205+
}
206+
207+
// Preconditions: c.mu must be held.
208+
func (c *HostConnectedEndpoint) closeSendLocked() {
209+
if c.IsSendClosed() {
210+
return
211+
}
195212

196213
if err := unix.Shutdown(c.fd, unix.SHUT_WR); err != nil {
197214
// A well-formed UDS shutdown can't fail. See
@@ -300,6 +317,14 @@ func (c *HostConnectedEndpoint) RecvNotify() {}
300317
func (c *HostConnectedEndpoint) CloseRecv() {
301318
c.mu.Lock()
302319
defer c.mu.Unlock()
320+
c.closeRecvLocked()
321+
}
322+
323+
// Preconditions: c.mu must be held.
324+
func (c *HostConnectedEndpoint) closeRecvLocked() {
325+
if c.IsRecvClosed() {
326+
return
327+
}
303328

304329
if err := unix.Shutdown(c.fd, unix.SHUT_RD); err != nil {
305330
// A well-formed UDS shutdown can't fail. See
@@ -382,13 +407,34 @@ func (c *HostConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) {
382407
// SCMConnectedEndpoint represents an endpoint backed by a host fd that was
383408
// passed through a gofer Unix socket. It resembles HostConnectedEndpoint, with the
384409
// following differences:
385-
// - SCMConnectedEndpoint is not saveable, because the host cannot guarantee
386-
// the same descriptor number across S/R.
410+
// - SCMConnectedEndpoint is not saveable by default, because the host
411+
// cannot guarantee the same descriptor number across S/R.
412+
// However, it can optionally be placed in a closed state before save.
387413
// - SCMConnectedEndpoint holds ownership of its fd and notification queue.
414+
//
415+
// +stateify savable
388416
type SCMConnectedEndpoint struct {
389417
HostConnectedEndpoint
390418

391419
queue *waiter.Queue
420+
opts UnixSocketOpts
421+
}
422+
423+
// beforeSave is invoked by stateify.
424+
func (e *SCMConnectedEndpoint) beforeSave() {
425+
if !e.opts.DisconnectOnSave {
426+
panic("socket cannot be saved in a connected state")
427+
}
428+
429+
e.mu.Lock()
430+
defer e.mu.Unlock()
431+
fdnotifier.RemoveFD(int32(e.fd))
432+
e.closeRecvLocked()
433+
e.closeSendLocked()
434+
if err := unix.Close(e.fd); err != nil {
435+
log.Warningf("Failed to close host fd %d: %v", err)
436+
}
437+
e.destroyLocked()
392438
}
393439

394440
// Init will do the initialization required without holding other locks.
@@ -400,12 +446,17 @@ func (e *SCMConnectedEndpoint) Init() error {
400446
func (e *SCMConnectedEndpoint) Release(ctx context.Context) {
401447
e.DecRef(func() {
402448
e.mu.Lock()
449+
defer e.mu.Unlock()
450+
451+
if e.fd < 0 {
452+
return
453+
}
454+
403455
fdnotifier.RemoveFD(int32(e.fd))
404456
if err := unix.Close(e.fd); err != nil {
405457
log.Warningf("Failed to close host fd %d: %v", err)
406458
}
407459
e.destroyLocked()
408-
e.mu.Unlock()
409460
})
410461
}
411462

@@ -415,13 +466,14 @@ func (e *SCMConnectedEndpoint) Release(ctx context.Context) {
415466
// The caller is responsible for calling Init(). Additionally, Release needs to
416467
// be called twice because ConnectedEndpoint is both a Receiver and
417468
// ConnectedEndpoint.
418-
func NewSCMEndpoint(hostFD int, queue *waiter.Queue, addr string) (*SCMConnectedEndpoint, *syserr.Error) {
469+
func NewSCMEndpoint(hostFD int, queue *waiter.Queue, addr string, opts UnixSocketOpts) (*SCMConnectedEndpoint, *syserr.Error) {
419470
e := SCMConnectedEndpoint{
420471
HostConnectedEndpoint: HostConnectedEndpoint{
421472
fd: hostFD,
422473
addr: addr,
423474
},
424475
queue: queue,
476+
opts: opts,
425477
}
426478

427479
if err := e.init(); err != nil {

0 commit comments

Comments
 (0)