diff --git a/src/modules/chat/chat.gateway.spec.ts b/src/modules/chat/chat.gateway.spec.ts new file mode 100644 index 0000000..9fa019a --- /dev/null +++ b/src/modules/chat/chat.gateway.spec.ts @@ -0,0 +1,90 @@ +import { Types } from 'mongoose'; +import { ChatGateway } from './chat.gateway'; + +const createClient = (userId?: string) => + ({ + id: new Types.ObjectId().toString(), + data: { userId }, + join: jest.fn().mockResolvedValue(undefined), + leave: jest.fn().mockResolvedValue(undefined), + emit: jest.fn(), + to: jest.fn().mockReturnValue({ emit: jest.fn() }), + }) as any; + +describe('ChatGateway realtime events', () => { + const userId = new Types.ObjectId().toString(); + const conversationId = new Types.ObjectId().toString(); + let chatService: Record; + let gateway: ChatGateway; + + beforeEach(() => { + chatService = { + assertConversationMember: jest.fn().mockResolvedValue({ id: conversationId }), + sendMessage: jest.fn(), + markMessageSeen: jest.fn(), + markMessageDelivered: jest.fn(), + }; + gateway = new ChatGateway( + chatService as any, + { bindServer: jest.fn() } as any, + {} as any, + {} as any, + {} as any, + ); + (gateway as any).server = { + to: jest.fn().mockReturnValue({ emit: jest.fn() }), + }; + }); + + it('join_conversation validates membership before joining the room', async () => { + const client = createClient(userId); + + await gateway.joinConversation(client, { conversationId }); + + expect(chatService.assertConversationMember).toHaveBeenCalledWith(userId, conversationId); + expect(client.join).toHaveBeenCalledWith(`conversation:${conversationId}`); + expect(client.emit).toHaveBeenCalledWith('joined_conversation', { conversationId }); + }); + + it('leave_conversation leaves only the requested conversation room', async () => { + const client = createClient(userId); + + await gateway.leaveConversation(client, { conversationId }); + + expect(chatService.assertConversationMember).toHaveBeenCalledWith(userId, conversationId); + expect(client.leave).toHaveBeenCalledWith(`conversation:${conversationId}`); + expect(client.emit).toHaveBeenCalledWith('left_conversation', { conversationId }); + }); + + it('emits socket_error for invalid realtime operations', async () => { + const client = createClient(userId); + chatService.assertConversationMember.mockRejectedValue(new Error('Not allowed')); + + await gateway.joinConversation(client, { conversationId }); + + expect(client.emit).toHaveBeenCalledWith('socket_error', { + event: 'join_conversation', + message: 'Not allowed', + code: 'ERROR', + }); + }); + + it('mark_delivered emits message_delivered payload returned by the service', async () => { + const client = createClient(userId); + const messageId = new Types.ObjectId().toString(); + const delivered = { + conversationId, + messageId, + userId, + deliveredAt: '2026-06-07T10:00:00.000Z', + }; + const roomEmitter = { emit: jest.fn() }; + (gateway as any).server.to.mockReturnValue(roomEmitter); + chatService.markMessageDelivered.mockResolvedValue(delivered); + + await gateway.markDelivered(client, { conversationId, messageId }); + + expect(chatService.markMessageDelivered).toHaveBeenCalledWith(userId, conversationId, messageId); + expect(roomEmitter.emit).toHaveBeenCalledWith('message_delivered', delivered); + }); +}); diff --git a/src/modules/chat/chat.gateway.ts b/src/modules/chat/chat.gateway.ts index ec4923a..dfda3eb 100644 --- a/src/modules/chat/chat.gateway.ts +++ b/src/modules/chat/chat.gateway.ts @@ -1,4 +1,10 @@ import { ConfigService } from '@nestjs/config'; +import { + BadRequestException, + ForbiddenException, + NotFoundException, + UnauthorizedException, +} from '@nestjs/common'; import { JwtService } from '@nestjs/jwt'; import { ConnectedSocket, @@ -22,7 +28,9 @@ type SocketWithUser = Socket & { data: { userId?: string } }; export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect { @WebSocketServer() server!: Server; + // TODO: Move presence counts to Redis when running multiple backend instances. private readonly connectionCountsByUser = new Map(); + private readonly typingTimestampsBySocket = new Map(); constructor( private readonly chatService: ChatService, @@ -52,10 +60,17 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa return; } client.data.userId = payload.sub; - this.incrementUserConnection(payload.sub); - await this.usersService.setPresence(payload.sub, true); + const connectionCount = this.incrementUserConnection(payload.sub); + if (connectionCount === 1) { + await this.usersService.setPresence(payload.sub, true); + } await client.join(this.userRoom(payload.sub)); - this.server.to(this.userRoom(payload.sub)).emit('presence', { userId: payload.sub, online: true }); + this.server.to(this.userRoom(payload.sub)).emit('presence', { + userId: payload.sub, + isOnline: true, + online: true, + lastSeenAt: null, + }); } catch { client.disconnect(true); } @@ -66,10 +81,17 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa if (userId) { const remainingConnections = this.decrementUserConnection(userId); if (remainingConnections === 0) { + const lastSeenAt = new Date().toISOString(); void this.usersService.setPresence(userId, false); - this.server.to(this.userRoom(userId)).emit('presence', { userId, online: false }); + this.server.to(this.userRoom(userId)).emit('presence', { + userId, + isOnline: false, + online: false, + lastSeenAt, + }); } } + this.typingTimestampsBySocket.delete(client.id); } @SubscribeMessage('join_conversation') @@ -78,11 +100,32 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa @MessageBody() body: { conversationId: string }, ) { const userId = client.data.userId; - if (!userId) return; + if (!userId) return this.emitSocketError(client, 'join_conversation', 'Unauthorized socket', 'UNAUTHORIZED'); - const conversation = await this.chatService.assertConversationMember(userId, body.conversationId); - await client.join(this.conversationRoom(conversation.id)); - client.emit('joined_conversation', { conversationId: conversation.id }); + try { + const conversation = await this.chatService.assertConversationMember(userId, body?.conversationId); + await client.join(this.conversationRoom(conversation.id)); + client.emit('joined_conversation', { conversationId: conversation.id }); + } catch (error) { + this.emitSocketError(client, 'join_conversation', error); + } + } + + @SubscribeMessage('leave_conversation') + async leaveConversation( + @ConnectedSocket() client: SocketWithUser, + @MessageBody() body: { conversationId: string }, + ) { + const userId = client.data.userId; + if (!userId) return this.emitSocketError(client, 'leave_conversation', 'Unauthorized socket', 'UNAUTHORIZED'); + + try { + const conversation = await this.chatService.assertConversationMember(userId, body?.conversationId); + await client.leave(this.conversationRoom(conversation.id)); + client.emit('left_conversation', { conversationId: conversation.id }); + } catch (error) { + this.emitSocketError(client, 'leave_conversation', error); + } } @SubscribeMessage('send_message') @@ -91,10 +134,15 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa @MessageBody() dto: SendMessageDto, ) { const userId = client.data.userId; - if (!userId) return; + if (!userId) return this.emitSocketError(client, 'send_message', 'Unauthorized socket', 'UNAUTHORIZED'); - const message = await this.chatService.sendMessage(userId, dto); - return message; + try { + const message = await this.chatService.sendMessage(userId, dto); + return message; + } catch (error) { + this.emitSocketError(client, 'send_message', error); + return null; + } } @SubscribeMessage('typing') @@ -103,14 +151,25 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa @MessageBody() body: { conversationId: string; isTyping: boolean }, ) { const userId = client.data.userId; - if (!userId) return; + if (!userId) return this.emitSocketError(client, 'typing', 'Unauthorized socket', 'UNAUTHORIZED'); - await this.chatService.assertConversationMember(userId, body.conversationId); - client.to(this.conversationRoom(body.conversationId)).emit('typing', { - conversationId: body.conversationId, - userId, - isTyping: !!body.isTyping, - }); + try { + const now = Date.now(); + const previous = this.typingTimestampsBySocket.get(client.id) ?? 0; + if (now - previous < 500) { + return; + } + this.typingTimestampsBySocket.set(client.id, now); + + const conversation = await this.chatService.assertConversationMember(userId, body?.conversationId); + client.to(this.conversationRoom(conversation.id)).emit('typing', { + conversationId: conversation.id, + userId, + isTyping: !!body?.isTyping, + }); + } catch (error) { + this.emitSocketError(client, 'typing', error); + } } @SubscribeMessage('mark_seen') @@ -119,13 +178,39 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa @MessageBody() body: { messageId: string; conversationId: string }, ) { const userId = client.data.userId; - if (!userId) return; + if (!userId) return this.emitSocketError(client, 'mark_seen', 'Unauthorized socket', 'UNAUTHORIZED'); - await this.chatService.markMessageSeen(userId, body.messageId); - this.server.to(this.conversationRoom(body.conversationId)).emit('message_seen', { - messageId: body.messageId, - userId, - }); + try { + const seen = await this.chatService.markMessageSeen(userId, body?.messageId, body?.conversationId); + this.server.to(this.conversationRoom(seen.conversationId)).emit('message_seen', { + conversationId: seen.conversationId, + messageId: seen.messageId, + userId, + seenAt: seen.seenAt, + }); + } catch (error) { + this.emitSocketError(client, 'mark_seen', error); + } + } + + @SubscribeMessage('mark_delivered') + async markDelivered( + @ConnectedSocket() client: SocketWithUser, + @MessageBody() body: { messageId: string; conversationId: string }, + ) { + const userId = client.data.userId; + if (!userId) return this.emitSocketError(client, 'mark_delivered', 'Unauthorized socket', 'UNAUTHORIZED'); + + try { + const delivered = await this.chatService.markMessageDelivered( + userId, + body?.conversationId, + body?.messageId, + ); + this.server.to(this.conversationRoom(delivered.conversationId)).emit('message_delivered', delivered); + } catch (error) { + this.emitSocketError(client, 'mark_delivered', error); + } } private extractToken(client: Socket): string | null { @@ -150,8 +235,10 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa return `conversation:${conversationId}`; } - private incrementUserConnection(userId: string): void { - this.connectionCountsByUser.set(userId, (this.connectionCountsByUser.get(userId) ?? 0) + 1); + private incrementUserConnection(userId: string): number { + const nextCount = (this.connectionCountsByUser.get(userId) ?? 0) + 1; + this.connectionCountsByUser.set(userId, nextCount); + return nextCount; } private decrementUserConnection(userId: string): number { @@ -163,4 +250,26 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa this.connectionCountsByUser.set(userId, nextCount); return nextCount; } + + private emitSocketError( + client: SocketWithUser, + event: string, + errorOrMessage: unknown, + explicitCode?: string, + ): void { + const message = errorOrMessage instanceof Error ? errorOrMessage.message : String(errorOrMessage); + client.emit('socket_error', { + event, + message, + code: explicitCode ?? this.resolveErrorCode(errorOrMessage), + }); + } + + private resolveErrorCode(error: unknown): string { + if (error instanceof BadRequestException) return 'BAD_REQUEST'; + if (error instanceof ForbiddenException) return 'FORBIDDEN'; + if (error instanceof NotFoundException) return 'NOT_FOUND'; + if (error instanceof UnauthorizedException) return 'UNAUTHORIZED'; + return 'ERROR'; + } } diff --git a/src/modules/chat/chat.repository.spec.ts b/src/modules/chat/chat.repository.spec.ts index a2489a4..5587a2a 100644 --- a/src/modules/chat/chat.repository.spec.ts +++ b/src/modules/chat/chat.repository.spec.ts @@ -7,6 +7,7 @@ const queryResult = (value: T) => ({ describe('ChatRepository unread counters', () => { let conversationModel: Record; + let messageModel: Record; let repository: ChatRepository; beforeEach(() => { @@ -16,7 +17,10 @@ describe('ChatRepository unread counters', () => { findByIdAndUpdate: jest.fn(), updateOne: jest.fn(), }; - repository = new ChatRepository(conversationModel as any, {} as any, {} as any); + messageModel = { + updateOne: jest.fn(), + }; + repository = new ChatRepository(conversationModel as any, messageModel as any, {} as any); }); it('creates conversations with string user ids as unread counter keys', async () => { @@ -88,4 +92,28 @@ describe('ChatRepository unread counters', () => { }, ); }); + + it('adds a delivered receipt only when the user has not already delivered the message', async () => { + const messageId = new Types.ObjectId().toString(); + const userId = new Types.ObjectId().toString(); + const deliveredAt = new Date('2026-06-07T10:00:00.000Z'); + messageModel.updateOne.mockReturnValue(queryResult({ modifiedCount: 1 })); + + await repository.markMessageDelivered(messageId, userId, deliveredAt); + + expect(messageModel.updateOne).toHaveBeenCalledWith( + { + _id: new Types.ObjectId(messageId), + 'deliveredBy.userId': { $ne: new Types.ObjectId(userId) }, + }, + { + $push: { + deliveredBy: { + userId: new Types.ObjectId(userId), + deliveredAt, + }, + }, + }, + ); + }); }); diff --git a/src/modules/chat/chat.repository.ts b/src/modules/chat/chat.repository.ts index 1b6a431..af9147d 100644 --- a/src/modules/chat/chat.repository.ts +++ b/src/modules/chat/chat.repository.ts @@ -134,6 +134,25 @@ export class ChatRepository { .exec(); } + async markMessageDelivered(messageId: string, userId: string, deliveredAt: Date): Promise { + await this.messageModel + .updateOne( + { + _id: new Types.ObjectId(messageId), + 'deliveredBy.userId': { $ne: new Types.ObjectId(userId) }, + }, + { + $push: { + deliveredBy: { + userId: new Types.ObjectId(userId), + deliveredAt, + }, + }, + }, + ) + .exec(); + } + async setMessageReaction( messageId: string, userId: string, diff --git a/src/modules/chat/chat.service.spec.ts b/src/modules/chat/chat.service.spec.ts index 5454cb6..37459ed 100644 --- a/src/modules/chat/chat.service.spec.ts +++ b/src/modules/chat/chat.service.spec.ts @@ -24,6 +24,9 @@ describe('ChatService realtime message broadcasting', () => { createMessage: jest.fn(), updateConversationAfterNewMessage: jest.fn().mockResolvedValue(conversation), findMessageById: jest.fn(), + markMessageSeen: jest.fn().mockResolvedValue(undefined), + clearConversationUnreadForUser: jest.fn().mockResolvedValue(undefined), + markMessageDelivered: jest.fn().mockResolvedValue(undefined), }; notificationsService = { createMessageNotification: jest.fn().mockResolvedValue(null), @@ -102,4 +105,62 @@ describe('ChatService realtime message broadcasting', () => { expect(chatRealtimeService.emitNewMessage).toHaveBeenCalledTimes(1); expect(chatRealtimeService.emitNewMessage).toHaveBeenCalledWith(conversationId, message); }); + + it('marks a message delivered for a participant that is not the sender', async () => { + const messageId = new Types.ObjectId().toString(); + chatRepository.findMessageById.mockResolvedValue({ + id: messageId, + conversationId: new Types.ObjectId(conversationId), + senderId: new Types.ObjectId(senderId), + }); + + const result = await service.markMessageDelivered(recipientId, conversationId, messageId); + + expect(chatRepository.markMessageDelivered).toHaveBeenCalledWith( + messageId, + recipientId, + expect.any(Date), + ); + expect(result).toMatchObject({ + conversationId, + messageId, + userId: recipientId, + }); + expect(result.deliveredAt).toEqual(expect.any(String)); + }); + + it('does not let the sender mark their own message delivered', async () => { + const messageId = new Types.ObjectId().toString(); + chatRepository.findMessageById.mockResolvedValue({ + id: messageId, + conversationId: new Types.ObjectId(conversationId), + senderId: new Types.ObjectId(senderId), + }); + + await expect(service.markMessageDelivered(senderId, conversationId, messageId)).rejects.toThrow( + 'Sender cannot mark own message delivered', + ); + expect(chatRepository.markMessageDelivered).not.toHaveBeenCalled(); + }); + + it('marks a message seen and returns conversationId and seenAt for realtime payloads', async () => { + const messageId = new Types.ObjectId().toString(); + chatRepository.findMessageById.mockResolvedValue({ + id: messageId, + conversationId: new Types.ObjectId(conversationId), + senderId: new Types.ObjectId(senderId), + }); + + const result = await service.markMessageSeen(recipientId, messageId, conversationId); + + expect(chatRepository.markMessageSeen).toHaveBeenCalledWith(messageId, recipientId); + expect(chatRepository.clearConversationUnreadForUser).toHaveBeenCalledWith(conversationId, recipientId); + expect(result).toMatchObject({ + success: true, + conversationId, + messageId, + userId: recipientId, + }); + expect(result.seenAt).toEqual(expect.any(String)); + }); }); diff --git a/src/modules/chat/chat.service.ts b/src/modules/chat/chat.service.ts index efdffb7..8aa794e 100644 --- a/src/modules/chat/chat.service.ts +++ b/src/modules/chat/chat.service.ts @@ -212,17 +212,60 @@ export class ChatService { } } - async markMessageSeen(currentUserId: string, messageId: string) { + async markMessageSeen(currentUserId: string, messageId: string, conversationId?: string) { + if (!Types.ObjectId.isValid(messageId)) { + throw new BadRequestException('Invalid message id'); + } + if (conversationId && !Types.ObjectId.isValid(conversationId)) { + throw new BadRequestException('Invalid conversation id'); + } + const message = await this.chatRepository.findMessageById(messageId); if (!message) { throw new NotFoundException('Message not found'); } + if (conversationId && message.conversationId.toString() !== conversationId) { + throw new BadRequestException('Message must belong to the conversation'); + } await this.assertConversationMember(currentUserId, message.conversationId.toString()); await this.chatRepository.markMessageSeen(message.id, currentUserId); await this.chatRepository.clearConversationUnreadForUser(message.conversationId.toString(), currentUserId); + const seenAt = new Date(); - return { success: true }; + return { + success: true, + conversationId: message.conversationId.toString(), + messageId: message.id, + userId: currentUserId, + seenAt: seenAt.toISOString(), + }; + } + + async markMessageDelivered(currentUserId: string, conversationId: string, messageId: string) { + const conversation = await this.assertConversationMember(currentUserId, conversationId); + if (!Types.ObjectId.isValid(messageId)) { + throw new BadRequestException('Invalid message id'); + } + + const message = await this.chatRepository.findMessageById(messageId); + if (!message || message.conversationId.toString() !== conversation.id) { + throw new BadRequestException('Message must belong to the conversation'); + } + + if (message.senderId.toString() === currentUserId) { + throw new BadRequestException('Sender cannot mark own message delivered'); + } + + const deliveredAt = new Date(); + await this.chatRepository.markMessageDelivered(message.id, currentUserId, deliveredAt); + + return { + conversationId: conversation.id, + messageId: message.id, + userId: currentUserId, + deliveredAt: deliveredAt.toISOString(), + }; } async unsendMessage(currentUserId: string, messageId: string) { diff --git a/src/modules/chat/schemas/message.schema.ts b/src/modules/chat/schemas/message.schema.ts index 18b1a08..2744acc 100644 --- a/src/modules/chat/schemas/message.schema.ts +++ b/src/modules/chat/schemas/message.schema.ts @@ -5,6 +5,17 @@ import { User } from '../../users/schemas/user.schema'; export type MessageDocument = HydratedDocument; +@Schema({ _id: false, versionKey: false }) +export class MessageDeliveryReceipt { + @Prop({ type: Types.ObjectId, ref: User.name, required: true }) + userId!: Types.ObjectId; + + @Prop({ type: Date, required: true }) + deliveredAt!: Date; +} + +export const MessageDeliveryReceiptSchema = SchemaFactory.createForClass(MessageDeliveryReceipt); + @Schema({ timestamps: true, versionKey: false }) export class Message { @Prop({ type: Types.ObjectId, required: true, index: true }) @@ -31,6 +42,9 @@ export class Message { @Prop({ type: [Types.ObjectId], ref: User.name, default: [] }) seenBy!: Types.ObjectId[]; + @Prop({ type: [MessageDeliveryReceiptSchema], default: [] }) + deliveredBy!: MessageDeliveryReceipt[]; + @Prop({ type: [Types.ObjectId], ref: User.name, default: [], index: true }) deletedForUserIds!: Types.ObjectId[]; diff --git a/src/modules/notifications/notifications.gateway.spec.ts b/src/modules/notifications/notifications.gateway.spec.ts new file mode 100644 index 0000000..e41aa5d --- /dev/null +++ b/src/modules/notifications/notifications.gateway.spec.ts @@ -0,0 +1,29 @@ +import { NotificationsGateway } from './notifications.gateway'; + +describe('NotificationsGateway realtime events', () => { + it('emits backward-compatible and new notification events to the recipient room', () => { + const gateway = new NotificationsGateway({} as any, {} as any); + const roomEmitter = { emit: jest.fn() }; + (gateway as any).server = { + to: jest.fn().mockReturnValue(roomEmitter), + }; + const notification = { _id: 'notification-1', type: 'message' }; + const unreadCounts = { + total: 3, + interactions: 1, + messages: 2, + follows: 0, + followRequests: 0, + collaboration: 0, + system: 0, + }; + + gateway.emitCreated('user-1', notification, 3, unreadCounts); + + expect((gateway as any).server.to).toHaveBeenCalledWith('user:user-1'); + expect(roomEmitter.emit).toHaveBeenCalledWith('notification_created', notification); + expect(roomEmitter.emit).toHaveBeenCalledWith('notifications_unread_count', { unreadCount: 3 }); + expect(roomEmitter.emit).toHaveBeenCalledWith('notification:new', notification); + expect(roomEmitter.emit).toHaveBeenCalledWith('notification:unread_counts', unreadCounts); + }); +}); diff --git a/src/modules/notifications/notifications.gateway.ts b/src/modules/notifications/notifications.gateway.ts index 2ef66b7..b5727f0 100644 --- a/src/modules/notifications/notifications.gateway.ts +++ b/src/modules/notifications/notifications.gateway.ts @@ -42,13 +42,25 @@ export class NotificationsGateway implements OnGatewayConnection { } } - emitCreated(recipientId: string, notification: unknown, unreadCount: number): void { + emitCreated( + recipientId: string, + notification: unknown, + unreadCount: number, + unreadCounts?: Record, + ): void { this.server.to(this.userRoom(recipientId)).emit('notification_created', notification); this.server.to(this.userRoom(recipientId)).emit('notifications_unread_count', { unreadCount }); + this.server.to(this.userRoom(recipientId)).emit('notification:new', notification); + if (unreadCounts) { + this.server.to(this.userRoom(recipientId)).emit('notification:unread_counts', unreadCounts); + } } - emitUnreadCount(recipientId: string, unreadCount: number): void { + emitUnreadCount(recipientId: string, unreadCount: number, unreadCounts?: Record): void { this.server.to(this.userRoom(recipientId)).emit('notifications_unread_count', { unreadCount }); + if (unreadCounts) { + this.server.to(this.userRoom(recipientId)).emit('notification:unread_counts', unreadCounts); + } } private extractToken(client: Socket): string | null { diff --git a/src/modules/notifications/notifications.service.spec.ts b/src/modules/notifications/notifications.service.spec.ts index 2e21a7f..64c9372 100644 --- a/src/modules/notifications/notifications.service.spec.ts +++ b/src/modules/notifications/notifications.service.spec.ts @@ -18,6 +18,7 @@ describe('NotificationsService', () => { const notificationsRepository = { create: jest.fn().mockResolvedValue({ toJSON: () => ({ _id: 'notification-1' }) }), countUnread: jest.fn().mockResolvedValue(5), + countUnreadByFilter: jest.fn().mockResolvedValue(0), }; const notificationsGateway = { emitCreated: jest.fn(), @@ -47,6 +48,7 @@ describe('NotificationsService', () => { const notificationsRepository = { markAllRead: jest.fn().mockResolvedValue(4), countUnread: jest.fn().mockResolvedValue(2), + countUnreadByFilter: jest.fn().mockResolvedValue(0), }; const notificationsGateway = { emitUnreadCount: jest.fn(), @@ -63,7 +65,15 @@ describe('NotificationsService', () => { updatedCount: 4, unreadCount: 2, }); - expect(notificationsGateway.emitUnreadCount).toHaveBeenCalledWith('user-1', 2); + expect(notificationsGateway.emitUnreadCount).toHaveBeenCalledWith( + 'user-1', + 2, + expect.objectContaining({ + total: 2, + interactions: 0, + messages: 0, + }), + ); }); it('throws not found for invalid notification id in markRead', async () => { @@ -330,6 +340,7 @@ describe('NotificationsService', () => { const notificationsRepository = { create: jest.fn().mockResolvedValue({ toJSON: () => ({ _id: 'notification-1' }) }), countUnread: jest.fn().mockResolvedValue(1), + countUnreadByFilter: jest.fn().mockResolvedValue(0), }; const notificationsGateway = { emitCreated: jest.fn(), diff --git a/src/modules/notifications/notifications.service.ts b/src/modules/notifications/notifications.service.ts index 0d7dbfe..004fa53 100644 --- a/src/modules/notifications/notifications.service.ts +++ b/src/modules/notifications/notifications.service.ts @@ -56,8 +56,14 @@ export class NotificationsService { readAt: null, }); - const unreadCount = await this.notificationsRepository.countUnread(dto.recipientId); - this.notificationsGateway.emitCreated(dto.recipientId, notification.toJSON(), unreadCount); + const unreadCounts = await this.getUnreadCounts(dto.recipientId); + const unreadCount = unreadCounts.total; + this.notificationsGateway.emitCreated( + dto.recipientId, + notification.toJSON(), + unreadCount, + unreadCounts, + ); return notification; } @@ -299,8 +305,9 @@ export class NotificationsService { throw new NotFoundException('Notification not found'); } - const unreadCount = await this.notificationsRepository.countUnread(recipientId); - this.notificationsGateway.emitUnreadCount(recipientId, unreadCount); + const unreadCounts = await this.getUnreadCounts(recipientId); + const unreadCount = unreadCounts.total; + this.notificationsGateway.emitUnreadCount(recipientId, unreadCount, unreadCounts); return { message: 'Notification marked as read', @@ -311,8 +318,9 @@ export class NotificationsService { async markAllRead(recipientId: string) { const modifiedCount = await this.notificationsRepository.markAllRead(recipientId); - const unreadCount = await this.notificationsRepository.countUnread(recipientId); - this.notificationsGateway.emitUnreadCount(recipientId, unreadCount); + const unreadCounts = await this.getUnreadCounts(recipientId); + const unreadCount = unreadCounts.total; + this.notificationsGateway.emitUnreadCount(recipientId, unreadCount, unreadCounts); return { message: 'All notifications marked as read',