diff --git a/src/Http/WebSocketSession.h b/src/Http/WebSocketSession.h index 682a2559..af302b8a 100644 --- a/src/Http/WebSocketSession.h +++ b/src/Http/WebSocketSession.h @@ -30,16 +30,80 @@ #include "HttpSession.h" #include "Network/TcpServer.h" +/** + * 数据发送拦截器 + */ +class SendInterceptor{ +public: + typedef function onBeforeSendCB; + SendInterceptor() = default; + virtual ~SendInterceptor() = default; + virtual void setOnBeforeSendCB(const onBeforeSendCB &cb) = 0; +}; + +/** + * 该类实现了TcpSession派生类发送数据的截取 + * 目的是发送业务数据前进行websocket协议的打包 + */ +template +class TcpSessionTypeImp : public TcpSessionType, public SendInterceptor{ +public: + typedef std::shared_ptr Ptr; + + TcpSessionTypeImp(const Parser &header, const HttpSession &parent, const Socket::Ptr &pSock) : + _identifier(parent.getIdentifier()), TcpSessionType(pSock) {} + + ~TcpSessionTypeImp() {} + + /** + * 设置发送数据截取回调函数 + * @param cb 截取回调函数 + */ + void setOnBeforeSendCB(const onBeforeSendCB &cb) override { + _beforeSendCB = cb; + } + +protected: + /** + * 重载send函数截取数据 + * @param buf 需要截取的数据 + * @return 数据字节数 + */ + int send(const Buffer::Ptr &buf) override { + if (_beforeSendCB) { + return _beforeSendCB(buf); + } + return TcpSessionType::send(buf); + } + + string getIdentifier() const override { + return _identifier; + } + +private: + onBeforeSendCB _beforeSendCB; + string _identifier; +}; + +template +class TcpSessionCreator { +public: + //返回的TcpSession必须派生于SendInterceptor,可以返回null + TcpSession::Ptr operator()(const Parser &header, const HttpSession &parent, const Socket::Ptr &pSock){ + return std::make_shared >(header,parent,pSock); + } +}; + + /** * 通过该模板类可以透明化WebSocket协议, * 用户只要实现WebSock协议下的具体业务协议,譬如基于WebSocket协议的Rtmp协议等 -* @tparam SessionType 业务协议的TcpSession类 */ -template -class WebSocketSession : public HttpSessionType { +template +class WebSocketSessionBase : public HttpSessionType { public: - WebSocketSession(const Socket::Ptr &pSock) : HttpSessionType(pSock){} - virtual ~WebSocketSession(){} + WebSocketSessionBase(const Socket::Ptr &pSock) : HttpSessionType(pSock){} + virtual ~WebSocketSessionBase(){} //收到eof或其他导致脱离TcpServer事件的回调 void onError(const SockException &err) override{ @@ -69,23 +133,27 @@ protected: */ bool onWebSocketConnect(const Parser &header) override{ //创建websocket session类 - _session = std::make_shared(HttpSessionType::getIdentifier(),HttpSessionType::_sock); + _session = _creator(header, *this,HttpSessionType::_sock); + if(!_session){ + //此url不允许创建websocket连接 + return false; + } auto strongServer = _weakServer.lock(); if(strongServer){ _session->attachServer(*strongServer); } //此处截取数据并进行websocket协议打包 - weak_ptr weakSelf = dynamic_pointer_cast(HttpSessionType::shared_from_this()); - _session->setOnBeforeSendCB([weakSelf](const Buffer::Ptr &buf){ + weak_ptr weakSelf = dynamic_pointer_cast(HttpSessionType::shared_from_this()); + dynamic_pointer_cast(_session)->setOnBeforeSendCB([weakSelf](const Buffer::Ptr &buf) { auto strongSelf = weakSelf.lock(); - if(strongSelf){ + if (strongSelf) { WebSocketHeader header; header._fin = true; header._reserved = 0; header._opcode = DataType; header._mask_flag = false; - strongSelf->WebSocketSplitter::encode(header,buf); + strongSelf->WebSocketSplitter::encode(header, buf); } return buf->size(); }); @@ -155,50 +223,19 @@ protected: void onWebSocketEncodeData(const Buffer::Ptr &buffer) override{ SocketHelper::send(buffer); } -private: - typedef function onBeforeSendCB; - /** - * 该类实现了TcpSession派生类发送数据的截取 - * 目的是发送业务数据前进行websocket协议的打包 - */ - class SessionImp : public SessionType{ - public: - SessionImp(const string &identifier,const Socket::Ptr &pSock) : - _identifier(identifier),SessionType(pSock){} - - ~SessionImp(){} - - /** - * 设置发送数据截取回调函数 - * @param cb 截取回调函数 - */ - void setOnBeforeSendCB(const onBeforeSendCB &cb){ - _beforeSendCB = cb; - } - protected: - /** - * 重载send函数截取数据 - * @param buf 需要截取的数据 - * @return 数据字节数 - */ - int send(const Buffer::Ptr &buf) override { - if(_beforeSendCB){ - return _beforeSendCB(buf); - } - return SessionType::send(buf); - } - string getIdentifier() const override{ - return _identifier; - } - private: - onBeforeSendCB _beforeSendCB; - string _identifier; - }; private: string _remian_data; weak_ptr _weakServer; - std::shared_ptr _session; + TcpSession::Ptr _session; + Creator _creator; }; +template +class WebSocketSession : public WebSocketSessionBase,HttpSessionType,DataType>{ +public: + WebSocketSession(const Socket::Ptr &pSock) : WebSocketSessionBase,HttpSessionType,DataType>(pSock){} + virtual ~WebSocketSession(){} +}; + #endif //ZLMEDIAKIT_WEBSOCKETSESSION_H diff --git a/tests/test_wsServer.cpp b/tests/test_wsServer.cpp index 12891889..1821467b 100644 --- a/tests/test_wsServer.cpp +++ b/tests/test_wsServer.cpp @@ -51,6 +51,7 @@ public: } void onRecv(const Buffer::Ptr &buffer) override { //回显数据 + send("from EchoSession:"); send(buffer); } void onError(const SockException &err) override{ @@ -62,6 +63,48 @@ public: } }; + +class EchoSessionWithUrl : public TcpSession { +public: + EchoSessionWithUrl(const Socket::Ptr &pSock) : TcpSession(pSock){ + DebugL; + } + virtual ~EchoSessionWithUrl(){ + DebugL; + } + + void attachServer(const TcpServer &server) override{ + DebugL << getIdentifier() << " " << TcpSession::getIdentifier(); + } + void onRecv(const Buffer::Ptr &buffer) override { + //回显数据 + send("from EchoSessionWithUrl:"); + send(buffer); + } + void onError(const SockException &err) override{ + WarnL << err.what(); + } + //每隔一段时间触发,用来做超时管理 + void onManager() override{ + DebugL; + } +}; + + +/** + * 此对象可以根据websocket 客户端访问的url选择创建不同的对象 + */ +struct EchoSessionCreator { + //返回的TcpSession必须派生于SendInterceptor,可以返回null(拒绝连接) + TcpSession::Ptr operator()(const Parser &header, const HttpSession &parent, const Socket::Ptr &pSock) { +// return nullptr; + if (header.Url() == "/") { + return std::make_shared >(header, parent, pSock); + } + return std::make_shared >(header, parent, pSock); + } +}; + int main(int argc, char *argv[]) { //设置日志 Logger::Instance().add(std::make_shared()); @@ -71,13 +114,19 @@ int main(int argc, char *argv[]) { TcpServer::Ptr httpSrv(new TcpServer()); //http服务器,支持websocket - httpSrv->start>(80);//默认80 + httpSrv->start >(80);//默认80 TcpServer::Ptr httpsSrv(new TcpServer()); //https服务器,支持websocket - httpsSrv->start>(443);//默认443 + httpsSrv->start >(443);//默认443 + + TcpServer::Ptr httpSrvOld(new TcpServer()); + //兼容之前的代码(但是不支持根据url选择生成TcpSession类型) + httpSrvOld->start >(8080); + + DebugL << "请打开网页:http://www.websocket-test.com/,进行测试"; + DebugL << "连接 ws://127.0.0.1/xxxx,ws://127.0.0.1/ 测试的效果将不同,支持根据url选择不同的处理逻辑"; - DebugL << "请打开网页:http://www.websocket-test.com/,连接 ws://127.0.0.1/测试"; //设置退出信号处理函数 static semaphore sem;